Files
PINA/tests/test_callback/test_normalizer_data_callback.py
Dario Coscia dc808c1d77 Add Normalizer Callback (#631)
* add normalizer callback

* implement shift and scale parameters computation

* change name files normalizer data callback

* reduce tests

* fix documentation

* add NotImplementedError for PinaGraphDataset

---------

Co-authored-by: FilippoOlivo <filippo@filippoolivo.com>
Co-authored-by: giovanni <giovanni.canali98@yahoo.it>
2025-09-16 17:29:05 +02:00

245 lines
7.5 KiB
Python

import torch
import pytest
from copy import deepcopy
from pina import Trainer, LabelTensor, Condition
from pina.solver import SupervisedSolver
from pina.model import FeedForward
from pina.callback import NormalizerDataCallback
from pina.problem import AbstractProblem
from pina.problem.zoo import Poisson2DSquareProblem as Poisson
from pina.solver import PINN
from pina.graph import RadiusGraph
# for checking normalization
stage_map = {
"train": ["train_dataset"],
"validate": ["val_dataset"],
"test": ["test_dataset"],
"all": ["train_dataset", "val_dataset", "test_dataset"],
}
input_1 = torch.rand(20, 2) * 10
target_1 = torch.rand(20, 1) * 10
input_2 = torch.rand(20, 2) * 5
target_2 = torch.rand(20, 1) * 5
class LabelTensorProblem(AbstractProblem):
input_variables = ["u_0", "u_1"]
output_variables = ["u"]
conditions = {
"data1": Condition(
input=LabelTensor(input_1, ["u_0", "u_1"]),
target=LabelTensor(target_1, ["u"]),
),
"data2": Condition(
input=LabelTensor(input_2, ["u_0", "u_1"]),
target=LabelTensor(target_2, ["u"]),
),
}
class TensorProblem(AbstractProblem):
input_variables = ["u_0", "u_1"]
output_variables = ["u"]
conditions = {
"data1": Condition(input=input_1, target=target_1),
"data2": Condition(input=input_2, target=target_2),
}
input_graph = [RadiusGraph(radius=0.5, pos=torch.rand(10, 2)) for _ in range(5)]
output_graph = torch.rand(5, 1)
class GraphProblem(AbstractProblem):
input_variables = ["u_0", "u_1"]
output_variables = ["u"]
conditions = {
"data": Condition(input=input_graph, target=output_graph),
}
supervised_solver_no_lt = SupervisedSolver(
problem=TensorProblem(), model=FeedForward(2, 1), use_lt=False
)
supervised_solver_lt = SupervisedSolver(
problem=LabelTensorProblem(), model=FeedForward(2, 1), use_lt=True
)
poisson_problem = Poisson()
poisson_problem.conditions["data"] = Condition(
input=LabelTensor(torch.rand(20, 2) * 10, ["x", "y"]),
target=LabelTensor(torch.rand(20, 1) * 10, ["u"]),
)
@pytest.mark.parametrize("scale_fn", [torch.std, torch.var])
@pytest.mark.parametrize("shift_fn", [torch.mean, torch.median])
@pytest.mark.parametrize("apply_to", ["input", "target"])
@pytest.mark.parametrize("stage", ["train", "validate", "test", "all"])
def test_init(scale_fn, shift_fn, apply_to, stage):
normalizer = NormalizerDataCallback(
scale_fn=scale_fn, shift_fn=shift_fn, apply_to=apply_to, stage=stage
)
assert normalizer.scale_fn == scale_fn
assert normalizer.shift_fn == shift_fn
assert normalizer.apply_to == apply_to
assert normalizer.stage == stage
def test_init_invalid_scale():
with pytest.raises(ValueError):
NormalizerDataCallback(scale_fn=1)
def test_init_invalid_shift():
with pytest.raises(ValueError):
NormalizerDataCallback(shift_fn=1)
@pytest.mark.parametrize("invalid_apply_to", ["inputt", "targett", 1])
def test_init_invalid_apply_to(invalid_apply_to):
with pytest.raises(ValueError):
NormalizerDataCallback(apply_to=invalid_apply_to)
@pytest.mark.parametrize("invalid_stage", ["trainn", "validatee", 1])
def test_init_invalid_stage(invalid_stage):
with pytest.raises(ValueError):
NormalizerDataCallback(stage=invalid_stage)
@pytest.mark.parametrize(
"solver", [supervised_solver_lt, supervised_solver_no_lt]
)
@pytest.mark.parametrize(
"fn", [[torch.std, torch.mean], [torch.var, torch.median]]
)
@pytest.mark.parametrize("apply_to", ["input", "target"])
@pytest.mark.parametrize("stage", ["all", "train", "validate", "test"])
def test_setup(solver, fn, stage, apply_to):
scale_fn, shift_fn = fn
trainer = Trainer(
solver=solver,
callbacks=NormalizerDataCallback(
scale_fn=scale_fn, shift_fn=shift_fn, stage=stage, apply_to=apply_to
),
max_epochs=1,
train_size=0.4,
val_size=0.3,
test_size=0.3,
shuffle=False,
)
trainer_copy = deepcopy(trainer)
trainer_copy.data_module.setup("fit")
trainer_copy.data_module.setup("test")
trainer.train()
trainer.test()
normalizer = trainer.callbacks[0].normalizer
for cond in ["data1", "data2"]:
scale = scale_fn(
trainer_copy.data_module.train_dataset.conditions_dict[cond][
apply_to
]
)
shift = shift_fn(
trainer_copy.data_module.train_dataset.conditions_dict[cond][
apply_to
]
)
assert "scale" in normalizer[cond]
assert "shift" in normalizer[cond]
assert normalizer[cond]["scale"] - scale < 1e-5
assert normalizer[cond]["shift"] - shift < 1e-5
for ds_name in stage_map[stage]:
dataset = getattr(trainer.data_module, ds_name, None)
old_dataset = getattr(trainer_copy.data_module, ds_name, None)
current_points = dataset.conditions_dict[cond][apply_to]
old_points = old_dataset.conditions_dict[cond][apply_to]
expected = (old_points - shift) / scale
assert torch.allclose(current_points, expected)
@pytest.mark.parametrize(
"fn", [[torch.std, torch.mean], [torch.var, torch.median]]
)
@pytest.mark.parametrize("apply_to", ["input"])
@pytest.mark.parametrize("stage", ["all", "train", "validate", "test"])
def test_setup_pinn(fn, stage, apply_to):
scale_fn, shift_fn = fn
pinn = PINN(
problem=poisson_problem,
model=FeedForward(2, 1),
)
poisson_problem.discretise_domain(n=10)
trainer = Trainer(
solver=pinn,
callbacks=NormalizerDataCallback(
scale_fn=scale_fn,
shift_fn=shift_fn,
stage=stage,
apply_to=apply_to,
),
max_epochs=1,
train_size=0.4,
val_size=0.3,
test_size=0.3,
shuffle=False,
)
trainer_copy = deepcopy(trainer)
trainer_copy.data_module.setup("fit")
trainer_copy.data_module.setup("test")
trainer.train()
trainer.test()
conditions = trainer.callbacks[0].normalizer.keys()
assert "data" in conditions
assert len(conditions) == 1
normalizer = trainer.callbacks[0].normalizer
cond = "data"
scale = scale_fn(
trainer_copy.data_module.train_dataset.conditions_dict[cond][apply_to]
)
shift = shift_fn(
trainer_copy.data_module.train_dataset.conditions_dict[cond][apply_to]
)
assert "scale" in normalizer[cond]
assert "shift" in normalizer[cond]
assert normalizer[cond]["scale"] - scale < 1e-5
assert normalizer[cond]["shift"] - shift < 1e-5
for ds_name in stage_map[stage]:
dataset = getattr(trainer.data_module, ds_name, None)
old_dataset = getattr(trainer_copy.data_module, ds_name, None)
current_points = dataset.conditions_dict[cond][apply_to]
old_points = old_dataset.conditions_dict[cond][apply_to]
expected = (old_points - shift) / scale
assert torch.allclose(current_points, expected)
def test_setup_graph_dataset():
solver = SupervisedSolver(
problem=GraphProblem(), model=FeedForward(2, 1), use_lt=False
)
trainer = Trainer(
solver=solver,
callbacks=NormalizerDataCallback(
scale_fn=torch.std,
shift_fn=torch.mean,
stage="all",
apply_to="input",
),
max_epochs=1,
train_size=0.4,
val_size=0.3,
test_size=0.3,
shuffle=False,
)
with pytest.raises(NotImplementedError):
trainer.train()