fix callbacks

This commit is contained in:
FilippoOlivo
2025-11-13 10:48:20 +01:00
parent 09677d3c15
commit c0cbb13a92
2 changed files with 20 additions and 25 deletions

View File

@@ -5,7 +5,6 @@ from lightning.pytorch import Callback
from ..label_tensor import LabelTensor from ..label_tensor import LabelTensor
from ..utils import check_consistency, is_function from ..utils import check_consistency, is_function
from ..condition import InputTargetCondition from ..condition import InputTargetCondition
from ..data.dataset import PinaGraphDataset
class NormalizerDataCallback(Callback): class NormalizerDataCallback(Callback):
@@ -122,7 +121,10 @@ class NormalizerDataCallback(Callback):
""" """
# Ensure datsets are not graph-based # Ensure datsets are not graph-based
if isinstance(trainer.datamodule.train_dataset, PinaGraphDataset): if any(
ds.is_graph_dataset
for ds in trainer.datamodule.train_dataset.values()
):
raise NotImplementedError( raise NotImplementedError(
"NormalizerDataCallback is not compatible with " "NormalizerDataCallback is not compatible with "
"graph-based datasets." "graph-based datasets."
@@ -164,8 +166,8 @@ class NormalizerDataCallback(Callback):
:param dataset: The `~pina.data.dataset.PinaDataset` dataset. :param dataset: The `~pina.data.dataset.PinaDataset` dataset.
""" """
for cond in conditions: for cond in conditions:
if cond in dataset.conditions_dict: if cond in dataset:
data = dataset.conditions_dict[cond][self.apply_to] data = dataset[cond].data[self.apply_to]
shift = self.shift_fn(data) shift = self.shift_fn(data)
scale = self.scale_fn(data) scale = self.scale_fn(data)
self._normalizer[cond] = { self._normalizer[cond] = {
@@ -197,25 +199,20 @@ class NormalizerDataCallback(Callback):
:param PinaDataset dataset: The dataset to be normalized. :param PinaDataset dataset: The dataset to be normalized.
""" """
# Initialize update dictionary
update_dataset_dict = {}
# Iterate over conditions and apply normalization # Iterate over conditions and apply normalization
for cond, norm_params in self.normalizer.items(): for cond, norm_params in self.normalizer.items():
points = dataset.conditions_dict[cond][self.apply_to] update_dataset_dict = {}
points = dataset[cond].data[self.apply_to]
scale = norm_params["scale"] scale = norm_params["scale"]
shift = norm_params["shift"] shift = norm_params["shift"]
normalized_points = self._norm_fn(points, scale, shift) normalized_points = self._norm_fn(points, scale, shift)
update_dataset_dict[cond] = { update_dataset_dict[self.apply_to] = (
self.apply_to: ( LabelTensor(normalized_points, points.labels)
LabelTensor(normalized_points, points.labels) if isinstance(points, LabelTensor)
if isinstance(points, LabelTensor) else normalized_points
else normalized_points )
) dataset[cond].data.update(update_dataset_dict)
}
# Update the dataset in-place
dataset.update_data(update_dataset_dict)
@property @property
def normalizer(self): def normalizer(self):

View File

@@ -133,13 +133,12 @@ class RefinementInterface(Callback, metaclass=ABCMeta):
:param PINNInterface solver: The solver object. :param PINNInterface solver: The solver object.
""" """
new_points = {}
for name in self._condition_to_update: for name in self._condition_to_update:
current_points = self.dataset.conditions_dict[name]["input"] new_points = {}
new_points[name] = { current_points = self.dataset[name].data["input"]
"input": self.sample(current_points, name, solver) new_points["input"] = self.sample(current_points, name, solver)
}
self.dataset.update_data(new_points) self.dataset[name].update_data(new_points)
def _compute_population_size(self, conditions): def _compute_population_size(self, conditions):
""" """
@@ -150,6 +149,5 @@ class RefinementInterface(Callback, metaclass=ABCMeta):
:rtype: dict :rtype: dict
""" """
return { return {
cond: len(self.dataset.conditions_dict[cond]["input"]) cond: len(self.dataset[cond].data["input"]) for cond in conditions
for cond in conditions
} }