add mutual solver-weighting link

This commit is contained in:
giovanni
2025-08-29 19:11:08 +02:00
committed by Giovanni Canali
parent 973d0c05ab
commit bacd7e202a
6 changed files with 62 additions and 76 deletions

View File

@@ -1,7 +1,6 @@
"""Module for Neural Tangent Kernel Class"""
import torch
from torch.nn import Module
from .weighting_interface import WeightingInterface
from ..utils import check_consistency
@@ -21,43 +20,45 @@ class NeuralTangentKernelWeighting(WeightingInterface):
"""
def __init__(self, model, alpha=0.5):
def __init__(self, alpha=0.5):
"""
Initialization of the :class:`NeuralTangentKernelWeighting` class.
:param torch.nn.Module model: The neural network model.
:param float alpha: The alpha parameter.
:raises ValueError: If ``alpha`` is not between 0 and 1 (inclusive).
"""
super().__init__()
# Check consistency
check_consistency(alpha, float)
check_consistency(model, Module)
if alpha < 0 or alpha > 1:
raise ValueError("alpha should be a value between 0 and 1")
# Initialize parameters
self.alpha = alpha
self.model = model
self.weights = {}
self.default_value_weights = 1
self.default_value_weights = 1.0
def aggregate(self, losses):
"""
Weight the losses according to the Neural Tangent Kernel
algorithm.
Weight the losses according to the Neural Tangent Kernel algorithm.
:param dict(torch.Tensor) input: The dictionary of losses.
:return: The losses aggregation. It should be a scalar Tensor.
:return: The aggregation of the losses. It should be a scalar Tensor.
:rtype: torch.Tensor
"""
# Define a dictionary to store the norms of the gradients
losses_norm = {}
for condition in losses:
losses[condition].backward(retain_graph=True)
grads = []
for param in self.model.parameters():
grads.append(param.grad.view(-1))
grads = torch.cat(grads)
losses_norm[condition] = torch.norm(grads)
# Compute the gradient norms for each loss component
for condition, loss in losses.items():
loss.backward(retain_graph=True)
grads = torch.cat(
[p.grad.flatten() for p in self.solver.model.parameters()]
)
losses_norm[condition] = grads.norm()
# Update the weights
self.weights = {
condition: self.alpha
* self.weights.get(condition, self.default_value_weights)
@@ -66,6 +67,7 @@ class NeuralTangentKernelWeighting(WeightingInterface):
/ sum(losses_norm.values())
for condition in losses
}
return sum(
self.weights[condition] * loss for condition, loss in losses.items()
)

View File

@@ -37,12 +37,16 @@ class ScalarWeighting(WeightingInterface):
:type weights: float | int | dict
"""
super().__init__()
# Check consistency
check_consistency([weights], (float, dict, int))
# Weights initialization
if isinstance(weights, (float, int)):
self.default_value_weights = weights
self.weights = {}
else:
self.default_value_weights = 1
self.default_value_weights = 1.0
self.weights = weights
def aggregate(self, losses):

View File

@@ -13,7 +13,7 @@ class WeightingInterface(metaclass=ABCMeta):
"""
Initialization of the :class:`WeightingInterface` class.
"""
self.condition_names = None
self._solver = None
@abstractmethod
def aggregate(self, losses):
@@ -22,3 +22,13 @@ class WeightingInterface(metaclass=ABCMeta):
:param dict losses: The dictionary of losses.
"""
@property
def solver(self):
"""
The solver employing this weighting schema.
:return: The solver.
:rtype: :class:`~pina.solver.SolverInterface`
"""
return self._solver

View File

@@ -44,7 +44,7 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
weighting = _NoWeighting()
check_consistency(weighting, WeightingInterface)
self._pina_weighting = weighting
weighting.condition_names = list(self._pina_problem.conditions.keys())
weighting._solver = self
# check consistency use_lt
check_consistency(use_lt, bool)

View File

@@ -2,64 +2,32 @@ import pytest
from pina import Trainer
from pina.solver import PINN
from pina.model import FeedForward
from pina.problem.zoo import Poisson2DSquareProblem
from pina.loss import NeuralTangentKernelWeighting
from pina.problem.zoo import Poisson2DSquareProblem
# Initialize problem and model
problem = Poisson2DSquareProblem()
condition_names = problem.conditions.keys()
problem.discretise_domain(10)
model = FeedForward(len(problem.input_variables), len(problem.output_variables))
@pytest.mark.parametrize(
"model,alpha",
[
(
FeedForward(
len(problem.input_variables), len(problem.output_variables)
),
0.5,
)
],
)
def test_constructor(model, alpha):
NeuralTangentKernelWeighting(model=model, alpha=alpha)
@pytest.mark.parametrize("alpha", [0.0, 0.5, 1.0])
def test_constructor(alpha):
NeuralTangentKernelWeighting(alpha=alpha)
@pytest.mark.parametrize("model", [0.5])
def test_wrong_constructor1(model):
# Should fail if alpha is not >= 0
with pytest.raises(ValueError):
NeuralTangentKernelWeighting(model)
NeuralTangentKernelWeighting(alpha=-0.1)
@pytest.mark.parametrize(
"model,alpha",
[
(
FeedForward(
len(problem.input_variables), len(problem.output_variables)
),
1.2,
)
],
)
def test_wrong_constructor2(model, alpha):
# Should fail if alpha is not <= 1
with pytest.raises(ValueError):
NeuralTangentKernelWeighting(model, alpha)
NeuralTangentKernelWeighting(alpha=1.1)
@pytest.mark.parametrize(
"model,alpha",
[
(
FeedForward(
len(problem.input_variables), len(problem.output_variables)
),
0.5,
)
],
)
def test_train_aggregation(model, alpha):
weighting = NeuralTangentKernelWeighting(model=model, alpha=alpha)
problem.discretise_domain(50)
@pytest.mark.parametrize("alpha", [0.0, 0.5, 1.0])
def test_train_aggregation(alpha):
weighting = NeuralTangentKernelWeighting(alpha=alpha)
solver = PINN(problem=problem, model=model, weighting=weighting)
trainer = Trainer(solver=solver, max_epochs=5, accelerator="cpu")
trainer.train()

View File

@@ -1,16 +1,17 @@
import pytest
import torch
from pina import Trainer
from pina.solver import PINN
from pina.model import FeedForward
from pina.problem.zoo import Poisson2DSquareProblem
from pina.loss import ScalarWeighting
from pina.problem.zoo import Poisson2DSquareProblem
# Initialize problem and model
problem = Poisson2DSquareProblem()
problem.discretise_domain(50)
model = FeedForward(len(problem.input_variables), len(problem.output_variables))
condition_names = problem.conditions.keys()
print(problem.conditions.keys())
@pytest.mark.parametrize(
@@ -19,11 +20,13 @@ print(problem.conditions.keys())
def test_constructor(weights):
ScalarWeighting(weights=weights)
@pytest.mark.parametrize("weights", ["a", [1, 2, 3]])
def test_wrong_constructor(weights):
# Should fail if weights are not a scalar
with pytest.raises(ValueError):
ScalarWeighting(weights=weights)
ScalarWeighting(weights="invalid")
# Should fail if weights are not a dictionary
with pytest.raises(ValueError):
ScalarWeighting(weights=[1, 2, 3])
@pytest.mark.parametrize(
@@ -45,7 +48,6 @@ def test_aggregate(weights):
)
def test_train_aggregation(weights):
weighting = ScalarWeighting(weights=weights)
problem.discretise_domain(50)
solver = PINN(problem=problem, model=model, weighting=weighting)
trainer = Trainer(solver=solver, max_epochs=5, accelerator="cpu")
trainer.train()