From 01aeb17673c6446125334bf9ff1f9ede1c4d4758 Mon Sep 17 00:00:00 2001 From: Giuseppe Alessio D'Inverno <66356297+AleDinve@users.noreply.github.com> Date: Wed, 19 Mar 2025 12:44:07 -0400 Subject: [PATCH] Neural Tangent Kernel integration + typo fix (#505) * NTK weighting + typo fixing * black code formatter + .rst docs --------- Co-authored-by: Dario Coscia --- docs/source/_rst/loss/ntk_weighting.rst | 9 +++ pina/loss/__init__.py | 2 + pina/loss/ntk_weighting.py | 71 ++++++++++++++++++++++ tests/test_weighting/test_ntk_weighting.py | 65 ++++++++++++++++++++ 4 files changed, 147 insertions(+) create mode 100644 docs/source/_rst/loss/ntk_weighting.rst create mode 100644 pina/loss/ntk_weighting.py create mode 100644 tests/test_weighting/test_ntk_weighting.py diff --git a/docs/source/_rst/loss/ntk_weighting.rst b/docs/source/_rst/loss/ntk_weighting.rst new file mode 100644 index 0000000..6d9d881 --- /dev/null +++ b/docs/source/_rst/loss/ntk_weighting.rst @@ -0,0 +1,9 @@ +NeuralTangentKernelWeighting +============================= +.. currentmodule:: pina.loss.ntk_weighting + +.. automodule:: pina.loss.ntk_weighting + +.. autoclass:: NeuralTangentKernelWeighting + :members: + :show-inheritance: diff --git a/pina/loss/__init__.py b/pina/loss/__init__.py index 4c57f9b..2f15c6d 100644 --- a/pina/loss/__init__.py +++ b/pina/loss/__init__.py @@ -6,6 +6,7 @@ __all__ = [ "PowerLoss", "WeightingInterface", "ScalarWeighting", + "NeuralTangentKernelWeighting", ] from .loss_interface import LossInterface @@ -13,3 +14,4 @@ from .power_loss import PowerLoss from .lp_loss import LpLoss from .weighting_interface import WeightingInterface from .scalar_weighting import ScalarWeighting +from .ntk_weighting import NeuralTangentKernelWeighting diff --git a/pina/loss/ntk_weighting.py b/pina/loss/ntk_weighting.py new file mode 100644 index 0000000..92fd002 --- /dev/null +++ b/pina/loss/ntk_weighting.py @@ -0,0 +1,71 @@ +"""Module for Neural Tangent Kernel Class""" + +import torch +from torch.nn import Module +from .weighting_interface import WeightingInterface +from ..utils import check_consistency + + +class NeuralTangentKernelWeighting(WeightingInterface): + """ + A neural tangent kernel scheme for weighting different losses to + boost the convergence. + + .. seealso:: + + **Original reference**: Wang, Sifan, Xinling Yu, and + Paris Perdikaris. *When and why PINNs fail to train: + A neural tangent kernel perspective*. Journal of + Computational Physics 449 (2022): 110768. + DOI: `10.1016/j.jcp.2021.110768 `_. + + + + """ + + def __init__(self, model, alpha=0.5): + """ + Initialization of the :class:`NeuralTangentKernelWeighting` class. + + :param torch.nn.Module model: The neural network model. + :param float alpha: The alpha parameter. + """ + + super().__init__() + check_consistency(alpha, float) + check_consistency(model, Module) + if alpha < 0 or alpha > 1: + raise ValueError("alpha should be a value between 0 and 1") + self.alpha = alpha + self.model = model + self.weights = {} + self.default_value_weights = 1 + + def aggregate(self, losses): + """ + Weights the losses according to the Neural Tangent Kernel + algorithm. + + :param dict(torch.Tensor) input: The dictionary of losses. + :return: The losses aggregation. It should be a scalar Tensor. + :rtype: torch.Tensor + """ + losses_norm = {} + for condition in losses: + losses[condition].backward(retain_graph=True) + grads = [] + for param in self.model.parameters(): + grads.append(param.grad.view(-1)) + grads = torch.cat(grads) + losses_norm[condition] = torch.norm(grads) + self.weights = { + condition: self.alpha + * self.weights.get(condition, self.default_value_weights) + + (1 - self.alpha) + * losses_norm[condition] + / sum(losses_norm.values()) + for condition in losses + } + return sum( + self.weights[condition] * loss for condition, loss in losses.items() + ) diff --git a/tests/test_weighting/test_ntk_weighting.py b/tests/test_weighting/test_ntk_weighting.py new file mode 100644 index 0000000..840237f --- /dev/null +++ b/tests/test_weighting/test_ntk_weighting.py @@ -0,0 +1,65 @@ +import pytest +from pina import Trainer +from pina.solver import PINN +from pina.model import FeedForward +from pina.problem.zoo import Poisson2DSquareProblem +from pina.loss import NeuralTangentKernelWeighting + +problem = Poisson2DSquareProblem() +condition_names = problem.conditions.keys() + + +@pytest.mark.parametrize( + "model,alpha", + [ + ( + FeedForward( + len(problem.input_variables), len(problem.output_variables) + ), + 0.5, + ) + ], +) +def test_constructor(model, alpha): + NeuralTangentKernelWeighting(model=model, alpha=alpha) + + +@pytest.mark.parametrize("model", [0.5]) +def test_wrong_constructor1(model): + with pytest.raises(ValueError): + NeuralTangentKernelWeighting(model) + + +@pytest.mark.parametrize( + "model,alpha", + [ + ( + FeedForward( + len(problem.input_variables), len(problem.output_variables) + ), + 1.2, + ) + ], +) +def test_wrong_constructor2(model, alpha): + with pytest.raises(ValueError): + NeuralTangentKernelWeighting(model, alpha) + + +@pytest.mark.parametrize( + "model,alpha", + [ + ( + FeedForward( + len(problem.input_variables), len(problem.output_variables) + ), + 0.5, + ) + ], +) +def test_train_aggregation(model, alpha): + weighting = NeuralTangentKernelWeighting(model=model, alpha=alpha) + problem.discretise_domain(50) + solver = PINN(problem=problem, model=model, weighting=weighting) + trainer = Trainer(solver=solver, max_epochs=5, accelerator="cpu") + trainer.train()