Neural Tangent Kernel integration + typo fix (#505)

* NTK weighting + typo fixing
* black code formatter + .rst docs

---------

Co-authored-by: Dario Coscia <dariocos99@gmail.com>
This commit is contained in:
Giuseppe Alessio D'Inverno
2025-03-19 12:44:07 -04:00
committed by Nicola Demo
parent 2b09cb95cf
commit 716d43f146
4 changed files with 147 additions and 0 deletions

View File

@@ -0,0 +1,9 @@
NeuralTangentKernelWeighting
=============================
.. currentmodule:: pina.loss.ntk_weighting
.. automodule:: pina.loss.ntk_weighting
.. autoclass:: NeuralTangentKernelWeighting
:members:
:show-inheritance:

View File

@@ -6,6 +6,7 @@ __all__ = [
"PowerLoss",
"WeightingInterface",
"ScalarWeighting",
"NeuralTangentKernelWeighting",
]
from .loss_interface import LossInterface
@@ -13,3 +14,4 @@ from .power_loss import PowerLoss
from .lp_loss import LpLoss
from .weighting_interface import WeightingInterface
from .scalar_weighting import ScalarWeighting
from .ntk_weighting import NeuralTangentKernelWeighting

View File

@@ -0,0 +1,71 @@
"""Module for Neural Tangent Kernel Class"""
import torch
from torch.nn import Module
from .weighting_interface import WeightingInterface
from ..utils import check_consistency
class NeuralTangentKernelWeighting(WeightingInterface):
"""
A neural tangent kernel scheme for weighting different losses to
boost the convergence.
.. seealso::
**Original reference**: Wang, Sifan, Xinling Yu, and
Paris Perdikaris. *When and why PINNs fail to train:
A neural tangent kernel perspective*. Journal of
Computational Physics 449 (2022): 110768.
DOI: `10.1016/j.jcp.2021.110768 <https://doi.org/10.1016/j.jcp.2021.110768>`_.
"""
def __init__(self, model, alpha=0.5):
"""
Initialization of the :class:`NeuralTangentKernelWeighting` class.
:param torch.nn.Module model: The neural network model.
:param float alpha: The alpha parameter.
"""
super().__init__()
check_consistency(alpha, float)
check_consistency(model, Module)
if alpha < 0 or alpha > 1:
raise ValueError("alpha should be a value between 0 and 1")
self.alpha = alpha
self.model = model
self.weights = {}
self.default_value_weights = 1
def aggregate(self, losses):
"""
Weights the losses according to the Neural Tangent Kernel
algorithm.
:param dict(torch.Tensor) input: The dictionary of losses.
:return: The losses aggregation. It should be a scalar Tensor.
:rtype: torch.Tensor
"""
losses_norm = {}
for condition in losses:
losses[condition].backward(retain_graph=True)
grads = []
for param in self.model.parameters():
grads.append(param.grad.view(-1))
grads = torch.cat(grads)
losses_norm[condition] = torch.norm(grads)
self.weights = {
condition: self.alpha
* self.weights.get(condition, self.default_value_weights)
+ (1 - self.alpha)
* losses_norm[condition]
/ sum(losses_norm.values())
for condition in losses
}
return sum(
self.weights[condition] * loss for condition, loss in losses.items()
)

View File

@@ -0,0 +1,65 @@
import pytest
from pina import Trainer
from pina.solver import PINN
from pina.model import FeedForward
from pina.problem.zoo import Poisson2DSquareProblem
from pina.loss import NeuralTangentKernelWeighting
problem = Poisson2DSquareProblem()
condition_names = problem.conditions.keys()
@pytest.mark.parametrize(
"model,alpha",
[
(
FeedForward(
len(problem.input_variables), len(problem.output_variables)
),
0.5,
)
],
)
def test_constructor(model, alpha):
NeuralTangentKernelWeighting(model=model, alpha=alpha)
@pytest.mark.parametrize("model", [0.5])
def test_wrong_constructor1(model):
with pytest.raises(ValueError):
NeuralTangentKernelWeighting(model)
@pytest.mark.parametrize(
"model,alpha",
[
(
FeedForward(
len(problem.input_variables), len(problem.output_variables)
),
1.2,
)
],
)
def test_wrong_constructor2(model, alpha):
with pytest.raises(ValueError):
NeuralTangentKernelWeighting(model, alpha)
@pytest.mark.parametrize(
"model,alpha",
[
(
FeedForward(
len(problem.input_variables), len(problem.output_variables)
),
0.5,
)
],
)
def test_train_aggregation(model, alpha):
weighting = NeuralTangentKernelWeighting(model=model, alpha=alpha)
problem.discretise_domain(50)
solver = PINN(problem=problem, model=model, weighting=weighting)
trainer = Trainer(solver=solver, max_epochs=5, accelerator="cpu")
trainer.train()