diff --git a/docs/source/_rst/_code.rst b/docs/source/_rst/_code.rst index a724256..160eb35 100644 --- a/docs/source/_rst/_code.rst +++ b/docs/source/_rst/_code.rst @@ -253,6 +253,7 @@ Callbacks Optimizer callback R3 Refinment callback Refinment Interface callback + Normalizer callback Losses and Weightings --------------------- diff --git a/docs/source/_rst/callback/normalizer_data_callback.rst b/docs/source/_rst/callback/normalizer_data_callback.rst new file mode 100644 index 0000000..6f59f7a --- /dev/null +++ b/docs/source/_rst/callback/normalizer_data_callback.rst @@ -0,0 +1,7 @@ +Normalizer callbacks +======================= + +.. currentmodule:: pina.callback.normalizer_data_callback +.. autoclass:: NormalizerDataCallback + :members: + :show-inheritance: \ No newline at end of file diff --git a/pina/callback/__init__.py b/pina/callback/__init__.py index e9a70ea..f71a89f 100644 --- a/pina/callback/__init__.py +++ b/pina/callback/__init__.py @@ -5,8 +5,10 @@ __all__ = [ "MetricTracker", "PINAProgressBar", "R3Refinement", + "NormalizerDataCallback", ] from .optimizer_callback import SwitchOptimizer from .processing_callback import MetricTracker, PINAProgressBar from .refinement import R3Refinement +from .normalizer_data_callback import NormalizerDataCallback diff --git a/pina/callback/normalizer_data_callback.py b/pina/callback/normalizer_data_callback.py new file mode 100644 index 0000000..ef957b9 --- /dev/null +++ b/pina/callback/normalizer_data_callback.py @@ -0,0 +1,228 @@ +"""Module for the Normalizer callback.""" + +import torch +from lightning.pytorch import Callback +from ..label_tensor import LabelTensor +from ..utils import check_consistency, is_function +from ..condition import InputTargetCondition +from ..data.dataset import PinaGraphDataset + + +class NormalizerDataCallback(Callback): + r""" + A Callback used to normalize the dataset inputs or targets according to + user-provided scale and shift functions. + + The transformation is applied as: + + .. math:: + + x_{\text{new}} = \frac{x - \text{shift}}{\text{scale}} + + :Example: + + >>> NormalizerDataCallback() + >>> NormalizerDataCallback( + ... scale_fn: torch.std, + ... shift_fn: torch.mean, + ... stage: "all", + ... apply_to: "input", + ... ) + """ + + def __init__( + self, + scale_fn=torch.std, + shift_fn=torch.mean, + stage="all", + apply_to="input", + ): + """ + Initialization of the :class:`NormalizerDataCallback` class. + + :param Callable scale_fn: The function to compute the scaling factor. + Default is ``torch.std``. + :param Callable shift_fn: The function to compute the shifting factor. + Default is ``torch.mean``. + :param str stage: The stage in which normalization is applied. + Accepted values are "train", "validate", "test", or "all". + Default is ``"all"``. + :param str apply_to: Whether to normalize "input" or "target" data. + Default is ``"input"``. + :raises ValueError: If ``scale_fn`` is not callable. + :raises ValueError: If ``shift_fn`` is not callable. + """ + super().__init__() + + # Validate parameters + self.apply_to = self._validate_apply_to(apply_to) + self.stage = self._validate_stage(stage) + + # Validate functions + if not is_function(scale_fn): + raise ValueError(f"scale_fn must be Callable, got {scale_fn}") + if not is_function(shift_fn): + raise ValueError(f"shift_fn must be Callable, got {shift_fn}") + self.scale_fn = scale_fn + self.shift_fn = shift_fn + + # Initialize normalizer dictionary + self._normalizer = {} + + def _validate_apply_to(self, apply_to): + """ + Validate the ``apply_to`` parameter. + + :param str apply_to: The candidate value for the ``apply_to`` parameter. + :raises ValueError: If ``apply_to`` is neither "input" nor "target". + :return: The validated ``apply_to`` value. + :rtype: str + """ + check_consistency(apply_to, str) + if apply_to not in {"input", "target"}: + raise ValueError( + f"apply_to must be either 'input' or 'target', got {apply_to}" + ) + + return apply_to + + def _validate_stage(self, stage): + """ + Validate the ``stage`` parameter. + + :param str stage: The candidate value for the ``stage`` parameter. + :raises ValueError: If ``stage`` is not one of "train", "validate", + "test", or "all". + :return: The validated ``stage`` value. + :rtype: str + """ + check_consistency(stage, str) + if stage not in {"train", "validate", "test", "all"}: + raise ValueError( + "stage must be one of 'train', 'validate', 'test', or 'all'," + f" got {stage}" + ) + + return stage + + def setup(self, trainer, pl_module, stage): + """ + Apply normalization during setup. + + :param Trainer trainer: A :class:`~pina.trainer.Trainer` instance. + :param SolverInterface pl_module: A + :class:`~pina.solver.solver.SolverInterface` instance. + :param str stage: The current stage. + :raises RuntimeError: If the training dataset is not available when + computing normalization parameters. + :return: The result of the parent setup. + :rtype: Any + + :raises NotImplementedError: If the dataset is graph-based. + """ + + # Ensure datsets are not graph-based + if isinstance(trainer.datamodule.train_dataset, PinaGraphDataset): + raise NotImplementedError( + "NormalizerDataCallback is not compatible with " + "graph-based datasets." + ) + + # Extract conditions + conditions_to_normalize = [ + name + for name, cond in pl_module.problem.conditions.items() + if isinstance(cond, InputTargetCondition) + ] + + # Compute scale and shift parameters + if not self.normalizer: + if not trainer.datamodule.train_dataset: + raise RuntimeError( + "Training dataset is not available. Cannot compute " + "normalization parameters." + ) + self._compute_scale_shift( + conditions_to_normalize, trainer.datamodule.train_dataset + ) + + # Apply normalization based on the specified stage + if stage == "fit" and self.stage in ["train", "all"]: + self.normalize_dataset(trainer.datamodule.train_dataset) + if stage == "fit" and self.stage in ["validate", "all"]: + self.normalize_dataset(trainer.datamodule.val_dataset) + if stage == "test" and self.stage in ["test", "all"]: + self.normalize_dataset(trainer.datamodule.test_dataset) + + return super().setup(trainer, pl_module, stage) + + def _compute_scale_shift(self, conditions, dataset): + """ + Compute scale and shift parameters for each condition in the dataset. + + :param list conditions: The list of condition names. + :param dataset: The `~pina.data.dataset.PinaDataset` dataset. + """ + for cond in conditions: + if cond in dataset.conditions_dict: + data = dataset.conditions_dict[cond][self.apply_to] + shift = self.shift_fn(data) + scale = self.scale_fn(data) + self._normalizer[cond] = { + "shift": shift, + "scale": scale, + } + + @staticmethod + def _norm_fn(value, scale, shift): + """ + Normalize a value according to the scale and shift parameters. + + :param value: The input tensor to normalize. + :type value: torch.Tensor | LabelTensor + :param float scale: The scaling factor. + :param float shift: The shifting factor. + :return: The normalized tensor. + :rtype: torch.Tensor | LabelTensor + """ + scaled_value = (value - shift) / scale + if isinstance(value, LabelTensor): + scaled_value = LabelTensor(scaled_value, value.labels) + + return scaled_value + + def normalize_dataset(self, dataset): + """ + Apply in-place normalization to the dataset. + + :param PinaDataset dataset: The dataset to be normalized. + """ + # Initialize update dictionary + update_dataset_dict = {} + + # Iterate over conditions and apply normalization + for cond, norm_params in self.normalizer.items(): + points = dataset.conditions_dict[cond][self.apply_to] + scale = norm_params["scale"] + shift = norm_params["shift"] + normalized_points = self._norm_fn(points, scale, shift) + update_dataset_dict[cond] = { + self.apply_to: ( + LabelTensor(normalized_points, points.labels) + if isinstance(points, LabelTensor) + else normalized_points + ) + } + + # Update the dataset in-place + dataset.update_data(update_dataset_dict) + + @property + def normalizer(self): + """ + Get the dictionary of normalization parameters. + + :return: The dictionary of normalization parameters. + :rtype: dict + """ + return self._normalizer diff --git a/pina/utils.py b/pina/utils.py index 2aafba1..efc4842 100644 --- a/pina/utils.py +++ b/pina/utils.py @@ -206,7 +206,7 @@ def is_function(f): :return: ``True`` if ``f`` is a function, ``False`` otherwise. :rtype: bool """ - return isinstance(f, (types.FunctionType, types.LambdaType)) + return callable(f) def chebyshev_roots(n): diff --git a/tests/test_callback/test_normalizer_data_callback.py b/tests/test_callback/test_normalizer_data_callback.py new file mode 100644 index 0000000..7cdcc95 --- /dev/null +++ b/tests/test_callback/test_normalizer_data_callback.py @@ -0,0 +1,244 @@ +import torch +import pytest +from copy import deepcopy + +from pina import Trainer, LabelTensor, Condition +from pina.solver import SupervisedSolver +from pina.model import FeedForward +from pina.callback import NormalizerDataCallback +from pina.problem import AbstractProblem +from pina.problem.zoo import Poisson2DSquareProblem as Poisson +from pina.solver import PINN +from pina.graph import RadiusGraph + +# for checking normalization +stage_map = { + "train": ["train_dataset"], + "validate": ["val_dataset"], + "test": ["test_dataset"], + "all": ["train_dataset", "val_dataset", "test_dataset"], +} + +input_1 = torch.rand(20, 2) * 10 +target_1 = torch.rand(20, 1) * 10 +input_2 = torch.rand(20, 2) * 5 +target_2 = torch.rand(20, 1) * 5 + + +class LabelTensorProblem(AbstractProblem): + input_variables = ["u_0", "u_1"] + output_variables = ["u"] + conditions = { + "data1": Condition( + input=LabelTensor(input_1, ["u_0", "u_1"]), + target=LabelTensor(target_1, ["u"]), + ), + "data2": Condition( + input=LabelTensor(input_2, ["u_0", "u_1"]), + target=LabelTensor(target_2, ["u"]), + ), + } + + +class TensorProblem(AbstractProblem): + input_variables = ["u_0", "u_1"] + output_variables = ["u"] + conditions = { + "data1": Condition(input=input_1, target=target_1), + "data2": Condition(input=input_2, target=target_2), + } + + +input_graph = [RadiusGraph(radius=0.5, pos=torch.rand(10, 2)) for _ in range(5)] +output_graph = torch.rand(5, 1) + + +class GraphProblem(AbstractProblem): + input_variables = ["u_0", "u_1"] + output_variables = ["u"] + conditions = { + "data": Condition(input=input_graph, target=output_graph), + } + + +supervised_solver_no_lt = SupervisedSolver( + problem=TensorProblem(), model=FeedForward(2, 1), use_lt=False +) +supervised_solver_lt = SupervisedSolver( + problem=LabelTensorProblem(), model=FeedForward(2, 1), use_lt=True +) + +poisson_problem = Poisson() +poisson_problem.conditions["data"] = Condition( + input=LabelTensor(torch.rand(20, 2) * 10, ["x", "y"]), + target=LabelTensor(torch.rand(20, 1) * 10, ["u"]), +) + + +@pytest.mark.parametrize("scale_fn", [torch.std, torch.var]) +@pytest.mark.parametrize("shift_fn", [torch.mean, torch.median]) +@pytest.mark.parametrize("apply_to", ["input", "target"]) +@pytest.mark.parametrize("stage", ["train", "validate", "test", "all"]) +def test_init(scale_fn, shift_fn, apply_to, stage): + normalizer = NormalizerDataCallback( + scale_fn=scale_fn, shift_fn=shift_fn, apply_to=apply_to, stage=stage + ) + assert normalizer.scale_fn == scale_fn + assert normalizer.shift_fn == shift_fn + assert normalizer.apply_to == apply_to + assert normalizer.stage == stage + + +def test_init_invalid_scale(): + with pytest.raises(ValueError): + NormalizerDataCallback(scale_fn=1) + + +def test_init_invalid_shift(): + with pytest.raises(ValueError): + NormalizerDataCallback(shift_fn=1) + + +@pytest.mark.parametrize("invalid_apply_to", ["inputt", "targett", 1]) +def test_init_invalid_apply_to(invalid_apply_to): + with pytest.raises(ValueError): + NormalizerDataCallback(apply_to=invalid_apply_to) + + +@pytest.mark.parametrize("invalid_stage", ["trainn", "validatee", 1]) +def test_init_invalid_stage(invalid_stage): + with pytest.raises(ValueError): + NormalizerDataCallback(stage=invalid_stage) + + +@pytest.mark.parametrize( + "solver", [supervised_solver_lt, supervised_solver_no_lt] +) +@pytest.mark.parametrize( + "fn", [[torch.std, torch.mean], [torch.var, torch.median]] +) +@pytest.mark.parametrize("apply_to", ["input", "target"]) +@pytest.mark.parametrize("stage", ["all", "train", "validate", "test"]) +def test_setup(solver, fn, stage, apply_to): + scale_fn, shift_fn = fn + trainer = Trainer( + solver=solver, + callbacks=NormalizerDataCallback( + scale_fn=scale_fn, shift_fn=shift_fn, stage=stage, apply_to=apply_to + ), + max_epochs=1, + train_size=0.4, + val_size=0.3, + test_size=0.3, + shuffle=False, + ) + trainer_copy = deepcopy(trainer) + trainer_copy.data_module.setup("fit") + trainer_copy.data_module.setup("test") + trainer.train() + trainer.test() + + normalizer = trainer.callbacks[0].normalizer + + for cond in ["data1", "data2"]: + scale = scale_fn( + trainer_copy.data_module.train_dataset.conditions_dict[cond][ + apply_to + ] + ) + shift = shift_fn( + trainer_copy.data_module.train_dataset.conditions_dict[cond][ + apply_to + ] + ) + assert "scale" in normalizer[cond] + assert "shift" in normalizer[cond] + assert normalizer[cond]["scale"] - scale < 1e-5 + assert normalizer[cond]["shift"] - shift < 1e-5 + for ds_name in stage_map[stage]: + dataset = getattr(trainer.data_module, ds_name, None) + old_dataset = getattr(trainer_copy.data_module, ds_name, None) + current_points = dataset.conditions_dict[cond][apply_to] + old_points = old_dataset.conditions_dict[cond][apply_to] + expected = (old_points - shift) / scale + assert torch.allclose(current_points, expected) + + +@pytest.mark.parametrize( + "fn", [[torch.std, torch.mean], [torch.var, torch.median]] +) +@pytest.mark.parametrize("apply_to", ["input"]) +@pytest.mark.parametrize("stage", ["all", "train", "validate", "test"]) +def test_setup_pinn(fn, stage, apply_to): + scale_fn, shift_fn = fn + pinn = PINN( + problem=poisson_problem, + model=FeedForward(2, 1), + ) + poisson_problem.discretise_domain(n=10) + trainer = Trainer( + solver=pinn, + callbacks=NormalizerDataCallback( + scale_fn=scale_fn, + shift_fn=shift_fn, + stage=stage, + apply_to=apply_to, + ), + max_epochs=1, + train_size=0.4, + val_size=0.3, + test_size=0.3, + shuffle=False, + ) + + trainer_copy = deepcopy(trainer) + trainer_copy.data_module.setup("fit") + trainer_copy.data_module.setup("test") + trainer.train() + trainer.test() + + conditions = trainer.callbacks[0].normalizer.keys() + assert "data" in conditions + assert len(conditions) == 1 + normalizer = trainer.callbacks[0].normalizer + cond = "data" + + scale = scale_fn( + trainer_copy.data_module.train_dataset.conditions_dict[cond][apply_to] + ) + shift = shift_fn( + trainer_copy.data_module.train_dataset.conditions_dict[cond][apply_to] + ) + assert "scale" in normalizer[cond] + assert "shift" in normalizer[cond] + assert normalizer[cond]["scale"] - scale < 1e-5 + assert normalizer[cond]["shift"] - shift < 1e-5 + for ds_name in stage_map[stage]: + dataset = getattr(trainer.data_module, ds_name, None) + old_dataset = getattr(trainer_copy.data_module, ds_name, None) + current_points = dataset.conditions_dict[cond][apply_to] + old_points = old_dataset.conditions_dict[cond][apply_to] + expected = (old_points - shift) / scale + assert torch.allclose(current_points, expected) + + +def test_setup_graph_dataset(): + solver = SupervisedSolver( + problem=GraphProblem(), model=FeedForward(2, 1), use_lt=False + ) + trainer = Trainer( + solver=solver, + callbacks=NormalizerDataCallback( + scale_fn=torch.std, + shift_fn=torch.mean, + stage="all", + apply_to="input", + ), + max_epochs=1, + train_size=0.4, + val_size=0.3, + test_size=0.3, + shuffle=False, + ) + with pytest.raises(NotImplementedError): + trainer.train()