fix callbacks

This commit is contained in:
FilippoOlivo
2025-11-13 10:48:20 +01:00
parent 09677d3c15
commit c0cbb13a92
2 changed files with 20 additions and 25 deletions

View File

@@ -5,7 +5,6 @@ from lightning.pytorch import Callback
from ..label_tensor import LabelTensor
from ..utils import check_consistency, is_function
from ..condition import InputTargetCondition
from ..data.dataset import PinaGraphDataset
class NormalizerDataCallback(Callback):
@@ -122,7 +121,10 @@ class NormalizerDataCallback(Callback):
"""
# Ensure datsets are not graph-based
if isinstance(trainer.datamodule.train_dataset, PinaGraphDataset):
if any(
ds.is_graph_dataset
for ds in trainer.datamodule.train_dataset.values()
):
raise NotImplementedError(
"NormalizerDataCallback is not compatible with "
"graph-based datasets."
@@ -164,8 +166,8 @@ class NormalizerDataCallback(Callback):
:param dataset: The `~pina.data.dataset.PinaDataset` dataset.
"""
for cond in conditions:
if cond in dataset.conditions_dict:
data = dataset.conditions_dict[cond][self.apply_to]
if cond in dataset:
data = dataset[cond].data[self.apply_to]
shift = self.shift_fn(data)
scale = self.scale_fn(data)
self._normalizer[cond] = {
@@ -197,25 +199,20 @@ class NormalizerDataCallback(Callback):
:param PinaDataset dataset: The dataset to be normalized.
"""
# Initialize update dictionary
update_dataset_dict = {}
# Iterate over conditions and apply normalization
for cond, norm_params in self.normalizer.items():
points = dataset.conditions_dict[cond][self.apply_to]
update_dataset_dict = {}
points = dataset[cond].data[self.apply_to]
scale = norm_params["scale"]
shift = norm_params["shift"]
normalized_points = self._norm_fn(points, scale, shift)
update_dataset_dict[cond] = {
self.apply_to: (
LabelTensor(normalized_points, points.labels)
if isinstance(points, LabelTensor)
else normalized_points
)
}
# Update the dataset in-place
dataset.update_data(update_dataset_dict)
update_dataset_dict[self.apply_to] = (
LabelTensor(normalized_points, points.labels)
if isinstance(points, LabelTensor)
else normalized_points
)
dataset[cond].data.update(update_dataset_dict)
@property
def normalizer(self):

View File

@@ -133,13 +133,12 @@ class RefinementInterface(Callback, metaclass=ABCMeta):
:param PINNInterface solver: The solver object.
"""
new_points = {}
for name in self._condition_to_update:
current_points = self.dataset.conditions_dict[name]["input"]
new_points[name] = {
"input": self.sample(current_points, name, solver)
}
self.dataset.update_data(new_points)
new_points = {}
current_points = self.dataset[name].data["input"]
new_points["input"] = self.sample(current_points, name, solver)
self.dataset[name].update_data(new_points)
def _compute_population_size(self, conditions):
"""
@@ -150,6 +149,5 @@ class RefinementInterface(Callback, metaclass=ABCMeta):
:rtype: dict
"""
return {
cond: len(self.dataset.conditions_dict[cond]["input"])
for cond in conditions
cond: len(self.dataset[cond].data["input"]) for cond in conditions
}