From bacd7e202ac2857c188702aec0328dfbac4f3308 Mon Sep 17 00:00:00 2001 From: giovanni Date: Fri, 29 Aug 2025 19:11:08 +0200 Subject: [PATCH] add mutual solver-weighting link --- pina/loss/ntk_weighting.py | 38 ++++++------ pina/loss/scalar_weighting.py | 6 +- pina/loss/weighting_interface.py | 12 +++- pina/solver/solver.py | 2 +- tests/test_weighting/test_ntk_weighting.py | 62 +++++-------------- ..._weighting.py => test_scalar_weighting.py} | 18 +++--- 6 files changed, 62 insertions(+), 76 deletions(-) rename tests/test_weighting/{test_standard_weighting.py => test_scalar_weighting.py} (82%) diff --git a/pina/loss/ntk_weighting.py b/pina/loss/ntk_weighting.py index d8c947f..6149f23 100644 --- a/pina/loss/ntk_weighting.py +++ b/pina/loss/ntk_weighting.py @@ -1,7 +1,6 @@ """Module for Neural Tangent Kernel Class""" import torch -from torch.nn import Module from .weighting_interface import WeightingInterface from ..utils import check_consistency @@ -21,43 +20,45 @@ class NeuralTangentKernelWeighting(WeightingInterface): """ - def __init__(self, model, alpha=0.5): + def __init__(self, alpha=0.5): """ Initialization of the :class:`NeuralTangentKernelWeighting` class. - :param torch.nn.Module model: The neural network model. :param float alpha: The alpha parameter. - :raises ValueError: If ``alpha`` is not between 0 and 1 (inclusive). """ - super().__init__() + + # Check consistency check_consistency(alpha, float) - check_consistency(model, Module) if alpha < 0 or alpha > 1: raise ValueError("alpha should be a value between 0 and 1") + + # Initialize parameters self.alpha = alpha - self.model = model self.weights = {} - self.default_value_weights = 1 + self.default_value_weights = 1.0 def aggregate(self, losses): """ - Weight the losses according to the Neural Tangent Kernel - algorithm. + Weight the losses according to the Neural Tangent Kernel algorithm. :param dict(torch.Tensor) input: The dictionary of losses. - :return: The losses aggregation. It should be a scalar Tensor. + :return: The aggregation of the losses. It should be a scalar Tensor. :rtype: torch.Tensor """ + # Define a dictionary to store the norms of the gradients losses_norm = {} - for condition in losses: - losses[condition].backward(retain_graph=True) - grads = [] - for param in self.model.parameters(): - grads.append(param.grad.view(-1)) - grads = torch.cat(grads) - losses_norm[condition] = torch.norm(grads) + + # Compute the gradient norms for each loss component + for condition, loss in losses.items(): + loss.backward(retain_graph=True) + grads = torch.cat( + [p.grad.flatten() for p in self.solver.model.parameters()] + ) + losses_norm[condition] = grads.norm() + + # Update the weights self.weights = { condition: self.alpha * self.weights.get(condition, self.default_value_weights) @@ -66,6 +67,7 @@ class NeuralTangentKernelWeighting(WeightingInterface): / sum(losses_norm.values()) for condition in losses } + return sum( self.weights[condition] * loss for condition, loss in losses.items() ) diff --git a/pina/loss/scalar_weighting.py b/pina/loss/scalar_weighting.py index 6bc093c..c10b574 100644 --- a/pina/loss/scalar_weighting.py +++ b/pina/loss/scalar_weighting.py @@ -37,12 +37,16 @@ class ScalarWeighting(WeightingInterface): :type weights: float | int | dict """ super().__init__() + + # Check consistency check_consistency([weights], (float, dict, int)) + + # Weights initialization if isinstance(weights, (float, int)): self.default_value_weights = weights self.weights = {} else: - self.default_value_weights = 1 + self.default_value_weights = 1.0 self.weights = weights def aggregate(self, losses): diff --git a/pina/loss/weighting_interface.py b/pina/loss/weighting_interface.py index 8b8cb2f..567d493 100644 --- a/pina/loss/weighting_interface.py +++ b/pina/loss/weighting_interface.py @@ -13,7 +13,7 @@ class WeightingInterface(metaclass=ABCMeta): """ Initialization of the :class:`WeightingInterface` class. """ - self.condition_names = None + self._solver = None @abstractmethod def aggregate(self, losses): @@ -22,3 +22,13 @@ class WeightingInterface(metaclass=ABCMeta): :param dict losses: The dictionary of losses. """ + + @property + def solver(self): + """ + The solver employing this weighting schema. + + :return: The solver. + :rtype: :class:`~pina.solver.SolverInterface` + """ + return self._solver diff --git a/pina/solver/solver.py b/pina/solver/solver.py index f3ff405..6948ec6 100644 --- a/pina/solver/solver.py +++ b/pina/solver/solver.py @@ -44,7 +44,7 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta): weighting = _NoWeighting() check_consistency(weighting, WeightingInterface) self._pina_weighting = weighting - weighting.condition_names = list(self._pina_problem.conditions.keys()) + weighting._solver = self # check consistency use_lt check_consistency(use_lt, bool) diff --git a/tests/test_weighting/test_ntk_weighting.py b/tests/test_weighting/test_ntk_weighting.py index 840237f..236c498 100644 --- a/tests/test_weighting/test_ntk_weighting.py +++ b/tests/test_weighting/test_ntk_weighting.py @@ -2,64 +2,32 @@ import pytest from pina import Trainer from pina.solver import PINN from pina.model import FeedForward -from pina.problem.zoo import Poisson2DSquareProblem from pina.loss import NeuralTangentKernelWeighting +from pina.problem.zoo import Poisson2DSquareProblem + +# Initialize problem and model problem = Poisson2DSquareProblem() -condition_names = problem.conditions.keys() +problem.discretise_domain(10) +model = FeedForward(len(problem.input_variables), len(problem.output_variables)) -@pytest.mark.parametrize( - "model,alpha", - [ - ( - FeedForward( - len(problem.input_variables), len(problem.output_variables) - ), - 0.5, - ) - ], -) -def test_constructor(model, alpha): - NeuralTangentKernelWeighting(model=model, alpha=alpha) +@pytest.mark.parametrize("alpha", [0.0, 0.5, 1.0]) +def test_constructor(alpha): + NeuralTangentKernelWeighting(alpha=alpha) - -@pytest.mark.parametrize("model", [0.5]) -def test_wrong_constructor1(model): + # Should fail if alpha is not >= 0 with pytest.raises(ValueError): - NeuralTangentKernelWeighting(model) + NeuralTangentKernelWeighting(alpha=-0.1) - -@pytest.mark.parametrize( - "model,alpha", - [ - ( - FeedForward( - len(problem.input_variables), len(problem.output_variables) - ), - 1.2, - ) - ], -) -def test_wrong_constructor2(model, alpha): + # Should fail if alpha is not <= 1 with pytest.raises(ValueError): - NeuralTangentKernelWeighting(model, alpha) + NeuralTangentKernelWeighting(alpha=1.1) -@pytest.mark.parametrize( - "model,alpha", - [ - ( - FeedForward( - len(problem.input_variables), len(problem.output_variables) - ), - 0.5, - ) - ], -) -def test_train_aggregation(model, alpha): - weighting = NeuralTangentKernelWeighting(model=model, alpha=alpha) - problem.discretise_domain(50) +@pytest.mark.parametrize("alpha", [0.0, 0.5, 1.0]) +def test_train_aggregation(alpha): + weighting = NeuralTangentKernelWeighting(alpha=alpha) solver = PINN(problem=problem, model=model, weighting=weighting) trainer = Trainer(solver=solver, max_epochs=5, accelerator="cpu") trainer.train() diff --git a/tests/test_weighting/test_standard_weighting.py b/tests/test_weighting/test_scalar_weighting.py similarity index 82% rename from tests/test_weighting/test_standard_weighting.py rename to tests/test_weighting/test_scalar_weighting.py index 9caa89a..54b3293 100644 --- a/tests/test_weighting/test_standard_weighting.py +++ b/tests/test_weighting/test_scalar_weighting.py @@ -1,16 +1,17 @@ import pytest import torch - from pina import Trainer from pina.solver import PINN from pina.model import FeedForward -from pina.problem.zoo import Poisson2DSquareProblem from pina.loss import ScalarWeighting +from pina.problem.zoo import Poisson2DSquareProblem + +# Initialize problem and model problem = Poisson2DSquareProblem() +problem.discretise_domain(50) model = FeedForward(len(problem.input_variables), len(problem.output_variables)) condition_names = problem.conditions.keys() -print(problem.conditions.keys()) @pytest.mark.parametrize( @@ -19,11 +20,13 @@ print(problem.conditions.keys()) def test_constructor(weights): ScalarWeighting(weights=weights) - -@pytest.mark.parametrize("weights", ["a", [1, 2, 3]]) -def test_wrong_constructor(weights): + # Should fail if weights are not a scalar with pytest.raises(ValueError): - ScalarWeighting(weights=weights) + ScalarWeighting(weights="invalid") + + # Should fail if weights are not a dictionary + with pytest.raises(ValueError): + ScalarWeighting(weights=[1, 2, 3]) @pytest.mark.parametrize( @@ -45,7 +48,6 @@ def test_aggregate(weights): ) def test_train_aggregation(weights): weighting = ScalarWeighting(weights=weights) - problem.discretise_domain(50) solver = PINN(problem=problem, model=model, weighting=weighting) trainer = Trainer(solver=solver, max_epochs=5, accelerator="cpu") trainer.train()