add mutual solver-weighting link

This commit is contained in:
giovanni
2025-08-29 19:11:08 +02:00
committed by Giovanni Canali
parent 973d0c05ab
commit bacd7e202a
6 changed files with 62 additions and 76 deletions

View File

@@ -1,7 +1,6 @@
"""Module for Neural Tangent Kernel Class""" """Module for Neural Tangent Kernel Class"""
import torch import torch
from torch.nn import Module
from .weighting_interface import WeightingInterface from .weighting_interface import WeightingInterface
from ..utils import check_consistency from ..utils import check_consistency
@@ -21,43 +20,45 @@ class NeuralTangentKernelWeighting(WeightingInterface):
""" """
def __init__(self, model, alpha=0.5): def __init__(self, alpha=0.5):
""" """
Initialization of the :class:`NeuralTangentKernelWeighting` class. Initialization of the :class:`NeuralTangentKernelWeighting` class.
:param torch.nn.Module model: The neural network model.
:param float alpha: The alpha parameter. :param float alpha: The alpha parameter.
:raises ValueError: If ``alpha`` is not between 0 and 1 (inclusive). :raises ValueError: If ``alpha`` is not between 0 and 1 (inclusive).
""" """
super().__init__() super().__init__()
# Check consistency
check_consistency(alpha, float) check_consistency(alpha, float)
check_consistency(model, Module)
if alpha < 0 or alpha > 1: if alpha < 0 or alpha > 1:
raise ValueError("alpha should be a value between 0 and 1") raise ValueError("alpha should be a value between 0 and 1")
# Initialize parameters
self.alpha = alpha self.alpha = alpha
self.model = model
self.weights = {} self.weights = {}
self.default_value_weights = 1 self.default_value_weights = 1.0
def aggregate(self, losses): def aggregate(self, losses):
""" """
Weight the losses according to the Neural Tangent Kernel Weight the losses according to the Neural Tangent Kernel algorithm.
algorithm.
:param dict(torch.Tensor) input: The dictionary of losses. :param dict(torch.Tensor) input: The dictionary of losses.
:return: The losses aggregation. It should be a scalar Tensor. :return: The aggregation of the losses. It should be a scalar Tensor.
:rtype: torch.Tensor :rtype: torch.Tensor
""" """
# Define a dictionary to store the norms of the gradients
losses_norm = {} losses_norm = {}
for condition in losses:
losses[condition].backward(retain_graph=True) # Compute the gradient norms for each loss component
grads = [] for condition, loss in losses.items():
for param in self.model.parameters(): loss.backward(retain_graph=True)
grads.append(param.grad.view(-1)) grads = torch.cat(
grads = torch.cat(grads) [p.grad.flatten() for p in self.solver.model.parameters()]
losses_norm[condition] = torch.norm(grads) )
losses_norm[condition] = grads.norm()
# Update the weights
self.weights = { self.weights = {
condition: self.alpha condition: self.alpha
* self.weights.get(condition, self.default_value_weights) * self.weights.get(condition, self.default_value_weights)
@@ -66,6 +67,7 @@ class NeuralTangentKernelWeighting(WeightingInterface):
/ sum(losses_norm.values()) / sum(losses_norm.values())
for condition in losses for condition in losses
} }
return sum( return sum(
self.weights[condition] * loss for condition, loss in losses.items() self.weights[condition] * loss for condition, loss in losses.items()
) )

View File

@@ -37,12 +37,16 @@ class ScalarWeighting(WeightingInterface):
:type weights: float | int | dict :type weights: float | int | dict
""" """
super().__init__() super().__init__()
# Check consistency
check_consistency([weights], (float, dict, int)) check_consistency([weights], (float, dict, int))
# Weights initialization
if isinstance(weights, (float, int)): if isinstance(weights, (float, int)):
self.default_value_weights = weights self.default_value_weights = weights
self.weights = {} self.weights = {}
else: else:
self.default_value_weights = 1 self.default_value_weights = 1.0
self.weights = weights self.weights = weights
def aggregate(self, losses): def aggregate(self, losses):

View File

@@ -13,7 +13,7 @@ class WeightingInterface(metaclass=ABCMeta):
""" """
Initialization of the :class:`WeightingInterface` class. Initialization of the :class:`WeightingInterface` class.
""" """
self.condition_names = None self._solver = None
@abstractmethod @abstractmethod
def aggregate(self, losses): def aggregate(self, losses):
@@ -22,3 +22,13 @@ class WeightingInterface(metaclass=ABCMeta):
:param dict losses: The dictionary of losses. :param dict losses: The dictionary of losses.
""" """
@property
def solver(self):
"""
The solver employing this weighting schema.
:return: The solver.
:rtype: :class:`~pina.solver.SolverInterface`
"""
return self._solver

View File

@@ -44,7 +44,7 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
weighting = _NoWeighting() weighting = _NoWeighting()
check_consistency(weighting, WeightingInterface) check_consistency(weighting, WeightingInterface)
self._pina_weighting = weighting self._pina_weighting = weighting
weighting.condition_names = list(self._pina_problem.conditions.keys()) weighting._solver = self
# check consistency use_lt # check consistency use_lt
check_consistency(use_lt, bool) check_consistency(use_lt, bool)

View File

@@ -2,64 +2,32 @@ import pytest
from pina import Trainer from pina import Trainer
from pina.solver import PINN from pina.solver import PINN
from pina.model import FeedForward from pina.model import FeedForward
from pina.problem.zoo import Poisson2DSquareProblem
from pina.loss import NeuralTangentKernelWeighting from pina.loss import NeuralTangentKernelWeighting
from pina.problem.zoo import Poisson2DSquareProblem
# Initialize problem and model
problem = Poisson2DSquareProblem() problem = Poisson2DSquareProblem()
condition_names = problem.conditions.keys() problem.discretise_domain(10)
model = FeedForward(len(problem.input_variables), len(problem.output_variables))
@pytest.mark.parametrize( @pytest.mark.parametrize("alpha", [0.0, 0.5, 1.0])
"model,alpha", def test_constructor(alpha):
[ NeuralTangentKernelWeighting(alpha=alpha)
(
FeedForward(
len(problem.input_variables), len(problem.output_variables)
),
0.5,
)
],
)
def test_constructor(model, alpha):
NeuralTangentKernelWeighting(model=model, alpha=alpha)
# Should fail if alpha is not >= 0
@pytest.mark.parametrize("model", [0.5])
def test_wrong_constructor1(model):
with pytest.raises(ValueError): with pytest.raises(ValueError):
NeuralTangentKernelWeighting(model) NeuralTangentKernelWeighting(alpha=-0.1)
# Should fail if alpha is not <= 1
@pytest.mark.parametrize(
"model,alpha",
[
(
FeedForward(
len(problem.input_variables), len(problem.output_variables)
),
1.2,
)
],
)
def test_wrong_constructor2(model, alpha):
with pytest.raises(ValueError): with pytest.raises(ValueError):
NeuralTangentKernelWeighting(model, alpha) NeuralTangentKernelWeighting(alpha=1.1)
@pytest.mark.parametrize( @pytest.mark.parametrize("alpha", [0.0, 0.5, 1.0])
"model,alpha", def test_train_aggregation(alpha):
[ weighting = NeuralTangentKernelWeighting(alpha=alpha)
(
FeedForward(
len(problem.input_variables), len(problem.output_variables)
),
0.5,
)
],
)
def test_train_aggregation(model, alpha):
weighting = NeuralTangentKernelWeighting(model=model, alpha=alpha)
problem.discretise_domain(50)
solver = PINN(problem=problem, model=model, weighting=weighting) solver = PINN(problem=problem, model=model, weighting=weighting)
trainer = Trainer(solver=solver, max_epochs=5, accelerator="cpu") trainer = Trainer(solver=solver, max_epochs=5, accelerator="cpu")
trainer.train() trainer.train()

View File

@@ -1,16 +1,17 @@
import pytest import pytest
import torch import torch
from pina import Trainer from pina import Trainer
from pina.solver import PINN from pina.solver import PINN
from pina.model import FeedForward from pina.model import FeedForward
from pina.problem.zoo import Poisson2DSquareProblem
from pina.loss import ScalarWeighting from pina.loss import ScalarWeighting
from pina.problem.zoo import Poisson2DSquareProblem
# Initialize problem and model
problem = Poisson2DSquareProblem() problem = Poisson2DSquareProblem()
problem.discretise_domain(50)
model = FeedForward(len(problem.input_variables), len(problem.output_variables)) model = FeedForward(len(problem.input_variables), len(problem.output_variables))
condition_names = problem.conditions.keys() condition_names = problem.conditions.keys()
print(problem.conditions.keys())
@pytest.mark.parametrize( @pytest.mark.parametrize(
@@ -19,11 +20,13 @@ print(problem.conditions.keys())
def test_constructor(weights): def test_constructor(weights):
ScalarWeighting(weights=weights) ScalarWeighting(weights=weights)
# Should fail if weights are not a scalar
@pytest.mark.parametrize("weights", ["a", [1, 2, 3]])
def test_wrong_constructor(weights):
with pytest.raises(ValueError): with pytest.raises(ValueError):
ScalarWeighting(weights=weights) ScalarWeighting(weights="invalid")
# Should fail if weights are not a dictionary
with pytest.raises(ValueError):
ScalarWeighting(weights=[1, 2, 3])
@pytest.mark.parametrize( @pytest.mark.parametrize(
@@ -45,7 +48,6 @@ def test_aggregate(weights):
) )
def test_train_aggregation(weights): def test_train_aggregation(weights):
weighting = ScalarWeighting(weights=weights) weighting = ScalarWeighting(weights=weights)
problem.discretise_domain(50)
solver = PINN(problem=problem, model=model, weighting=weighting) solver = PINN(problem=problem, model=model, weighting=weighting)
trainer = Trainer(solver=solver, max_epochs=5, accelerator="cpu") trainer = Trainer(solver=solver, max_epochs=5, accelerator="cpu")
trainer.train() trainer.train()