add self-adaptive weighting
This commit is contained in:
committed by
Giovanni Canali
parent
bacd7e202a
commit
c42bdd575c
@@ -267,3 +267,4 @@ Losses and Weightings
|
|||||||
WeightingInterface <loss/weighting_interface.rst>
|
WeightingInterface <loss/weighting_interface.rst>
|
||||||
ScalarWeighting <loss/scalar_weighting.rst>
|
ScalarWeighting <loss/scalar_weighting.rst>
|
||||||
NeuralTangentKernelWeighting <loss/ntk_weighting.rst>
|
NeuralTangentKernelWeighting <loss/ntk_weighting.rst>
|
||||||
|
SelfAdaptiveWeighting <loss/self_adaptive_weighting.rst>
|
||||||
9
docs/source/_rst/loss/self_adaptive_weighting.rst
Normal file
9
docs/source/_rst/loss/self_adaptive_weighting.rst
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
SelfAdaptiveWeighting
|
||||||
|
=============================
|
||||||
|
.. currentmodule:: pina.loss.self_adaptive_weighting
|
||||||
|
|
||||||
|
.. automodule:: pina.loss.self_adaptive_weighting
|
||||||
|
|
||||||
|
.. autoclass:: SelfAdaptiveWeighting
|
||||||
|
:members:
|
||||||
|
:show-inheritance:
|
||||||
@@ -7,6 +7,7 @@ __all__ = [
|
|||||||
"WeightingInterface",
|
"WeightingInterface",
|
||||||
"ScalarWeighting",
|
"ScalarWeighting",
|
||||||
"NeuralTangentKernelWeighting",
|
"NeuralTangentKernelWeighting",
|
||||||
|
"SelfAdaptiveWeighting",
|
||||||
]
|
]
|
||||||
|
|
||||||
from .loss_interface import LossInterface
|
from .loss_interface import LossInterface
|
||||||
@@ -15,3 +16,4 @@ from .lp_loss import LpLoss
|
|||||||
from .weighting_interface import WeightingInterface
|
from .weighting_interface import WeightingInterface
|
||||||
from .scalar_weighting import ScalarWeighting
|
from .scalar_weighting import ScalarWeighting
|
||||||
from .ntk_weighting import NeuralTangentKernelWeighting
|
from .ntk_weighting import NeuralTangentKernelWeighting
|
||||||
|
from .self_adaptive_weighting import SelfAdaptiveWeighting
|
||||||
|
|||||||
80
pina/loss/self_adaptive_weighting.py
Normal file
80
pina/loss/self_adaptive_weighting.py
Normal file
@@ -0,0 +1,80 @@
|
|||||||
|
"""Module for Self-Adaptive Weighting class."""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from .weighting_interface import WeightingInterface
|
||||||
|
from ..utils import check_positive_integer
|
||||||
|
|
||||||
|
|
||||||
|
class SelfAdaptiveWeighting(WeightingInterface):
|
||||||
|
"""
|
||||||
|
A self-adaptive weighting scheme to tackle the imbalance among the loss
|
||||||
|
components. This formulation equalizes the gradient norms of the losses,
|
||||||
|
preventing bias toward any particular term during training.
|
||||||
|
|
||||||
|
.. seealso::
|
||||||
|
|
||||||
|
**Original reference**:
|
||||||
|
Wang, S., Sankaran, S., Stinis., P., Perdikaris, P. (2025).
|
||||||
|
*Simulating Three-dimensional Turbulence with Physics-informed Neural
|
||||||
|
Networks*.
|
||||||
|
DOI: `arXiv preprint arXiv:2507.08972.
|
||||||
|
<https://arxiv.org/abs/2507.08972>`_
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, k=100):
|
||||||
|
"""
|
||||||
|
Initialization of the :class:`SelfAdaptiveWeighting` class.
|
||||||
|
|
||||||
|
:param int k: The number of epochs after which the weights are updated.
|
||||||
|
Default is 100.
|
||||||
|
|
||||||
|
:raises ValueError: If ``k`` is not a positive integer.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# Check consistency
|
||||||
|
check_positive_integer(value=k, strict=True)
|
||||||
|
|
||||||
|
# Initialize parameters
|
||||||
|
self.k = k
|
||||||
|
self.weights = {}
|
||||||
|
self.default_value_weights = 1.0
|
||||||
|
|
||||||
|
def aggregate(self, losses):
|
||||||
|
"""
|
||||||
|
Weight the losses according to the self-adaptive algorithm.
|
||||||
|
|
||||||
|
:param dict(torch.Tensor) losses: The dictionary of losses.
|
||||||
|
:return: The aggregation of the losses. It should be a scalar Tensor.
|
||||||
|
:rtype: torch.Tensor
|
||||||
|
"""
|
||||||
|
# If weights have not been initialized, set them to 1
|
||||||
|
if not self.weights:
|
||||||
|
self.weights = {
|
||||||
|
condition: self.default_value_weights for condition in losses
|
||||||
|
}
|
||||||
|
|
||||||
|
# Update every k epochs
|
||||||
|
if self.solver.trainer.current_epoch % self.k == 0:
|
||||||
|
|
||||||
|
# Define a dictionary to store the norms of the gradients
|
||||||
|
losses_norm = {}
|
||||||
|
|
||||||
|
# Compute the gradient norms for each loss component
|
||||||
|
for condition, loss in losses.items():
|
||||||
|
loss.backward(retain_graph=True)
|
||||||
|
grads = torch.cat(
|
||||||
|
[p.grad.flatten() for p in self.solver.model.parameters()]
|
||||||
|
)
|
||||||
|
losses_norm[condition] = grads.norm()
|
||||||
|
|
||||||
|
# Update the weights
|
||||||
|
self.weights = {
|
||||||
|
condition: sum(losses_norm.values()) / losses_norm[condition]
|
||||||
|
for condition in losses
|
||||||
|
}
|
||||||
|
|
||||||
|
return sum(
|
||||||
|
self.weights[condition] * loss for condition, loss in losses.items()
|
||||||
|
)
|
||||||
37
tests/test_weighting/test_self_adaptive_weighting.py
Normal file
37
tests/test_weighting/test_self_adaptive_weighting.py
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
import pytest
|
||||||
|
from pina import Trainer
|
||||||
|
from pina.solver import PINN
|
||||||
|
from pina.model import FeedForward
|
||||||
|
from pina.loss import SelfAdaptiveWeighting
|
||||||
|
from pina.problem.zoo import Poisson2DSquareProblem
|
||||||
|
|
||||||
|
|
||||||
|
# Initialize problem and model
|
||||||
|
problem = Poisson2DSquareProblem()
|
||||||
|
problem.discretise_domain(10)
|
||||||
|
model = FeedForward(len(problem.input_variables), len(problem.output_variables))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("k", [10, 100, 1000])
|
||||||
|
def test_constructor(k):
|
||||||
|
SelfAdaptiveWeighting(k=k)
|
||||||
|
|
||||||
|
# Should fail if k is not an integer
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
SelfAdaptiveWeighting(k=1.5)
|
||||||
|
|
||||||
|
# Should fail if k is not > 0
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
SelfAdaptiveWeighting(k=0)
|
||||||
|
|
||||||
|
# Should fail if k is not > 0
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
SelfAdaptiveWeighting(k=-3)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("k", [2, 3])
|
||||||
|
def test_train_aggregation(k):
|
||||||
|
weighting = SelfAdaptiveWeighting(k=k)
|
||||||
|
solver = PINN(problem=problem, model=model, weighting=weighting)
|
||||||
|
trainer = Trainer(solver=solver, max_epochs=5, accelerator="cpu")
|
||||||
|
trainer.train()
|
||||||
Reference in New Issue
Block a user