From ef3542486c52c429bf1ed3099627e56be3c6eb2f Mon Sep 17 00:00:00 2001 From: giovanni Date: Fri, 5 Sep 2025 10:12:09 +0200 Subject: [PATCH] add linear weighting --- docs/source/_rst/_code.rst | 4 +- docs/source/_rst/loss/linear_weighting.rst | 9 ++ pina/loss/__init__.py | 2 + pina/loss/linear_weighting.py | 64 +++++++++++++ pina/loss/ntk_weighting.py | 5 +- pina/loss/scalar_weighting.py | 6 +- tests/test_weighting/test_linear_weighting.py | 95 +++++++++++++++++++ 7 files changed, 176 insertions(+), 9 deletions(-) create mode 100644 docs/source/_rst/loss/linear_weighting.rst create mode 100644 pina/loss/linear_weighting.py create mode 100644 tests/test_weighting/test_linear_weighting.py diff --git a/docs/source/_rst/_code.rst b/docs/source/_rst/_code.rst index 2bb62a4..a724256 100644 --- a/docs/source/_rst/_code.rst +++ b/docs/source/_rst/_code.rst @@ -253,7 +253,6 @@ Callbacks Optimizer callback R3 Refinment callback Refinment Interface callback - Weighting callback Losses and Weightings --------------------- @@ -267,4 +266,5 @@ Losses and Weightings WeightingInterface ScalarWeighting NeuralTangentKernelWeighting - SelfAdaptiveWeighting \ No newline at end of file + SelfAdaptiveWeighting + LinearWeighting \ No newline at end of file diff --git a/docs/source/_rst/loss/linear_weighting.rst b/docs/source/_rst/loss/linear_weighting.rst new file mode 100644 index 0000000..16e6232 --- /dev/null +++ b/docs/source/_rst/loss/linear_weighting.rst @@ -0,0 +1,9 @@ +LinearWeighting +============================= +.. currentmodule:: pina.loss.linear_weighting + +.. automodule:: pina.loss.linear_weighting + +.. autoclass:: LinearWeighting + :members: + :show-inheritance: diff --git a/pina/loss/__init__.py b/pina/loss/__init__.py index fc47e62..d91cf7a 100644 --- a/pina/loss/__init__.py +++ b/pina/loss/__init__.py @@ -8,6 +8,7 @@ __all__ = [ "ScalarWeighting", "NeuralTangentKernelWeighting", "SelfAdaptiveWeighting", + "LinearWeighting", ] from .loss_interface import LossInterface @@ -17,3 +18,4 @@ from .weighting_interface import WeightingInterface from .scalar_weighting import ScalarWeighting from .ntk_weighting import NeuralTangentKernelWeighting from .self_adaptive_weighting import SelfAdaptiveWeighting +from .linear_weighting import LinearWeighting diff --git a/pina/loss/linear_weighting.py b/pina/loss/linear_weighting.py new file mode 100644 index 0000000..9049b52 --- /dev/null +++ b/pina/loss/linear_weighting.py @@ -0,0 +1,64 @@ +"""Module for the LinearWeighting class.""" + +from ..loss import WeightingInterface +from ..utils import check_consistency, check_positive_integer + + +class LinearWeighting(WeightingInterface): + """ + A weighting scheme that linearly scales weights from initial values to final + values over a specified number of epochs. + """ + + def __init__(self, initial_weights, final_weights, target_epoch): + """ + :param dict initial_weights: The weights to be assigned to each loss + term at the beginning of training. The keys are the conditions and + the values are the corresponding weights. If a condition is not + present in the dictionary, the default value (1) is used. + :param dict final_weights: The weights to be assigned to each loss term + once the target epoch is reached. The keys are the conditions and + the values are the corresponding weights. If a condition is not + present in the dictionary, the default value (1) is used. + :param int target_epoch: The epoch at which the weights reach their + final values. + :raises ValueError: If the keys of the two dictionaries are not + consistent. + """ + super().__init__(update_every_n_epochs=1, aggregator="sum") + + # Check consistency + check_consistency([initial_weights, final_weights], dict) + check_positive_integer(value=target_epoch, strict=True) + + # Check that the keys of the two dictionaries are the same + if initial_weights.keys() != final_weights.keys(): + raise ValueError( + "The keys of the initial_weights and final_weights " + "dictionaries must be the same." + ) + + # Initialization + self.initial_weights = initial_weights + self.final_weights = final_weights + self.target_epoch = target_epoch + + def weights_update(self, losses): + """ + Update the weighting scheme based on the given losses. + + :param dict losses: The dictionary of losses. + :return: The updated weights. + :rtype: dict + """ + return { + condition: self.last_saved_weights().get( + condition, self.initial_weights.get(condition, 1) + ) + + ( + self.final_weights.get(condition, 1) + - self.initial_weights.get(condition, 1) + ) + / (self.target_epoch) + for condition in losses.keys() + } diff --git a/pina/loss/ntk_weighting.py b/pina/loss/ntk_weighting.py index b888126..fe67115 100644 --- a/pina/loss/ntk_weighting.py +++ b/pina/loss/ntk_weighting.py @@ -61,11 +61,10 @@ class NeuralTangentKernelWeighting(WeightingInterface): losses_norm[condition] = grads.norm() # Update the weights - self.weights = { - condition: self.alpha * self.weights.get(condition, 1) + return { + condition: self.alpha * self.last_saved_weights().get(condition, 1) + (1 - self.alpha) * losses_norm[condition] / sum(losses_norm.values()) for condition in losses } - return self.weights diff --git a/pina/loss/scalar_weighting.py b/pina/loss/scalar_weighting.py index d770c89..692c493 100644 --- a/pina/loss/scalar_weighting.py +++ b/pina/loss/scalar_weighting.py @@ -17,7 +17,7 @@ class ScalarWeighting(WeightingInterface): If a single scalar value is provided, it is assigned to all loss terms. If a dictionary is provided, the keys are the conditions and the values are the weights. If a condition is not present in the - dictionary, the default value is used. + dictionary, the default value (1) is used. :type weights: float | int | dict """ super().__init__(update_every_n_epochs=1, aggregator="sum") @@ -29,11 +29,9 @@ class ScalarWeighting(WeightingInterface): if isinstance(weights, dict): self.values = weights self.default_value_weights = 1 - elif isinstance(weights, (float, int)): + else: self.values = {} self.default_value_weights = weights - else: - raise ValueError def weights_update(self, losses): """ diff --git a/tests/test_weighting/test_linear_weighting.py b/tests/test_weighting/test_linear_weighting.py new file mode 100644 index 0000000..a119520 --- /dev/null +++ b/tests/test_weighting/test_linear_weighting.py @@ -0,0 +1,95 @@ +import math +import pytest +from pina import Trainer +from pina.solver import PINN +from pina.model import FeedForward +from pina.loss import LinearWeighting +from pina.problem.zoo import Poisson2DSquareProblem + + +# Initialize problem and model +problem = Poisson2DSquareProblem() +problem.discretise_domain(10) +model = FeedForward(len(problem.input_variables), len(problem.output_variables)) + +# Weights for testing +init_weight_1 = {cond: 3 for cond in problem.conditions.keys()} +init_weight_2 = {cond: 4 for cond in problem.conditions.keys()} +final_weight_1 = {cond: 1 for cond in problem.conditions.keys()} +final_weight_2 = {cond: 5 for cond in problem.conditions.keys()} + + +@pytest.mark.parametrize("initial_weights", [init_weight_1, init_weight_2]) +@pytest.mark.parametrize("final_weights", [final_weight_1, final_weight_2]) +@pytest.mark.parametrize("target_epoch", [5, 10]) +def test_constructor(initial_weights, final_weights, target_epoch): + LinearWeighting( + initial_weights=initial_weights, + final_weights=final_weights, + target_epoch=target_epoch, + ) + + # Should fail if initial_weights is not a dictionary + with pytest.raises(ValueError): + LinearWeighting( + initial_weights=[1, 1, 1], + final_weights=final_weights, + target_epoch=target_epoch, + ) + + # Should fail if final_weights is not a dictionary + with pytest.raises(ValueError): + LinearWeighting( + initial_weights=initial_weights, + final_weights=[1, 1, 1], + target_epoch=target_epoch, + ) + + # Should fail if target_epoch is not an integer + with pytest.raises(AssertionError): + LinearWeighting( + initial_weights=initial_weights, + final_weights=final_weights, + target_epoch=1.5, + ) + + # Should fail if target_epoch is not positive + with pytest.raises(AssertionError): + LinearWeighting( + initial_weights=initial_weights, + final_weights=final_weights, + target_epoch=0, + ) + + # Should fail if dictionary keys do not match + with pytest.raises(ValueError): + LinearWeighting( + initial_weights={list(initial_weights.keys())[0]: 1}, + final_weights=final_weights, + target_epoch=target_epoch, + ) + + +@pytest.mark.parametrize("initial_weights", [init_weight_1, init_weight_2]) +@pytest.mark.parametrize("final_weights", [final_weight_1, final_weight_2]) +@pytest.mark.parametrize("target_epoch", [5, 10]) +def test_train_aggregation(initial_weights, final_weights, target_epoch): + weighting = LinearWeighting( + initial_weights=initial_weights, + final_weights=final_weights, + target_epoch=target_epoch, + ) + solver = PINN(problem=problem, model=model, weighting=weighting) + trainer = Trainer(solver=solver, max_epochs=target_epoch, accelerator="cpu") + trainer.train() + + # Check that weights are updated correctly + assert all( + math.isclose( + weighting.last_saved_weights()[cond], + final_weights[cond], + rel_tol=1e-5, + abs_tol=1e-8, + ) + for cond in final_weights.keys() + )