add mutual solver-weighting link
This commit is contained in:
committed by
Giovanni Canali
parent
973d0c05ab
commit
bacd7e202a
@@ -1,7 +1,6 @@
|
||||
"""Module for Neural Tangent Kernel Class"""
|
||||
|
||||
import torch
|
||||
from torch.nn import Module
|
||||
from .weighting_interface import WeightingInterface
|
||||
from ..utils import check_consistency
|
||||
|
||||
@@ -21,43 +20,45 @@ class NeuralTangentKernelWeighting(WeightingInterface):
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, model, alpha=0.5):
|
||||
def __init__(self, alpha=0.5):
|
||||
"""
|
||||
Initialization of the :class:`NeuralTangentKernelWeighting` class.
|
||||
|
||||
:param torch.nn.Module model: The neural network model.
|
||||
:param float alpha: The alpha parameter.
|
||||
|
||||
:raises ValueError: If ``alpha`` is not between 0 and 1 (inclusive).
|
||||
"""
|
||||
|
||||
super().__init__()
|
||||
|
||||
# Check consistency
|
||||
check_consistency(alpha, float)
|
||||
check_consistency(model, Module)
|
||||
if alpha < 0 or alpha > 1:
|
||||
raise ValueError("alpha should be a value between 0 and 1")
|
||||
|
||||
# Initialize parameters
|
||||
self.alpha = alpha
|
||||
self.model = model
|
||||
self.weights = {}
|
||||
self.default_value_weights = 1
|
||||
self.default_value_weights = 1.0
|
||||
|
||||
def aggregate(self, losses):
|
||||
"""
|
||||
Weight the losses according to the Neural Tangent Kernel
|
||||
algorithm.
|
||||
Weight the losses according to the Neural Tangent Kernel algorithm.
|
||||
|
||||
:param dict(torch.Tensor) input: The dictionary of losses.
|
||||
:return: The losses aggregation. It should be a scalar Tensor.
|
||||
:return: The aggregation of the losses. It should be a scalar Tensor.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
# Define a dictionary to store the norms of the gradients
|
||||
losses_norm = {}
|
||||
for condition in losses:
|
||||
losses[condition].backward(retain_graph=True)
|
||||
grads = []
|
||||
for param in self.model.parameters():
|
||||
grads.append(param.grad.view(-1))
|
||||
grads = torch.cat(grads)
|
||||
losses_norm[condition] = torch.norm(grads)
|
||||
|
||||
# Compute the gradient norms for each loss component
|
||||
for condition, loss in losses.items():
|
||||
loss.backward(retain_graph=True)
|
||||
grads = torch.cat(
|
||||
[p.grad.flatten() for p in self.solver.model.parameters()]
|
||||
)
|
||||
losses_norm[condition] = grads.norm()
|
||||
|
||||
# Update the weights
|
||||
self.weights = {
|
||||
condition: self.alpha
|
||||
* self.weights.get(condition, self.default_value_weights)
|
||||
@@ -66,6 +67,7 @@ class NeuralTangentKernelWeighting(WeightingInterface):
|
||||
/ sum(losses_norm.values())
|
||||
for condition in losses
|
||||
}
|
||||
|
||||
return sum(
|
||||
self.weights[condition] * loss for condition, loss in losses.items()
|
||||
)
|
||||
|
||||
@@ -37,12 +37,16 @@ class ScalarWeighting(WeightingInterface):
|
||||
:type weights: float | int | dict
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
# Check consistency
|
||||
check_consistency([weights], (float, dict, int))
|
||||
|
||||
# Weights initialization
|
||||
if isinstance(weights, (float, int)):
|
||||
self.default_value_weights = weights
|
||||
self.weights = {}
|
||||
else:
|
||||
self.default_value_weights = 1
|
||||
self.default_value_weights = 1.0
|
||||
self.weights = weights
|
||||
|
||||
def aggregate(self, losses):
|
||||
|
||||
@@ -13,7 +13,7 @@ class WeightingInterface(metaclass=ABCMeta):
|
||||
"""
|
||||
Initialization of the :class:`WeightingInterface` class.
|
||||
"""
|
||||
self.condition_names = None
|
||||
self._solver = None
|
||||
|
||||
@abstractmethod
|
||||
def aggregate(self, losses):
|
||||
@@ -22,3 +22,13 @@ class WeightingInterface(metaclass=ABCMeta):
|
||||
|
||||
:param dict losses: The dictionary of losses.
|
||||
"""
|
||||
|
||||
@property
|
||||
def solver(self):
|
||||
"""
|
||||
The solver employing this weighting schema.
|
||||
|
||||
:return: The solver.
|
||||
:rtype: :class:`~pina.solver.SolverInterface`
|
||||
"""
|
||||
return self._solver
|
||||
|
||||
Reference in New Issue
Block a user