From c42bdd575c388f3c4fbd2a25edb7e2a8e512fcfd Mon Sep 17 00:00:00 2001 From: giovanni Date: Fri, 29 Aug 2025 19:11:48 +0200 Subject: [PATCH] add self-adaptive weighting --- docs/source/_rst/_code.rst | 1 + .../_rst/loss/self_adaptive_weighting.rst | 9 +++ pina/loss/__init__.py | 2 + pina/loss/self_adaptive_weighting.py | 80 +++++++++++++++++++ .../test_self_adaptive_weighting.py | 37 +++++++++ 5 files changed, 129 insertions(+) create mode 100644 docs/source/_rst/loss/self_adaptive_weighting.rst create mode 100644 pina/loss/self_adaptive_weighting.py create mode 100644 tests/test_weighting/test_self_adaptive_weighting.py diff --git a/docs/source/_rst/_code.rst b/docs/source/_rst/_code.rst index 9bd36ab..2bb62a4 100644 --- a/docs/source/_rst/_code.rst +++ b/docs/source/_rst/_code.rst @@ -267,3 +267,4 @@ Losses and Weightings WeightingInterface ScalarWeighting NeuralTangentKernelWeighting + SelfAdaptiveWeighting \ No newline at end of file diff --git a/docs/source/_rst/loss/self_adaptive_weighting.rst b/docs/source/_rst/loss/self_adaptive_weighting.rst new file mode 100644 index 0000000..cd1daed --- /dev/null +++ b/docs/source/_rst/loss/self_adaptive_weighting.rst @@ -0,0 +1,9 @@ +SelfAdaptiveWeighting +============================= +.. currentmodule:: pina.loss.self_adaptive_weighting + +.. automodule:: pina.loss.self_adaptive_weighting + +.. autoclass:: SelfAdaptiveWeighting + :members: + :show-inheritance: \ No newline at end of file diff --git a/pina/loss/__init__.py b/pina/loss/__init__.py index 2f15c6d..fc47e62 100644 --- a/pina/loss/__init__.py +++ b/pina/loss/__init__.py @@ -7,6 +7,7 @@ __all__ = [ "WeightingInterface", "ScalarWeighting", "NeuralTangentKernelWeighting", + "SelfAdaptiveWeighting", ] from .loss_interface import LossInterface @@ -15,3 +16,4 @@ from .lp_loss import LpLoss from .weighting_interface import WeightingInterface from .scalar_weighting import ScalarWeighting from .ntk_weighting import NeuralTangentKernelWeighting +from .self_adaptive_weighting import SelfAdaptiveWeighting diff --git a/pina/loss/self_adaptive_weighting.py b/pina/loss/self_adaptive_weighting.py new file mode 100644 index 0000000..8533078 --- /dev/null +++ b/pina/loss/self_adaptive_weighting.py @@ -0,0 +1,80 @@ +"""Module for Self-Adaptive Weighting class.""" + +import torch +from .weighting_interface import WeightingInterface +from ..utils import check_positive_integer + + +class SelfAdaptiveWeighting(WeightingInterface): + """ + A self-adaptive weighting scheme to tackle the imbalance among the loss + components. This formulation equalizes the gradient norms of the losses, + preventing bias toward any particular term during training. + + .. seealso:: + + **Original reference**: + Wang, S., Sankaran, S., Stinis., P., Perdikaris, P. (2025). + *Simulating Three-dimensional Turbulence with Physics-informed Neural + Networks*. + DOI: `arXiv preprint arXiv:2507.08972. + `_ + + """ + + def __init__(self, k=100): + """ + Initialization of the :class:`SelfAdaptiveWeighting` class. + + :param int k: The number of epochs after which the weights are updated. + Default is 100. + + :raises ValueError: If ``k`` is not a positive integer. + """ + super().__init__() + + # Check consistency + check_positive_integer(value=k, strict=True) + + # Initialize parameters + self.k = k + self.weights = {} + self.default_value_weights = 1.0 + + def aggregate(self, losses): + """ + Weight the losses according to the self-adaptive algorithm. + + :param dict(torch.Tensor) losses: The dictionary of losses. + :return: The aggregation of the losses. It should be a scalar Tensor. + :rtype: torch.Tensor + """ + # If weights have not been initialized, set them to 1 + if not self.weights: + self.weights = { + condition: self.default_value_weights for condition in losses + } + + # Update every k epochs + if self.solver.trainer.current_epoch % self.k == 0: + + # Define a dictionary to store the norms of the gradients + losses_norm = {} + + # Compute the gradient norms for each loss component + for condition, loss in losses.items(): + loss.backward(retain_graph=True) + grads = torch.cat( + [p.grad.flatten() for p in self.solver.model.parameters()] + ) + losses_norm[condition] = grads.norm() + + # Update the weights + self.weights = { + condition: sum(losses_norm.values()) / losses_norm[condition] + for condition in losses + } + + return sum( + self.weights[condition] * loss for condition, loss in losses.items() + ) diff --git a/tests/test_weighting/test_self_adaptive_weighting.py b/tests/test_weighting/test_self_adaptive_weighting.py new file mode 100644 index 0000000..b82f545 --- /dev/null +++ b/tests/test_weighting/test_self_adaptive_weighting.py @@ -0,0 +1,37 @@ +import pytest +from pina import Trainer +from pina.solver import PINN +from pina.model import FeedForward +from pina.loss import SelfAdaptiveWeighting +from pina.problem.zoo import Poisson2DSquareProblem + + +# Initialize problem and model +problem = Poisson2DSquareProblem() +problem.discretise_domain(10) +model = FeedForward(len(problem.input_variables), len(problem.output_variables)) + + +@pytest.mark.parametrize("k", [10, 100, 1000]) +def test_constructor(k): + SelfAdaptiveWeighting(k=k) + + # Should fail if k is not an integer + with pytest.raises(AssertionError): + SelfAdaptiveWeighting(k=1.5) + + # Should fail if k is not > 0 + with pytest.raises(AssertionError): + SelfAdaptiveWeighting(k=0) + + # Should fail if k is not > 0 + with pytest.raises(AssertionError): + SelfAdaptiveWeighting(k=-3) + + +@pytest.mark.parametrize("k", [2, 3]) +def test_train_aggregation(k): + weighting = SelfAdaptiveWeighting(k=k) + solver = PINN(problem=problem, model=model, weighting=weighting) + trainer = Trainer(solver=solver, max_epochs=5, accelerator="cpu") + trainer.train()