update doc

This commit is contained in:
Dario Coscia
2025-03-19 17:56:27 +01:00
committed by Nicola Demo
parent 69283e7fde
commit 3a7a4a0950
2 changed files with 5 additions and 4 deletions

View File

@@ -248,3 +248,4 @@ Losses and Weightings
PowerLoss <loss/powerloss.rst>
WeightingInterface <loss/weighting_interface.rst>
ScalarWeighting <loss/scalar_weighting.rst>
NeuralTangentKernelWeighting <loss/ntk_weighting.rst>

View File

@@ -17,9 +17,7 @@ class NeuralTangentKernelWeighting(WeightingInterface):
Paris Perdikaris. *When and why PINNs fail to train:
A neural tangent kernel perspective*. Journal of
Computational Physics 449 (2022): 110768.
DOI: `10.1016/j.jcp.2021.110768 <https://doi.org/10.1016/j.jcp.2021.110768>`_.
DOI: `10.1016 <https://doi.org/10.1016/j.jcp.2021.110768>`_.
"""
@@ -29,6 +27,8 @@ class NeuralTangentKernelWeighting(WeightingInterface):
:param torch.nn.Module model: The neural network model.
:param float alpha: The alpha parameter.
:raises ValueError: If ``alpha`` is not between 0 and 1 (inclusive).
"""
super().__init__()
@@ -43,7 +43,7 @@ class NeuralTangentKernelWeighting(WeightingInterface):
def aggregate(self, losses):
"""
Weights the losses according to the Neural Tangent Kernel
Weight the losses according to the Neural Tangent Kernel
algorithm.
:param dict(torch.Tensor) input: The dictionary of losses.