fix callbacks

This commit is contained in:
FilippoOlivo
2025-11-13 10:48:20 +01:00
parent 09677d3c15
commit c0cbb13a92
2 changed files with 20 additions and 25 deletions

View File

@@ -5,7 +5,6 @@ from lightning.pytorch import Callback
from ..label_tensor import LabelTensor
from ..utils import check_consistency, is_function
from ..condition import InputTargetCondition
from ..data.dataset import PinaGraphDataset
class NormalizerDataCallback(Callback):
@@ -122,7 +121,10 @@ class NormalizerDataCallback(Callback):
"""
# Ensure datsets are not graph-based
if isinstance(trainer.datamodule.train_dataset, PinaGraphDataset):
if any(
ds.is_graph_dataset
for ds in trainer.datamodule.train_dataset.values()
):
raise NotImplementedError(
"NormalizerDataCallback is not compatible with "
"graph-based datasets."
@@ -164,8 +166,8 @@ class NormalizerDataCallback(Callback):
:param dataset: The `~pina.data.dataset.PinaDataset` dataset.
"""
for cond in conditions:
if cond in dataset.conditions_dict:
data = dataset.conditions_dict[cond][self.apply_to]
if cond in dataset:
data = dataset[cond].data[self.apply_to]
shift = self.shift_fn(data)
scale = self.scale_fn(data)
self._normalizer[cond] = {
@@ -197,25 +199,20 @@ class NormalizerDataCallback(Callback):
:param PinaDataset dataset: The dataset to be normalized.
"""
# Initialize update dictionary
update_dataset_dict = {}
# Iterate over conditions and apply normalization
for cond, norm_params in self.normalizer.items():
points = dataset.conditions_dict[cond][self.apply_to]
update_dataset_dict = {}
points = dataset[cond].data[self.apply_to]
scale = norm_params["scale"]
shift = norm_params["shift"]
normalized_points = self._norm_fn(points, scale, shift)
update_dataset_dict[cond] = {
self.apply_to: (
LabelTensor(normalized_points, points.labels)
if isinstance(points, LabelTensor)
else normalized_points
)
}
# Update the dataset in-place
dataset.update_data(update_dataset_dict)
update_dataset_dict[self.apply_to] = (
LabelTensor(normalized_points, points.labels)
if isinstance(points, LabelTensor)
else normalized_points
)
dataset[cond].data.update(update_dataset_dict)
@property
def normalizer(self):