Add Normalizer Callback (#631)

* add normalizer callback

* implement shift and scale parameters computation

* change name files normalizer data callback

* reduce tests

* fix documentation

* add NotImplementedError for PinaGraphDataset

---------

Co-authored-by: FilippoOlivo <filippo@filippoolivo.com>
Co-authored-by: giovanni <giovanni.canali98@yahoo.it>
This commit is contained in:
Dario Coscia
2025-09-16 17:29:05 +02:00
committed by GitHub
parent ef75f13bcb
commit dc808c1d77
6 changed files with 483 additions and 1 deletions

View File

@@ -253,6 +253,7 @@ Callbacks
Optimizer callback <callback/optimizer_callback.rst>
R3 Refinment callback <callback/refinement/r3_refinement.rst>
Refinment Interface callback <callback/refinement/refinement_interface.rst>
Normalizer callback <callback/normalizer_data_callback.rst>
Losses and Weightings
---------------------

View File

@@ -0,0 +1,7 @@
Normalizer callbacks
=======================
.. currentmodule:: pina.callback.normalizer_data_callback
.. autoclass:: NormalizerDataCallback
:members:
:show-inheritance:

View File

@@ -5,8 +5,10 @@ __all__ = [
"MetricTracker",
"PINAProgressBar",
"R3Refinement",
"NormalizerDataCallback",
]
from .optimizer_callback import SwitchOptimizer
from .processing_callback import MetricTracker, PINAProgressBar
from .refinement import R3Refinement
from .normalizer_data_callback import NormalizerDataCallback

View File

@@ -0,0 +1,228 @@
"""Module for the Normalizer callback."""
import torch
from lightning.pytorch import Callback
from ..label_tensor import LabelTensor
from ..utils import check_consistency, is_function
from ..condition import InputTargetCondition
from ..data.dataset import PinaGraphDataset
class NormalizerDataCallback(Callback):
r"""
A Callback used to normalize the dataset inputs or targets according to
user-provided scale and shift functions.
The transformation is applied as:
.. math::
x_{\text{new}} = \frac{x - \text{shift}}{\text{scale}}
:Example:
>>> NormalizerDataCallback()
>>> NormalizerDataCallback(
... scale_fn: torch.std,
... shift_fn: torch.mean,
... stage: "all",
... apply_to: "input",
... )
"""
def __init__(
self,
scale_fn=torch.std,
shift_fn=torch.mean,
stage="all",
apply_to="input",
):
"""
Initialization of the :class:`NormalizerDataCallback` class.
:param Callable scale_fn: The function to compute the scaling factor.
Default is ``torch.std``.
:param Callable shift_fn: The function to compute the shifting factor.
Default is ``torch.mean``.
:param str stage: The stage in which normalization is applied.
Accepted values are "train", "validate", "test", or "all".
Default is ``"all"``.
:param str apply_to: Whether to normalize "input" or "target" data.
Default is ``"input"``.
:raises ValueError: If ``scale_fn`` is not callable.
:raises ValueError: If ``shift_fn`` is not callable.
"""
super().__init__()
# Validate parameters
self.apply_to = self._validate_apply_to(apply_to)
self.stage = self._validate_stage(stage)
# Validate functions
if not is_function(scale_fn):
raise ValueError(f"scale_fn must be Callable, got {scale_fn}")
if not is_function(shift_fn):
raise ValueError(f"shift_fn must be Callable, got {shift_fn}")
self.scale_fn = scale_fn
self.shift_fn = shift_fn
# Initialize normalizer dictionary
self._normalizer = {}
def _validate_apply_to(self, apply_to):
"""
Validate the ``apply_to`` parameter.
:param str apply_to: The candidate value for the ``apply_to`` parameter.
:raises ValueError: If ``apply_to`` is neither "input" nor "target".
:return: The validated ``apply_to`` value.
:rtype: str
"""
check_consistency(apply_to, str)
if apply_to not in {"input", "target"}:
raise ValueError(
f"apply_to must be either 'input' or 'target', got {apply_to}"
)
return apply_to
def _validate_stage(self, stage):
"""
Validate the ``stage`` parameter.
:param str stage: The candidate value for the ``stage`` parameter.
:raises ValueError: If ``stage`` is not one of "train", "validate",
"test", or "all".
:return: The validated ``stage`` value.
:rtype: str
"""
check_consistency(stage, str)
if stage not in {"train", "validate", "test", "all"}:
raise ValueError(
"stage must be one of 'train', 'validate', 'test', or 'all',"
f" got {stage}"
)
return stage
def setup(self, trainer, pl_module, stage):
"""
Apply normalization during setup.
:param Trainer trainer: A :class:`~pina.trainer.Trainer` instance.
:param SolverInterface pl_module: A
:class:`~pina.solver.solver.SolverInterface` instance.
:param str stage: The current stage.
:raises RuntimeError: If the training dataset is not available when
computing normalization parameters.
:return: The result of the parent setup.
:rtype: Any
:raises NotImplementedError: If the dataset is graph-based.
"""
# Ensure datsets are not graph-based
if isinstance(trainer.datamodule.train_dataset, PinaGraphDataset):
raise NotImplementedError(
"NormalizerDataCallback is not compatible with "
"graph-based datasets."
)
# Extract conditions
conditions_to_normalize = [
name
for name, cond in pl_module.problem.conditions.items()
if isinstance(cond, InputTargetCondition)
]
# Compute scale and shift parameters
if not self.normalizer:
if not trainer.datamodule.train_dataset:
raise RuntimeError(
"Training dataset is not available. Cannot compute "
"normalization parameters."
)
self._compute_scale_shift(
conditions_to_normalize, trainer.datamodule.train_dataset
)
# Apply normalization based on the specified stage
if stage == "fit" and self.stage in ["train", "all"]:
self.normalize_dataset(trainer.datamodule.train_dataset)
if stage == "fit" and self.stage in ["validate", "all"]:
self.normalize_dataset(trainer.datamodule.val_dataset)
if stage == "test" and self.stage in ["test", "all"]:
self.normalize_dataset(trainer.datamodule.test_dataset)
return super().setup(trainer, pl_module, stage)
def _compute_scale_shift(self, conditions, dataset):
"""
Compute scale and shift parameters for each condition in the dataset.
:param list conditions: The list of condition names.
:param dataset: The `~pina.data.dataset.PinaDataset` dataset.
"""
for cond in conditions:
if cond in dataset.conditions_dict:
data = dataset.conditions_dict[cond][self.apply_to]
shift = self.shift_fn(data)
scale = self.scale_fn(data)
self._normalizer[cond] = {
"shift": shift,
"scale": scale,
}
@staticmethod
def _norm_fn(value, scale, shift):
"""
Normalize a value according to the scale and shift parameters.
:param value: The input tensor to normalize.
:type value: torch.Tensor | LabelTensor
:param float scale: The scaling factor.
:param float shift: The shifting factor.
:return: The normalized tensor.
:rtype: torch.Tensor | LabelTensor
"""
scaled_value = (value - shift) / scale
if isinstance(value, LabelTensor):
scaled_value = LabelTensor(scaled_value, value.labels)
return scaled_value
def normalize_dataset(self, dataset):
"""
Apply in-place normalization to the dataset.
:param PinaDataset dataset: The dataset to be normalized.
"""
# Initialize update dictionary
update_dataset_dict = {}
# Iterate over conditions and apply normalization
for cond, norm_params in self.normalizer.items():
points = dataset.conditions_dict[cond][self.apply_to]
scale = norm_params["scale"]
shift = norm_params["shift"]
normalized_points = self._norm_fn(points, scale, shift)
update_dataset_dict[cond] = {
self.apply_to: (
LabelTensor(normalized_points, points.labels)
if isinstance(points, LabelTensor)
else normalized_points
)
}
# Update the dataset in-place
dataset.update_data(update_dataset_dict)
@property
def normalizer(self):
"""
Get the dictionary of normalization parameters.
:return: The dictionary of normalization parameters.
:rtype: dict
"""
return self._normalizer

View File

@@ -206,7 +206,7 @@ def is_function(f):
:return: ``True`` if ``f`` is a function, ``False`` otherwise.
:rtype: bool
"""
return isinstance(f, (types.FunctionType, types.LambdaType))
return callable(f)
def chebyshev_roots(n):

View File

@@ -0,0 +1,244 @@
import torch
import pytest
from copy import deepcopy
from pina import Trainer, LabelTensor, Condition
from pina.solver import SupervisedSolver
from pina.model import FeedForward
from pina.callback import NormalizerDataCallback
from pina.problem import AbstractProblem
from pina.problem.zoo import Poisson2DSquareProblem as Poisson
from pina.solver import PINN
from pina.graph import RadiusGraph
# for checking normalization
stage_map = {
"train": ["train_dataset"],
"validate": ["val_dataset"],
"test": ["test_dataset"],
"all": ["train_dataset", "val_dataset", "test_dataset"],
}
input_1 = torch.rand(20, 2) * 10
target_1 = torch.rand(20, 1) * 10
input_2 = torch.rand(20, 2) * 5
target_2 = torch.rand(20, 1) * 5
class LabelTensorProblem(AbstractProblem):
input_variables = ["u_0", "u_1"]
output_variables = ["u"]
conditions = {
"data1": Condition(
input=LabelTensor(input_1, ["u_0", "u_1"]),
target=LabelTensor(target_1, ["u"]),
),
"data2": Condition(
input=LabelTensor(input_2, ["u_0", "u_1"]),
target=LabelTensor(target_2, ["u"]),
),
}
class TensorProblem(AbstractProblem):
input_variables = ["u_0", "u_1"]
output_variables = ["u"]
conditions = {
"data1": Condition(input=input_1, target=target_1),
"data2": Condition(input=input_2, target=target_2),
}
input_graph = [RadiusGraph(radius=0.5, pos=torch.rand(10, 2)) for _ in range(5)]
output_graph = torch.rand(5, 1)
class GraphProblem(AbstractProblem):
input_variables = ["u_0", "u_1"]
output_variables = ["u"]
conditions = {
"data": Condition(input=input_graph, target=output_graph),
}
supervised_solver_no_lt = SupervisedSolver(
problem=TensorProblem(), model=FeedForward(2, 1), use_lt=False
)
supervised_solver_lt = SupervisedSolver(
problem=LabelTensorProblem(), model=FeedForward(2, 1), use_lt=True
)
poisson_problem = Poisson()
poisson_problem.conditions["data"] = Condition(
input=LabelTensor(torch.rand(20, 2) * 10, ["x", "y"]),
target=LabelTensor(torch.rand(20, 1) * 10, ["u"]),
)
@pytest.mark.parametrize("scale_fn", [torch.std, torch.var])
@pytest.mark.parametrize("shift_fn", [torch.mean, torch.median])
@pytest.mark.parametrize("apply_to", ["input", "target"])
@pytest.mark.parametrize("stage", ["train", "validate", "test", "all"])
def test_init(scale_fn, shift_fn, apply_to, stage):
normalizer = NormalizerDataCallback(
scale_fn=scale_fn, shift_fn=shift_fn, apply_to=apply_to, stage=stage
)
assert normalizer.scale_fn == scale_fn
assert normalizer.shift_fn == shift_fn
assert normalizer.apply_to == apply_to
assert normalizer.stage == stage
def test_init_invalid_scale():
with pytest.raises(ValueError):
NormalizerDataCallback(scale_fn=1)
def test_init_invalid_shift():
with pytest.raises(ValueError):
NormalizerDataCallback(shift_fn=1)
@pytest.mark.parametrize("invalid_apply_to", ["inputt", "targett", 1])
def test_init_invalid_apply_to(invalid_apply_to):
with pytest.raises(ValueError):
NormalizerDataCallback(apply_to=invalid_apply_to)
@pytest.mark.parametrize("invalid_stage", ["trainn", "validatee", 1])
def test_init_invalid_stage(invalid_stage):
with pytest.raises(ValueError):
NormalizerDataCallback(stage=invalid_stage)
@pytest.mark.parametrize(
"solver", [supervised_solver_lt, supervised_solver_no_lt]
)
@pytest.mark.parametrize(
"fn", [[torch.std, torch.mean], [torch.var, torch.median]]
)
@pytest.mark.parametrize("apply_to", ["input", "target"])
@pytest.mark.parametrize("stage", ["all", "train", "validate", "test"])
def test_setup(solver, fn, stage, apply_to):
scale_fn, shift_fn = fn
trainer = Trainer(
solver=solver,
callbacks=NormalizerDataCallback(
scale_fn=scale_fn, shift_fn=shift_fn, stage=stage, apply_to=apply_to
),
max_epochs=1,
train_size=0.4,
val_size=0.3,
test_size=0.3,
shuffle=False,
)
trainer_copy = deepcopy(trainer)
trainer_copy.data_module.setup("fit")
trainer_copy.data_module.setup("test")
trainer.train()
trainer.test()
normalizer = trainer.callbacks[0].normalizer
for cond in ["data1", "data2"]:
scale = scale_fn(
trainer_copy.data_module.train_dataset.conditions_dict[cond][
apply_to
]
)
shift = shift_fn(
trainer_copy.data_module.train_dataset.conditions_dict[cond][
apply_to
]
)
assert "scale" in normalizer[cond]
assert "shift" in normalizer[cond]
assert normalizer[cond]["scale"] - scale < 1e-5
assert normalizer[cond]["shift"] - shift < 1e-5
for ds_name in stage_map[stage]:
dataset = getattr(trainer.data_module, ds_name, None)
old_dataset = getattr(trainer_copy.data_module, ds_name, None)
current_points = dataset.conditions_dict[cond][apply_to]
old_points = old_dataset.conditions_dict[cond][apply_to]
expected = (old_points - shift) / scale
assert torch.allclose(current_points, expected)
@pytest.mark.parametrize(
"fn", [[torch.std, torch.mean], [torch.var, torch.median]]
)
@pytest.mark.parametrize("apply_to", ["input"])
@pytest.mark.parametrize("stage", ["all", "train", "validate", "test"])
def test_setup_pinn(fn, stage, apply_to):
scale_fn, shift_fn = fn
pinn = PINN(
problem=poisson_problem,
model=FeedForward(2, 1),
)
poisson_problem.discretise_domain(n=10)
trainer = Trainer(
solver=pinn,
callbacks=NormalizerDataCallback(
scale_fn=scale_fn,
shift_fn=shift_fn,
stage=stage,
apply_to=apply_to,
),
max_epochs=1,
train_size=0.4,
val_size=0.3,
test_size=0.3,
shuffle=False,
)
trainer_copy = deepcopy(trainer)
trainer_copy.data_module.setup("fit")
trainer_copy.data_module.setup("test")
trainer.train()
trainer.test()
conditions = trainer.callbacks[0].normalizer.keys()
assert "data" in conditions
assert len(conditions) == 1
normalizer = trainer.callbacks[0].normalizer
cond = "data"
scale = scale_fn(
trainer_copy.data_module.train_dataset.conditions_dict[cond][apply_to]
)
shift = shift_fn(
trainer_copy.data_module.train_dataset.conditions_dict[cond][apply_to]
)
assert "scale" in normalizer[cond]
assert "shift" in normalizer[cond]
assert normalizer[cond]["scale"] - scale < 1e-5
assert normalizer[cond]["shift"] - shift < 1e-5
for ds_name in stage_map[stage]:
dataset = getattr(trainer.data_module, ds_name, None)
old_dataset = getattr(trainer_copy.data_module, ds_name, None)
current_points = dataset.conditions_dict[cond][apply_to]
old_points = old_dataset.conditions_dict[cond][apply_to]
expected = (old_points - shift) / scale
assert torch.allclose(current_points, expected)
def test_setup_graph_dataset():
solver = SupervisedSolver(
problem=GraphProblem(), model=FeedForward(2, 1), use_lt=False
)
trainer = Trainer(
solver=solver,
callbacks=NormalizerDataCallback(
scale_fn=torch.std,
shift_fn=torch.mean,
stage="all",
apply_to="input",
),
max_epochs=1,
train_size=0.4,
val_size=0.3,
test_size=0.3,
shuffle=False,
)
with pytest.raises(NotImplementedError):
trainer.train()