Neural Tangent Kernel integration + typo fix (#505)
* NTK weighting + typo fixing * black code formatter + .rst docs --------- Co-authored-by: Dario Coscia <dariocos99@gmail.com>
This commit is contained in:
committed by
Nicola Demo
parent
2b09cb95cf
commit
716d43f146
9
docs/source/_rst/loss/ntk_weighting.rst
Normal file
9
docs/source/_rst/loss/ntk_weighting.rst
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
NeuralTangentKernelWeighting
|
||||||
|
=============================
|
||||||
|
.. currentmodule:: pina.loss.ntk_weighting
|
||||||
|
|
||||||
|
.. automodule:: pina.loss.ntk_weighting
|
||||||
|
|
||||||
|
.. autoclass:: NeuralTangentKernelWeighting
|
||||||
|
:members:
|
||||||
|
:show-inheritance:
|
||||||
@@ -6,6 +6,7 @@ __all__ = [
|
|||||||
"PowerLoss",
|
"PowerLoss",
|
||||||
"WeightingInterface",
|
"WeightingInterface",
|
||||||
"ScalarWeighting",
|
"ScalarWeighting",
|
||||||
|
"NeuralTangentKernelWeighting",
|
||||||
]
|
]
|
||||||
|
|
||||||
from .loss_interface import LossInterface
|
from .loss_interface import LossInterface
|
||||||
@@ -13,3 +14,4 @@ from .power_loss import PowerLoss
|
|||||||
from .lp_loss import LpLoss
|
from .lp_loss import LpLoss
|
||||||
from .weighting_interface import WeightingInterface
|
from .weighting_interface import WeightingInterface
|
||||||
from .scalar_weighting import ScalarWeighting
|
from .scalar_weighting import ScalarWeighting
|
||||||
|
from .ntk_weighting import NeuralTangentKernelWeighting
|
||||||
|
|||||||
71
pina/loss/ntk_weighting.py
Normal file
71
pina/loss/ntk_weighting.py
Normal file
@@ -0,0 +1,71 @@
|
|||||||
|
"""Module for Neural Tangent Kernel Class"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.nn import Module
|
||||||
|
from .weighting_interface import WeightingInterface
|
||||||
|
from ..utils import check_consistency
|
||||||
|
|
||||||
|
|
||||||
|
class NeuralTangentKernelWeighting(WeightingInterface):
|
||||||
|
"""
|
||||||
|
A neural tangent kernel scheme for weighting different losses to
|
||||||
|
boost the convergence.
|
||||||
|
|
||||||
|
.. seealso::
|
||||||
|
|
||||||
|
**Original reference**: Wang, Sifan, Xinling Yu, and
|
||||||
|
Paris Perdikaris. *When and why PINNs fail to train:
|
||||||
|
A neural tangent kernel perspective*. Journal of
|
||||||
|
Computational Physics 449 (2022): 110768.
|
||||||
|
DOI: `10.1016/j.jcp.2021.110768 <https://doi.org/10.1016/j.jcp.2021.110768>`_.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, model, alpha=0.5):
|
||||||
|
"""
|
||||||
|
Initialization of the :class:`NeuralTangentKernelWeighting` class.
|
||||||
|
|
||||||
|
:param torch.nn.Module model: The neural network model.
|
||||||
|
:param float alpha: The alpha parameter.
|
||||||
|
"""
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
check_consistency(alpha, float)
|
||||||
|
check_consistency(model, Module)
|
||||||
|
if alpha < 0 or alpha > 1:
|
||||||
|
raise ValueError("alpha should be a value between 0 and 1")
|
||||||
|
self.alpha = alpha
|
||||||
|
self.model = model
|
||||||
|
self.weights = {}
|
||||||
|
self.default_value_weights = 1
|
||||||
|
|
||||||
|
def aggregate(self, losses):
|
||||||
|
"""
|
||||||
|
Weights the losses according to the Neural Tangent Kernel
|
||||||
|
algorithm.
|
||||||
|
|
||||||
|
:param dict(torch.Tensor) input: The dictionary of losses.
|
||||||
|
:return: The losses aggregation. It should be a scalar Tensor.
|
||||||
|
:rtype: torch.Tensor
|
||||||
|
"""
|
||||||
|
losses_norm = {}
|
||||||
|
for condition in losses:
|
||||||
|
losses[condition].backward(retain_graph=True)
|
||||||
|
grads = []
|
||||||
|
for param in self.model.parameters():
|
||||||
|
grads.append(param.grad.view(-1))
|
||||||
|
grads = torch.cat(grads)
|
||||||
|
losses_norm[condition] = torch.norm(grads)
|
||||||
|
self.weights = {
|
||||||
|
condition: self.alpha
|
||||||
|
* self.weights.get(condition, self.default_value_weights)
|
||||||
|
+ (1 - self.alpha)
|
||||||
|
* losses_norm[condition]
|
||||||
|
/ sum(losses_norm.values())
|
||||||
|
for condition in losses
|
||||||
|
}
|
||||||
|
return sum(
|
||||||
|
self.weights[condition] * loss for condition, loss in losses.items()
|
||||||
|
)
|
||||||
65
tests/test_weighting/test_ntk_weighting.py
Normal file
65
tests/test_weighting/test_ntk_weighting.py
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
import pytest
|
||||||
|
from pina import Trainer
|
||||||
|
from pina.solver import PINN
|
||||||
|
from pina.model import FeedForward
|
||||||
|
from pina.problem.zoo import Poisson2DSquareProblem
|
||||||
|
from pina.loss import NeuralTangentKernelWeighting
|
||||||
|
|
||||||
|
problem = Poisson2DSquareProblem()
|
||||||
|
condition_names = problem.conditions.keys()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"model,alpha",
|
||||||
|
[
|
||||||
|
(
|
||||||
|
FeedForward(
|
||||||
|
len(problem.input_variables), len(problem.output_variables)
|
||||||
|
),
|
||||||
|
0.5,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_constructor(model, alpha):
|
||||||
|
NeuralTangentKernelWeighting(model=model, alpha=alpha)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model", [0.5])
|
||||||
|
def test_wrong_constructor1(model):
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
NeuralTangentKernelWeighting(model)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"model,alpha",
|
||||||
|
[
|
||||||
|
(
|
||||||
|
FeedForward(
|
||||||
|
len(problem.input_variables), len(problem.output_variables)
|
||||||
|
),
|
||||||
|
1.2,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_wrong_constructor2(model, alpha):
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
NeuralTangentKernelWeighting(model, alpha)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"model,alpha",
|
||||||
|
[
|
||||||
|
(
|
||||||
|
FeedForward(
|
||||||
|
len(problem.input_variables), len(problem.output_variables)
|
||||||
|
),
|
||||||
|
0.5,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_train_aggregation(model, alpha):
|
||||||
|
weighting = NeuralTangentKernelWeighting(model=model, alpha=alpha)
|
||||||
|
problem.discretise_domain(50)
|
||||||
|
solver = PINN(problem=problem, model=model, weighting=weighting)
|
||||||
|
trainer = Trainer(solver=solver, max_epochs=5, accelerator="cpu")
|
||||||
|
trainer.train()
|
||||||
Reference in New Issue
Block a user