Add Normalizer Callback (#631)
* add normalizer callback * implement shift and scale parameters computation * change name files normalizer data callback * reduce tests * fix documentation * add NotImplementedError for PinaGraphDataset --------- Co-authored-by: FilippoOlivo <filippo@filippoolivo.com> Co-authored-by: giovanni <giovanni.canali98@yahoo.it>
This commit is contained in:
@@ -253,6 +253,7 @@ Callbacks
|
|||||||
Optimizer callback <callback/optimizer_callback.rst>
|
Optimizer callback <callback/optimizer_callback.rst>
|
||||||
R3 Refinment callback <callback/refinement/r3_refinement.rst>
|
R3 Refinment callback <callback/refinement/r3_refinement.rst>
|
||||||
Refinment Interface callback <callback/refinement/refinement_interface.rst>
|
Refinment Interface callback <callback/refinement/refinement_interface.rst>
|
||||||
|
Normalizer callback <callback/normalizer_data_callback.rst>
|
||||||
|
|
||||||
Losses and Weightings
|
Losses and Weightings
|
||||||
---------------------
|
---------------------
|
||||||
|
|||||||
7
docs/source/_rst/callback/normalizer_data_callback.rst
Normal file
7
docs/source/_rst/callback/normalizer_data_callback.rst
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
Normalizer callbacks
|
||||||
|
=======================
|
||||||
|
|
||||||
|
.. currentmodule:: pina.callback.normalizer_data_callback
|
||||||
|
.. autoclass:: NormalizerDataCallback
|
||||||
|
:members:
|
||||||
|
:show-inheritance:
|
||||||
@@ -5,8 +5,10 @@ __all__ = [
|
|||||||
"MetricTracker",
|
"MetricTracker",
|
||||||
"PINAProgressBar",
|
"PINAProgressBar",
|
||||||
"R3Refinement",
|
"R3Refinement",
|
||||||
|
"NormalizerDataCallback",
|
||||||
]
|
]
|
||||||
|
|
||||||
from .optimizer_callback import SwitchOptimizer
|
from .optimizer_callback import SwitchOptimizer
|
||||||
from .processing_callback import MetricTracker, PINAProgressBar
|
from .processing_callback import MetricTracker, PINAProgressBar
|
||||||
from .refinement import R3Refinement
|
from .refinement import R3Refinement
|
||||||
|
from .normalizer_data_callback import NormalizerDataCallback
|
||||||
|
|||||||
228
pina/callback/normalizer_data_callback.py
Normal file
228
pina/callback/normalizer_data_callback.py
Normal file
@@ -0,0 +1,228 @@
|
|||||||
|
"""Module for the Normalizer callback."""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from lightning.pytorch import Callback
|
||||||
|
from ..label_tensor import LabelTensor
|
||||||
|
from ..utils import check_consistency, is_function
|
||||||
|
from ..condition import InputTargetCondition
|
||||||
|
from ..data.dataset import PinaGraphDataset
|
||||||
|
|
||||||
|
|
||||||
|
class NormalizerDataCallback(Callback):
|
||||||
|
r"""
|
||||||
|
A Callback used to normalize the dataset inputs or targets according to
|
||||||
|
user-provided scale and shift functions.
|
||||||
|
|
||||||
|
The transformation is applied as:
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
|
||||||
|
x_{\text{new}} = \frac{x - \text{shift}}{\text{scale}}
|
||||||
|
|
||||||
|
:Example:
|
||||||
|
|
||||||
|
>>> NormalizerDataCallback()
|
||||||
|
>>> NormalizerDataCallback(
|
||||||
|
... scale_fn: torch.std,
|
||||||
|
... shift_fn: torch.mean,
|
||||||
|
... stage: "all",
|
||||||
|
... apply_to: "input",
|
||||||
|
... )
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
scale_fn=torch.std,
|
||||||
|
shift_fn=torch.mean,
|
||||||
|
stage="all",
|
||||||
|
apply_to="input",
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialization of the :class:`NormalizerDataCallback` class.
|
||||||
|
|
||||||
|
:param Callable scale_fn: The function to compute the scaling factor.
|
||||||
|
Default is ``torch.std``.
|
||||||
|
:param Callable shift_fn: The function to compute the shifting factor.
|
||||||
|
Default is ``torch.mean``.
|
||||||
|
:param str stage: The stage in which normalization is applied.
|
||||||
|
Accepted values are "train", "validate", "test", or "all".
|
||||||
|
Default is ``"all"``.
|
||||||
|
:param str apply_to: Whether to normalize "input" or "target" data.
|
||||||
|
Default is ``"input"``.
|
||||||
|
:raises ValueError: If ``scale_fn`` is not callable.
|
||||||
|
:raises ValueError: If ``shift_fn`` is not callable.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# Validate parameters
|
||||||
|
self.apply_to = self._validate_apply_to(apply_to)
|
||||||
|
self.stage = self._validate_stage(stage)
|
||||||
|
|
||||||
|
# Validate functions
|
||||||
|
if not is_function(scale_fn):
|
||||||
|
raise ValueError(f"scale_fn must be Callable, got {scale_fn}")
|
||||||
|
if not is_function(shift_fn):
|
||||||
|
raise ValueError(f"shift_fn must be Callable, got {shift_fn}")
|
||||||
|
self.scale_fn = scale_fn
|
||||||
|
self.shift_fn = shift_fn
|
||||||
|
|
||||||
|
# Initialize normalizer dictionary
|
||||||
|
self._normalizer = {}
|
||||||
|
|
||||||
|
def _validate_apply_to(self, apply_to):
|
||||||
|
"""
|
||||||
|
Validate the ``apply_to`` parameter.
|
||||||
|
|
||||||
|
:param str apply_to: The candidate value for the ``apply_to`` parameter.
|
||||||
|
:raises ValueError: If ``apply_to`` is neither "input" nor "target".
|
||||||
|
:return: The validated ``apply_to`` value.
|
||||||
|
:rtype: str
|
||||||
|
"""
|
||||||
|
check_consistency(apply_to, str)
|
||||||
|
if apply_to not in {"input", "target"}:
|
||||||
|
raise ValueError(
|
||||||
|
f"apply_to must be either 'input' or 'target', got {apply_to}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return apply_to
|
||||||
|
|
||||||
|
def _validate_stage(self, stage):
|
||||||
|
"""
|
||||||
|
Validate the ``stage`` parameter.
|
||||||
|
|
||||||
|
:param str stage: The candidate value for the ``stage`` parameter.
|
||||||
|
:raises ValueError: If ``stage`` is not one of "train", "validate",
|
||||||
|
"test", or "all".
|
||||||
|
:return: The validated ``stage`` value.
|
||||||
|
:rtype: str
|
||||||
|
"""
|
||||||
|
check_consistency(stage, str)
|
||||||
|
if stage not in {"train", "validate", "test", "all"}:
|
||||||
|
raise ValueError(
|
||||||
|
"stage must be one of 'train', 'validate', 'test', or 'all',"
|
||||||
|
f" got {stage}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return stage
|
||||||
|
|
||||||
|
def setup(self, trainer, pl_module, stage):
|
||||||
|
"""
|
||||||
|
Apply normalization during setup.
|
||||||
|
|
||||||
|
:param Trainer trainer: A :class:`~pina.trainer.Trainer` instance.
|
||||||
|
:param SolverInterface pl_module: A
|
||||||
|
:class:`~pina.solver.solver.SolverInterface` instance.
|
||||||
|
:param str stage: The current stage.
|
||||||
|
:raises RuntimeError: If the training dataset is not available when
|
||||||
|
computing normalization parameters.
|
||||||
|
:return: The result of the parent setup.
|
||||||
|
:rtype: Any
|
||||||
|
|
||||||
|
:raises NotImplementedError: If the dataset is graph-based.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Ensure datsets are not graph-based
|
||||||
|
if isinstance(trainer.datamodule.train_dataset, PinaGraphDataset):
|
||||||
|
raise NotImplementedError(
|
||||||
|
"NormalizerDataCallback is not compatible with "
|
||||||
|
"graph-based datasets."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract conditions
|
||||||
|
conditions_to_normalize = [
|
||||||
|
name
|
||||||
|
for name, cond in pl_module.problem.conditions.items()
|
||||||
|
if isinstance(cond, InputTargetCondition)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Compute scale and shift parameters
|
||||||
|
if not self.normalizer:
|
||||||
|
if not trainer.datamodule.train_dataset:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Training dataset is not available. Cannot compute "
|
||||||
|
"normalization parameters."
|
||||||
|
)
|
||||||
|
self._compute_scale_shift(
|
||||||
|
conditions_to_normalize, trainer.datamodule.train_dataset
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply normalization based on the specified stage
|
||||||
|
if stage == "fit" and self.stage in ["train", "all"]:
|
||||||
|
self.normalize_dataset(trainer.datamodule.train_dataset)
|
||||||
|
if stage == "fit" and self.stage in ["validate", "all"]:
|
||||||
|
self.normalize_dataset(trainer.datamodule.val_dataset)
|
||||||
|
if stage == "test" and self.stage in ["test", "all"]:
|
||||||
|
self.normalize_dataset(trainer.datamodule.test_dataset)
|
||||||
|
|
||||||
|
return super().setup(trainer, pl_module, stage)
|
||||||
|
|
||||||
|
def _compute_scale_shift(self, conditions, dataset):
|
||||||
|
"""
|
||||||
|
Compute scale and shift parameters for each condition in the dataset.
|
||||||
|
|
||||||
|
:param list conditions: The list of condition names.
|
||||||
|
:param dataset: The `~pina.data.dataset.PinaDataset` dataset.
|
||||||
|
"""
|
||||||
|
for cond in conditions:
|
||||||
|
if cond in dataset.conditions_dict:
|
||||||
|
data = dataset.conditions_dict[cond][self.apply_to]
|
||||||
|
shift = self.shift_fn(data)
|
||||||
|
scale = self.scale_fn(data)
|
||||||
|
self._normalizer[cond] = {
|
||||||
|
"shift": shift,
|
||||||
|
"scale": scale,
|
||||||
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _norm_fn(value, scale, shift):
|
||||||
|
"""
|
||||||
|
Normalize a value according to the scale and shift parameters.
|
||||||
|
|
||||||
|
:param value: The input tensor to normalize.
|
||||||
|
:type value: torch.Tensor | LabelTensor
|
||||||
|
:param float scale: The scaling factor.
|
||||||
|
:param float shift: The shifting factor.
|
||||||
|
:return: The normalized tensor.
|
||||||
|
:rtype: torch.Tensor | LabelTensor
|
||||||
|
"""
|
||||||
|
scaled_value = (value - shift) / scale
|
||||||
|
if isinstance(value, LabelTensor):
|
||||||
|
scaled_value = LabelTensor(scaled_value, value.labels)
|
||||||
|
|
||||||
|
return scaled_value
|
||||||
|
|
||||||
|
def normalize_dataset(self, dataset):
|
||||||
|
"""
|
||||||
|
Apply in-place normalization to the dataset.
|
||||||
|
|
||||||
|
:param PinaDataset dataset: The dataset to be normalized.
|
||||||
|
"""
|
||||||
|
# Initialize update dictionary
|
||||||
|
update_dataset_dict = {}
|
||||||
|
|
||||||
|
# Iterate over conditions and apply normalization
|
||||||
|
for cond, norm_params in self.normalizer.items():
|
||||||
|
points = dataset.conditions_dict[cond][self.apply_to]
|
||||||
|
scale = norm_params["scale"]
|
||||||
|
shift = norm_params["shift"]
|
||||||
|
normalized_points = self._norm_fn(points, scale, shift)
|
||||||
|
update_dataset_dict[cond] = {
|
||||||
|
self.apply_to: (
|
||||||
|
LabelTensor(normalized_points, points.labels)
|
||||||
|
if isinstance(points, LabelTensor)
|
||||||
|
else normalized_points
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
# Update the dataset in-place
|
||||||
|
dataset.update_data(update_dataset_dict)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def normalizer(self):
|
||||||
|
"""
|
||||||
|
Get the dictionary of normalization parameters.
|
||||||
|
|
||||||
|
:return: The dictionary of normalization parameters.
|
||||||
|
:rtype: dict
|
||||||
|
"""
|
||||||
|
return self._normalizer
|
||||||
@@ -206,7 +206,7 @@ def is_function(f):
|
|||||||
:return: ``True`` if ``f`` is a function, ``False`` otherwise.
|
:return: ``True`` if ``f`` is a function, ``False`` otherwise.
|
||||||
:rtype: bool
|
:rtype: bool
|
||||||
"""
|
"""
|
||||||
return isinstance(f, (types.FunctionType, types.LambdaType))
|
return callable(f)
|
||||||
|
|
||||||
|
|
||||||
def chebyshev_roots(n):
|
def chebyshev_roots(n):
|
||||||
|
|||||||
244
tests/test_callback/test_normalizer_data_callback.py
Normal file
244
tests/test_callback/test_normalizer_data_callback.py
Normal file
@@ -0,0 +1,244 @@
|
|||||||
|
import torch
|
||||||
|
import pytest
|
||||||
|
from copy import deepcopy
|
||||||
|
|
||||||
|
from pina import Trainer, LabelTensor, Condition
|
||||||
|
from pina.solver import SupervisedSolver
|
||||||
|
from pina.model import FeedForward
|
||||||
|
from pina.callback import NormalizerDataCallback
|
||||||
|
from pina.problem import AbstractProblem
|
||||||
|
from pina.problem.zoo import Poisson2DSquareProblem as Poisson
|
||||||
|
from pina.solver import PINN
|
||||||
|
from pina.graph import RadiusGraph
|
||||||
|
|
||||||
|
# for checking normalization
|
||||||
|
stage_map = {
|
||||||
|
"train": ["train_dataset"],
|
||||||
|
"validate": ["val_dataset"],
|
||||||
|
"test": ["test_dataset"],
|
||||||
|
"all": ["train_dataset", "val_dataset", "test_dataset"],
|
||||||
|
}
|
||||||
|
|
||||||
|
input_1 = torch.rand(20, 2) * 10
|
||||||
|
target_1 = torch.rand(20, 1) * 10
|
||||||
|
input_2 = torch.rand(20, 2) * 5
|
||||||
|
target_2 = torch.rand(20, 1) * 5
|
||||||
|
|
||||||
|
|
||||||
|
class LabelTensorProblem(AbstractProblem):
|
||||||
|
input_variables = ["u_0", "u_1"]
|
||||||
|
output_variables = ["u"]
|
||||||
|
conditions = {
|
||||||
|
"data1": Condition(
|
||||||
|
input=LabelTensor(input_1, ["u_0", "u_1"]),
|
||||||
|
target=LabelTensor(target_1, ["u"]),
|
||||||
|
),
|
||||||
|
"data2": Condition(
|
||||||
|
input=LabelTensor(input_2, ["u_0", "u_1"]),
|
||||||
|
target=LabelTensor(target_2, ["u"]),
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class TensorProblem(AbstractProblem):
|
||||||
|
input_variables = ["u_0", "u_1"]
|
||||||
|
output_variables = ["u"]
|
||||||
|
conditions = {
|
||||||
|
"data1": Condition(input=input_1, target=target_1),
|
||||||
|
"data2": Condition(input=input_2, target=target_2),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
input_graph = [RadiusGraph(radius=0.5, pos=torch.rand(10, 2)) for _ in range(5)]
|
||||||
|
output_graph = torch.rand(5, 1)
|
||||||
|
|
||||||
|
|
||||||
|
class GraphProblem(AbstractProblem):
|
||||||
|
input_variables = ["u_0", "u_1"]
|
||||||
|
output_variables = ["u"]
|
||||||
|
conditions = {
|
||||||
|
"data": Condition(input=input_graph, target=output_graph),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
supervised_solver_no_lt = SupervisedSolver(
|
||||||
|
problem=TensorProblem(), model=FeedForward(2, 1), use_lt=False
|
||||||
|
)
|
||||||
|
supervised_solver_lt = SupervisedSolver(
|
||||||
|
problem=LabelTensorProblem(), model=FeedForward(2, 1), use_lt=True
|
||||||
|
)
|
||||||
|
|
||||||
|
poisson_problem = Poisson()
|
||||||
|
poisson_problem.conditions["data"] = Condition(
|
||||||
|
input=LabelTensor(torch.rand(20, 2) * 10, ["x", "y"]),
|
||||||
|
target=LabelTensor(torch.rand(20, 1) * 10, ["u"]),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("scale_fn", [torch.std, torch.var])
|
||||||
|
@pytest.mark.parametrize("shift_fn", [torch.mean, torch.median])
|
||||||
|
@pytest.mark.parametrize("apply_to", ["input", "target"])
|
||||||
|
@pytest.mark.parametrize("stage", ["train", "validate", "test", "all"])
|
||||||
|
def test_init(scale_fn, shift_fn, apply_to, stage):
|
||||||
|
normalizer = NormalizerDataCallback(
|
||||||
|
scale_fn=scale_fn, shift_fn=shift_fn, apply_to=apply_to, stage=stage
|
||||||
|
)
|
||||||
|
assert normalizer.scale_fn == scale_fn
|
||||||
|
assert normalizer.shift_fn == shift_fn
|
||||||
|
assert normalizer.apply_to == apply_to
|
||||||
|
assert normalizer.stage == stage
|
||||||
|
|
||||||
|
|
||||||
|
def test_init_invalid_scale():
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
NormalizerDataCallback(scale_fn=1)
|
||||||
|
|
||||||
|
|
||||||
|
def test_init_invalid_shift():
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
NormalizerDataCallback(shift_fn=1)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("invalid_apply_to", ["inputt", "targett", 1])
|
||||||
|
def test_init_invalid_apply_to(invalid_apply_to):
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
NormalizerDataCallback(apply_to=invalid_apply_to)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("invalid_stage", ["trainn", "validatee", 1])
|
||||||
|
def test_init_invalid_stage(invalid_stage):
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
NormalizerDataCallback(stage=invalid_stage)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"solver", [supervised_solver_lt, supervised_solver_no_lt]
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"fn", [[torch.std, torch.mean], [torch.var, torch.median]]
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize("apply_to", ["input", "target"])
|
||||||
|
@pytest.mark.parametrize("stage", ["all", "train", "validate", "test"])
|
||||||
|
def test_setup(solver, fn, stage, apply_to):
|
||||||
|
scale_fn, shift_fn = fn
|
||||||
|
trainer = Trainer(
|
||||||
|
solver=solver,
|
||||||
|
callbacks=NormalizerDataCallback(
|
||||||
|
scale_fn=scale_fn, shift_fn=shift_fn, stage=stage, apply_to=apply_to
|
||||||
|
),
|
||||||
|
max_epochs=1,
|
||||||
|
train_size=0.4,
|
||||||
|
val_size=0.3,
|
||||||
|
test_size=0.3,
|
||||||
|
shuffle=False,
|
||||||
|
)
|
||||||
|
trainer_copy = deepcopy(trainer)
|
||||||
|
trainer_copy.data_module.setup("fit")
|
||||||
|
trainer_copy.data_module.setup("test")
|
||||||
|
trainer.train()
|
||||||
|
trainer.test()
|
||||||
|
|
||||||
|
normalizer = trainer.callbacks[0].normalizer
|
||||||
|
|
||||||
|
for cond in ["data1", "data2"]:
|
||||||
|
scale = scale_fn(
|
||||||
|
trainer_copy.data_module.train_dataset.conditions_dict[cond][
|
||||||
|
apply_to
|
||||||
|
]
|
||||||
|
)
|
||||||
|
shift = shift_fn(
|
||||||
|
trainer_copy.data_module.train_dataset.conditions_dict[cond][
|
||||||
|
apply_to
|
||||||
|
]
|
||||||
|
)
|
||||||
|
assert "scale" in normalizer[cond]
|
||||||
|
assert "shift" in normalizer[cond]
|
||||||
|
assert normalizer[cond]["scale"] - scale < 1e-5
|
||||||
|
assert normalizer[cond]["shift"] - shift < 1e-5
|
||||||
|
for ds_name in stage_map[stage]:
|
||||||
|
dataset = getattr(trainer.data_module, ds_name, None)
|
||||||
|
old_dataset = getattr(trainer_copy.data_module, ds_name, None)
|
||||||
|
current_points = dataset.conditions_dict[cond][apply_to]
|
||||||
|
old_points = old_dataset.conditions_dict[cond][apply_to]
|
||||||
|
expected = (old_points - shift) / scale
|
||||||
|
assert torch.allclose(current_points, expected)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"fn", [[torch.std, torch.mean], [torch.var, torch.median]]
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize("apply_to", ["input"])
|
||||||
|
@pytest.mark.parametrize("stage", ["all", "train", "validate", "test"])
|
||||||
|
def test_setup_pinn(fn, stage, apply_to):
|
||||||
|
scale_fn, shift_fn = fn
|
||||||
|
pinn = PINN(
|
||||||
|
problem=poisson_problem,
|
||||||
|
model=FeedForward(2, 1),
|
||||||
|
)
|
||||||
|
poisson_problem.discretise_domain(n=10)
|
||||||
|
trainer = Trainer(
|
||||||
|
solver=pinn,
|
||||||
|
callbacks=NormalizerDataCallback(
|
||||||
|
scale_fn=scale_fn,
|
||||||
|
shift_fn=shift_fn,
|
||||||
|
stage=stage,
|
||||||
|
apply_to=apply_to,
|
||||||
|
),
|
||||||
|
max_epochs=1,
|
||||||
|
train_size=0.4,
|
||||||
|
val_size=0.3,
|
||||||
|
test_size=0.3,
|
||||||
|
shuffle=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
trainer_copy = deepcopy(trainer)
|
||||||
|
trainer_copy.data_module.setup("fit")
|
||||||
|
trainer_copy.data_module.setup("test")
|
||||||
|
trainer.train()
|
||||||
|
trainer.test()
|
||||||
|
|
||||||
|
conditions = trainer.callbacks[0].normalizer.keys()
|
||||||
|
assert "data" in conditions
|
||||||
|
assert len(conditions) == 1
|
||||||
|
normalizer = trainer.callbacks[0].normalizer
|
||||||
|
cond = "data"
|
||||||
|
|
||||||
|
scale = scale_fn(
|
||||||
|
trainer_copy.data_module.train_dataset.conditions_dict[cond][apply_to]
|
||||||
|
)
|
||||||
|
shift = shift_fn(
|
||||||
|
trainer_copy.data_module.train_dataset.conditions_dict[cond][apply_to]
|
||||||
|
)
|
||||||
|
assert "scale" in normalizer[cond]
|
||||||
|
assert "shift" in normalizer[cond]
|
||||||
|
assert normalizer[cond]["scale"] - scale < 1e-5
|
||||||
|
assert normalizer[cond]["shift"] - shift < 1e-5
|
||||||
|
for ds_name in stage_map[stage]:
|
||||||
|
dataset = getattr(trainer.data_module, ds_name, None)
|
||||||
|
old_dataset = getattr(trainer_copy.data_module, ds_name, None)
|
||||||
|
current_points = dataset.conditions_dict[cond][apply_to]
|
||||||
|
old_points = old_dataset.conditions_dict[cond][apply_to]
|
||||||
|
expected = (old_points - shift) / scale
|
||||||
|
assert torch.allclose(current_points, expected)
|
||||||
|
|
||||||
|
|
||||||
|
def test_setup_graph_dataset():
|
||||||
|
solver = SupervisedSolver(
|
||||||
|
problem=GraphProblem(), model=FeedForward(2, 1), use_lt=False
|
||||||
|
)
|
||||||
|
trainer = Trainer(
|
||||||
|
solver=solver,
|
||||||
|
callbacks=NormalizerDataCallback(
|
||||||
|
scale_fn=torch.std,
|
||||||
|
shift_fn=torch.mean,
|
||||||
|
stage="all",
|
||||||
|
apply_to="input",
|
||||||
|
),
|
||||||
|
max_epochs=1,
|
||||||
|
train_size=0.4,
|
||||||
|
val_size=0.3,
|
||||||
|
test_size=0.3,
|
||||||
|
shuffle=False,
|
||||||
|
)
|
||||||
|
with pytest.raises(NotImplementedError):
|
||||||
|
trainer.train()
|
||||||
Reference in New Issue
Block a user