add mutual solver-weighting link
This commit is contained in:
committed by
Giovanni Canali
parent
973d0c05ab
commit
bacd7e202a
@@ -1,7 +1,6 @@
|
|||||||
"""Module for Neural Tangent Kernel Class"""
|
"""Module for Neural Tangent Kernel Class"""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.nn import Module
|
|
||||||
from .weighting_interface import WeightingInterface
|
from .weighting_interface import WeightingInterface
|
||||||
from ..utils import check_consistency
|
from ..utils import check_consistency
|
||||||
|
|
||||||
@@ -21,43 +20,45 @@ class NeuralTangentKernelWeighting(WeightingInterface):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, model, alpha=0.5):
|
def __init__(self, alpha=0.5):
|
||||||
"""
|
"""
|
||||||
Initialization of the :class:`NeuralTangentKernelWeighting` class.
|
Initialization of the :class:`NeuralTangentKernelWeighting` class.
|
||||||
|
|
||||||
:param torch.nn.Module model: The neural network model.
|
|
||||||
:param float alpha: The alpha parameter.
|
:param float alpha: The alpha parameter.
|
||||||
|
|
||||||
:raises ValueError: If ``alpha`` is not between 0 and 1 (inclusive).
|
:raises ValueError: If ``alpha`` is not between 0 and 1 (inclusive).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
# Check consistency
|
||||||
check_consistency(alpha, float)
|
check_consistency(alpha, float)
|
||||||
check_consistency(model, Module)
|
|
||||||
if alpha < 0 or alpha > 1:
|
if alpha < 0 or alpha > 1:
|
||||||
raise ValueError("alpha should be a value between 0 and 1")
|
raise ValueError("alpha should be a value between 0 and 1")
|
||||||
|
|
||||||
|
# Initialize parameters
|
||||||
self.alpha = alpha
|
self.alpha = alpha
|
||||||
self.model = model
|
|
||||||
self.weights = {}
|
self.weights = {}
|
||||||
self.default_value_weights = 1
|
self.default_value_weights = 1.0
|
||||||
|
|
||||||
def aggregate(self, losses):
|
def aggregate(self, losses):
|
||||||
"""
|
"""
|
||||||
Weight the losses according to the Neural Tangent Kernel
|
Weight the losses according to the Neural Tangent Kernel algorithm.
|
||||||
algorithm.
|
|
||||||
|
|
||||||
:param dict(torch.Tensor) input: The dictionary of losses.
|
:param dict(torch.Tensor) input: The dictionary of losses.
|
||||||
:return: The losses aggregation. It should be a scalar Tensor.
|
:return: The aggregation of the losses. It should be a scalar Tensor.
|
||||||
:rtype: torch.Tensor
|
:rtype: torch.Tensor
|
||||||
"""
|
"""
|
||||||
|
# Define a dictionary to store the norms of the gradients
|
||||||
losses_norm = {}
|
losses_norm = {}
|
||||||
for condition in losses:
|
|
||||||
losses[condition].backward(retain_graph=True)
|
# Compute the gradient norms for each loss component
|
||||||
grads = []
|
for condition, loss in losses.items():
|
||||||
for param in self.model.parameters():
|
loss.backward(retain_graph=True)
|
||||||
grads.append(param.grad.view(-1))
|
grads = torch.cat(
|
||||||
grads = torch.cat(grads)
|
[p.grad.flatten() for p in self.solver.model.parameters()]
|
||||||
losses_norm[condition] = torch.norm(grads)
|
)
|
||||||
|
losses_norm[condition] = grads.norm()
|
||||||
|
|
||||||
|
# Update the weights
|
||||||
self.weights = {
|
self.weights = {
|
||||||
condition: self.alpha
|
condition: self.alpha
|
||||||
* self.weights.get(condition, self.default_value_weights)
|
* self.weights.get(condition, self.default_value_weights)
|
||||||
@@ -66,6 +67,7 @@ class NeuralTangentKernelWeighting(WeightingInterface):
|
|||||||
/ sum(losses_norm.values())
|
/ sum(losses_norm.values())
|
||||||
for condition in losses
|
for condition in losses
|
||||||
}
|
}
|
||||||
|
|
||||||
return sum(
|
return sum(
|
||||||
self.weights[condition] * loss for condition, loss in losses.items()
|
self.weights[condition] * loss for condition, loss in losses.items()
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -37,12 +37,16 @@ class ScalarWeighting(WeightingInterface):
|
|||||||
:type weights: float | int | dict
|
:type weights: float | int | dict
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
# Check consistency
|
||||||
check_consistency([weights], (float, dict, int))
|
check_consistency([weights], (float, dict, int))
|
||||||
|
|
||||||
|
# Weights initialization
|
||||||
if isinstance(weights, (float, int)):
|
if isinstance(weights, (float, int)):
|
||||||
self.default_value_weights = weights
|
self.default_value_weights = weights
|
||||||
self.weights = {}
|
self.weights = {}
|
||||||
else:
|
else:
|
||||||
self.default_value_weights = 1
|
self.default_value_weights = 1.0
|
||||||
self.weights = weights
|
self.weights = weights
|
||||||
|
|
||||||
def aggregate(self, losses):
|
def aggregate(self, losses):
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ class WeightingInterface(metaclass=ABCMeta):
|
|||||||
"""
|
"""
|
||||||
Initialization of the :class:`WeightingInterface` class.
|
Initialization of the :class:`WeightingInterface` class.
|
||||||
"""
|
"""
|
||||||
self.condition_names = None
|
self._solver = None
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def aggregate(self, losses):
|
def aggregate(self, losses):
|
||||||
@@ -22,3 +22,13 @@ class WeightingInterface(metaclass=ABCMeta):
|
|||||||
|
|
||||||
:param dict losses: The dictionary of losses.
|
:param dict losses: The dictionary of losses.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def solver(self):
|
||||||
|
"""
|
||||||
|
The solver employing this weighting schema.
|
||||||
|
|
||||||
|
:return: The solver.
|
||||||
|
:rtype: :class:`~pina.solver.SolverInterface`
|
||||||
|
"""
|
||||||
|
return self._solver
|
||||||
|
|||||||
@@ -44,7 +44,7 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
|
|||||||
weighting = _NoWeighting()
|
weighting = _NoWeighting()
|
||||||
check_consistency(weighting, WeightingInterface)
|
check_consistency(weighting, WeightingInterface)
|
||||||
self._pina_weighting = weighting
|
self._pina_weighting = weighting
|
||||||
weighting.condition_names = list(self._pina_problem.conditions.keys())
|
weighting._solver = self
|
||||||
|
|
||||||
# check consistency use_lt
|
# check consistency use_lt
|
||||||
check_consistency(use_lt, bool)
|
check_consistency(use_lt, bool)
|
||||||
|
|||||||
@@ -2,64 +2,32 @@ import pytest
|
|||||||
from pina import Trainer
|
from pina import Trainer
|
||||||
from pina.solver import PINN
|
from pina.solver import PINN
|
||||||
from pina.model import FeedForward
|
from pina.model import FeedForward
|
||||||
from pina.problem.zoo import Poisson2DSquareProblem
|
|
||||||
from pina.loss import NeuralTangentKernelWeighting
|
from pina.loss import NeuralTangentKernelWeighting
|
||||||
|
from pina.problem.zoo import Poisson2DSquareProblem
|
||||||
|
|
||||||
|
|
||||||
|
# Initialize problem and model
|
||||||
problem = Poisson2DSquareProblem()
|
problem = Poisson2DSquareProblem()
|
||||||
condition_names = problem.conditions.keys()
|
problem.discretise_domain(10)
|
||||||
|
model = FeedForward(len(problem.input_variables), len(problem.output_variables))
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize("alpha", [0.0, 0.5, 1.0])
|
||||||
"model,alpha",
|
def test_constructor(alpha):
|
||||||
[
|
NeuralTangentKernelWeighting(alpha=alpha)
|
||||||
(
|
|
||||||
FeedForward(
|
|
||||||
len(problem.input_variables), len(problem.output_variables)
|
|
||||||
),
|
|
||||||
0.5,
|
|
||||||
)
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_constructor(model, alpha):
|
|
||||||
NeuralTangentKernelWeighting(model=model, alpha=alpha)
|
|
||||||
|
|
||||||
|
# Should fail if alpha is not >= 0
|
||||||
@pytest.mark.parametrize("model", [0.5])
|
|
||||||
def test_wrong_constructor1(model):
|
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
NeuralTangentKernelWeighting(model)
|
NeuralTangentKernelWeighting(alpha=-0.1)
|
||||||
|
|
||||||
|
# Should fail if alpha is not <= 1
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"model,alpha",
|
|
||||||
[
|
|
||||||
(
|
|
||||||
FeedForward(
|
|
||||||
len(problem.input_variables), len(problem.output_variables)
|
|
||||||
),
|
|
||||||
1.2,
|
|
||||||
)
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_wrong_constructor2(model, alpha):
|
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
NeuralTangentKernelWeighting(model, alpha)
|
NeuralTangentKernelWeighting(alpha=1.1)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize("alpha", [0.0, 0.5, 1.0])
|
||||||
"model,alpha",
|
def test_train_aggregation(alpha):
|
||||||
[
|
weighting = NeuralTangentKernelWeighting(alpha=alpha)
|
||||||
(
|
|
||||||
FeedForward(
|
|
||||||
len(problem.input_variables), len(problem.output_variables)
|
|
||||||
),
|
|
||||||
0.5,
|
|
||||||
)
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_train_aggregation(model, alpha):
|
|
||||||
weighting = NeuralTangentKernelWeighting(model=model, alpha=alpha)
|
|
||||||
problem.discretise_domain(50)
|
|
||||||
solver = PINN(problem=problem, model=model, weighting=weighting)
|
solver = PINN(problem=problem, model=model, weighting=weighting)
|
||||||
trainer = Trainer(solver=solver, max_epochs=5, accelerator="cpu")
|
trainer = Trainer(solver=solver, max_epochs=5, accelerator="cpu")
|
||||||
trainer.train()
|
trainer.train()
|
||||||
|
|||||||
@@ -1,16 +1,17 @@
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from pina import Trainer
|
from pina import Trainer
|
||||||
from pina.solver import PINN
|
from pina.solver import PINN
|
||||||
from pina.model import FeedForward
|
from pina.model import FeedForward
|
||||||
from pina.problem.zoo import Poisson2DSquareProblem
|
|
||||||
from pina.loss import ScalarWeighting
|
from pina.loss import ScalarWeighting
|
||||||
|
from pina.problem.zoo import Poisson2DSquareProblem
|
||||||
|
|
||||||
|
|
||||||
|
# Initialize problem and model
|
||||||
problem = Poisson2DSquareProblem()
|
problem = Poisson2DSquareProblem()
|
||||||
|
problem.discretise_domain(50)
|
||||||
model = FeedForward(len(problem.input_variables), len(problem.output_variables))
|
model = FeedForward(len(problem.input_variables), len(problem.output_variables))
|
||||||
condition_names = problem.conditions.keys()
|
condition_names = problem.conditions.keys()
|
||||||
print(problem.conditions.keys())
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@@ -19,11 +20,13 @@ print(problem.conditions.keys())
|
|||||||
def test_constructor(weights):
|
def test_constructor(weights):
|
||||||
ScalarWeighting(weights=weights)
|
ScalarWeighting(weights=weights)
|
||||||
|
|
||||||
|
# Should fail if weights are not a scalar
|
||||||
@pytest.mark.parametrize("weights", ["a", [1, 2, 3]])
|
|
||||||
def test_wrong_constructor(weights):
|
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
ScalarWeighting(weights=weights)
|
ScalarWeighting(weights="invalid")
|
||||||
|
|
||||||
|
# Should fail if weights are not a dictionary
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
ScalarWeighting(weights=[1, 2, 3])
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@@ -45,7 +48,6 @@ def test_aggregate(weights):
|
|||||||
)
|
)
|
||||||
def test_train_aggregation(weights):
|
def test_train_aggregation(weights):
|
||||||
weighting = ScalarWeighting(weights=weights)
|
weighting = ScalarWeighting(weights=weights)
|
||||||
problem.discretise_domain(50)
|
|
||||||
solver = PINN(problem=problem, model=model, weighting=weighting)
|
solver = PINN(problem=problem, model=model, weighting=weighting)
|
||||||
trainer = Trainer(solver=solver, max_epochs=5, accelerator="cpu")
|
trainer = Trainer(solver=solver, max_epochs=5, accelerator="cpu")
|
||||||
trainer.train()
|
trainer.train()
|
||||||
Reference in New Issue
Block a user