diff --git a/pina/callback/normalizer_data_callback.py b/pina/callback/normalizer_data_callback.py index ef957b9..faef8a9 100644 --- a/pina/callback/normalizer_data_callback.py +++ b/pina/callback/normalizer_data_callback.py @@ -5,7 +5,6 @@ from lightning.pytorch import Callback from ..label_tensor import LabelTensor from ..utils import check_consistency, is_function from ..condition import InputTargetCondition -from ..data.dataset import PinaGraphDataset class NormalizerDataCallback(Callback): @@ -122,7 +121,10 @@ class NormalizerDataCallback(Callback): """ # Ensure datsets are not graph-based - if isinstance(trainer.datamodule.train_dataset, PinaGraphDataset): + if any( + ds.is_graph_dataset + for ds in trainer.datamodule.train_dataset.values() + ): raise NotImplementedError( "NormalizerDataCallback is not compatible with " "graph-based datasets." @@ -164,8 +166,8 @@ class NormalizerDataCallback(Callback): :param dataset: The `~pina.data.dataset.PinaDataset` dataset. """ for cond in conditions: - if cond in dataset.conditions_dict: - data = dataset.conditions_dict[cond][self.apply_to] + if cond in dataset: + data = dataset[cond].data[self.apply_to] shift = self.shift_fn(data) scale = self.scale_fn(data) self._normalizer[cond] = { @@ -197,25 +199,20 @@ class NormalizerDataCallback(Callback): :param PinaDataset dataset: The dataset to be normalized. """ - # Initialize update dictionary - update_dataset_dict = {} # Iterate over conditions and apply normalization for cond, norm_params in self.normalizer.items(): - points = dataset.conditions_dict[cond][self.apply_to] + update_dataset_dict = {} + points = dataset[cond].data[self.apply_to] scale = norm_params["scale"] shift = norm_params["shift"] normalized_points = self._norm_fn(points, scale, shift) - update_dataset_dict[cond] = { - self.apply_to: ( - LabelTensor(normalized_points, points.labels) - if isinstance(points, LabelTensor) - else normalized_points - ) - } - - # Update the dataset in-place - dataset.update_data(update_dataset_dict) + update_dataset_dict[self.apply_to] = ( + LabelTensor(normalized_points, points.labels) + if isinstance(points, LabelTensor) + else normalized_points + ) + dataset[cond].data.update(update_dataset_dict) @property def normalizer(self): diff --git a/pina/callback/refinement/refinement_interface.py b/pina/callback/refinement/refinement_interface.py index adc6e4e..c8bafa9 100644 --- a/pina/callback/refinement/refinement_interface.py +++ b/pina/callback/refinement/refinement_interface.py @@ -133,13 +133,12 @@ class RefinementInterface(Callback, metaclass=ABCMeta): :param PINNInterface solver: The solver object. """ - new_points = {} for name in self._condition_to_update: - current_points = self.dataset.conditions_dict[name]["input"] - new_points[name] = { - "input": self.sample(current_points, name, solver) - } - self.dataset.update_data(new_points) + new_points = {} + current_points = self.dataset[name].data["input"] + new_points["input"] = self.sample(current_points, name, solver) + + self.dataset[name].update_data(new_points) def _compute_population_size(self, conditions): """ @@ -150,6 +149,5 @@ class RefinementInterface(Callback, metaclass=ABCMeta): :rtype: dict """ return { - cond: len(self.dataset.conditions_dict[cond]["input"]) - for cond in conditions + cond: len(self.dataset[cond].data["input"]) for cond in conditions }