add linear weighting
This commit is contained in:
committed by
Giovanni Canali
parent
96402baf20
commit
ef3542486c
@@ -253,7 +253,6 @@ Callbacks
|
|||||||
Optimizer callback <callback/optimizer_callback.rst>
|
Optimizer callback <callback/optimizer_callback.rst>
|
||||||
R3 Refinment callback <callback/refinement/r3_refinement.rst>
|
R3 Refinment callback <callback/refinement/r3_refinement.rst>
|
||||||
Refinment Interface callback <callback/refinement/refinement_interface.rst>
|
Refinment Interface callback <callback/refinement/refinement_interface.rst>
|
||||||
Weighting callback <callback/linear_weight_update_callback.rst>
|
|
||||||
|
|
||||||
Losses and Weightings
|
Losses and Weightings
|
||||||
---------------------
|
---------------------
|
||||||
@@ -267,4 +266,5 @@ Losses and Weightings
|
|||||||
WeightingInterface <loss/weighting_interface.rst>
|
WeightingInterface <loss/weighting_interface.rst>
|
||||||
ScalarWeighting <loss/scalar_weighting.rst>
|
ScalarWeighting <loss/scalar_weighting.rst>
|
||||||
NeuralTangentKernelWeighting <loss/ntk_weighting.rst>
|
NeuralTangentKernelWeighting <loss/ntk_weighting.rst>
|
||||||
SelfAdaptiveWeighting <loss/self_adaptive_weighting.rst>
|
SelfAdaptiveWeighting <loss/self_adaptive_weighting.rst>
|
||||||
|
LinearWeighting <loss/linear_weighting.rst>
|
||||||
9
docs/source/_rst/loss/linear_weighting.rst
Normal file
9
docs/source/_rst/loss/linear_weighting.rst
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
LinearWeighting
|
||||||
|
=============================
|
||||||
|
.. currentmodule:: pina.loss.linear_weighting
|
||||||
|
|
||||||
|
.. automodule:: pina.loss.linear_weighting
|
||||||
|
|
||||||
|
.. autoclass:: LinearWeighting
|
||||||
|
:members:
|
||||||
|
:show-inheritance:
|
||||||
@@ -8,6 +8,7 @@ __all__ = [
|
|||||||
"ScalarWeighting",
|
"ScalarWeighting",
|
||||||
"NeuralTangentKernelWeighting",
|
"NeuralTangentKernelWeighting",
|
||||||
"SelfAdaptiveWeighting",
|
"SelfAdaptiveWeighting",
|
||||||
|
"LinearWeighting",
|
||||||
]
|
]
|
||||||
|
|
||||||
from .loss_interface import LossInterface
|
from .loss_interface import LossInterface
|
||||||
@@ -17,3 +18,4 @@ from .weighting_interface import WeightingInterface
|
|||||||
from .scalar_weighting import ScalarWeighting
|
from .scalar_weighting import ScalarWeighting
|
||||||
from .ntk_weighting import NeuralTangentKernelWeighting
|
from .ntk_weighting import NeuralTangentKernelWeighting
|
||||||
from .self_adaptive_weighting import SelfAdaptiveWeighting
|
from .self_adaptive_weighting import SelfAdaptiveWeighting
|
||||||
|
from .linear_weighting import LinearWeighting
|
||||||
|
|||||||
64
pina/loss/linear_weighting.py
Normal file
64
pina/loss/linear_weighting.py
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
"""Module for the LinearWeighting class."""
|
||||||
|
|
||||||
|
from ..loss import WeightingInterface
|
||||||
|
from ..utils import check_consistency, check_positive_integer
|
||||||
|
|
||||||
|
|
||||||
|
class LinearWeighting(WeightingInterface):
|
||||||
|
"""
|
||||||
|
A weighting scheme that linearly scales weights from initial values to final
|
||||||
|
values over a specified number of epochs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, initial_weights, final_weights, target_epoch):
|
||||||
|
"""
|
||||||
|
:param dict initial_weights: The weights to be assigned to each loss
|
||||||
|
term at the beginning of training. The keys are the conditions and
|
||||||
|
the values are the corresponding weights. If a condition is not
|
||||||
|
present in the dictionary, the default value (1) is used.
|
||||||
|
:param dict final_weights: The weights to be assigned to each loss term
|
||||||
|
once the target epoch is reached. The keys are the conditions and
|
||||||
|
the values are the corresponding weights. If a condition is not
|
||||||
|
present in the dictionary, the default value (1) is used.
|
||||||
|
:param int target_epoch: The epoch at which the weights reach their
|
||||||
|
final values.
|
||||||
|
:raises ValueError: If the keys of the two dictionaries are not
|
||||||
|
consistent.
|
||||||
|
"""
|
||||||
|
super().__init__(update_every_n_epochs=1, aggregator="sum")
|
||||||
|
|
||||||
|
# Check consistency
|
||||||
|
check_consistency([initial_weights, final_weights], dict)
|
||||||
|
check_positive_integer(value=target_epoch, strict=True)
|
||||||
|
|
||||||
|
# Check that the keys of the two dictionaries are the same
|
||||||
|
if initial_weights.keys() != final_weights.keys():
|
||||||
|
raise ValueError(
|
||||||
|
"The keys of the initial_weights and final_weights "
|
||||||
|
"dictionaries must be the same."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialization
|
||||||
|
self.initial_weights = initial_weights
|
||||||
|
self.final_weights = final_weights
|
||||||
|
self.target_epoch = target_epoch
|
||||||
|
|
||||||
|
def weights_update(self, losses):
|
||||||
|
"""
|
||||||
|
Update the weighting scheme based on the given losses.
|
||||||
|
|
||||||
|
:param dict losses: The dictionary of losses.
|
||||||
|
:return: The updated weights.
|
||||||
|
:rtype: dict
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
condition: self.last_saved_weights().get(
|
||||||
|
condition, self.initial_weights.get(condition, 1)
|
||||||
|
)
|
||||||
|
+ (
|
||||||
|
self.final_weights.get(condition, 1)
|
||||||
|
- self.initial_weights.get(condition, 1)
|
||||||
|
)
|
||||||
|
/ (self.target_epoch)
|
||||||
|
for condition in losses.keys()
|
||||||
|
}
|
||||||
@@ -61,11 +61,10 @@ class NeuralTangentKernelWeighting(WeightingInterface):
|
|||||||
losses_norm[condition] = grads.norm()
|
losses_norm[condition] = grads.norm()
|
||||||
|
|
||||||
# Update the weights
|
# Update the weights
|
||||||
self.weights = {
|
return {
|
||||||
condition: self.alpha * self.weights.get(condition, 1)
|
condition: self.alpha * self.last_saved_weights().get(condition, 1)
|
||||||
+ (1 - self.alpha)
|
+ (1 - self.alpha)
|
||||||
* losses_norm[condition]
|
* losses_norm[condition]
|
||||||
/ sum(losses_norm.values())
|
/ sum(losses_norm.values())
|
||||||
for condition in losses
|
for condition in losses
|
||||||
}
|
}
|
||||||
return self.weights
|
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ class ScalarWeighting(WeightingInterface):
|
|||||||
If a single scalar value is provided, it is assigned to all loss
|
If a single scalar value is provided, it is assigned to all loss
|
||||||
terms. If a dictionary is provided, the keys are the conditions and
|
terms. If a dictionary is provided, the keys are the conditions and
|
||||||
the values are the weights. If a condition is not present in the
|
the values are the weights. If a condition is not present in the
|
||||||
dictionary, the default value is used.
|
dictionary, the default value (1) is used.
|
||||||
:type weights: float | int | dict
|
:type weights: float | int | dict
|
||||||
"""
|
"""
|
||||||
super().__init__(update_every_n_epochs=1, aggregator="sum")
|
super().__init__(update_every_n_epochs=1, aggregator="sum")
|
||||||
@@ -29,11 +29,9 @@ class ScalarWeighting(WeightingInterface):
|
|||||||
if isinstance(weights, dict):
|
if isinstance(weights, dict):
|
||||||
self.values = weights
|
self.values = weights
|
||||||
self.default_value_weights = 1
|
self.default_value_weights = 1
|
||||||
elif isinstance(weights, (float, int)):
|
else:
|
||||||
self.values = {}
|
self.values = {}
|
||||||
self.default_value_weights = weights
|
self.default_value_weights = weights
|
||||||
else:
|
|
||||||
raise ValueError
|
|
||||||
|
|
||||||
def weights_update(self, losses):
|
def weights_update(self, losses):
|
||||||
"""
|
"""
|
||||||
|
|||||||
95
tests/test_weighting/test_linear_weighting.py
Normal file
95
tests/test_weighting/test_linear_weighting.py
Normal file
@@ -0,0 +1,95 @@
|
|||||||
|
import math
|
||||||
|
import pytest
|
||||||
|
from pina import Trainer
|
||||||
|
from pina.solver import PINN
|
||||||
|
from pina.model import FeedForward
|
||||||
|
from pina.loss import LinearWeighting
|
||||||
|
from pina.problem.zoo import Poisson2DSquareProblem
|
||||||
|
|
||||||
|
|
||||||
|
# Initialize problem and model
|
||||||
|
problem = Poisson2DSquareProblem()
|
||||||
|
problem.discretise_domain(10)
|
||||||
|
model = FeedForward(len(problem.input_variables), len(problem.output_variables))
|
||||||
|
|
||||||
|
# Weights for testing
|
||||||
|
init_weight_1 = {cond: 3 for cond in problem.conditions.keys()}
|
||||||
|
init_weight_2 = {cond: 4 for cond in problem.conditions.keys()}
|
||||||
|
final_weight_1 = {cond: 1 for cond in problem.conditions.keys()}
|
||||||
|
final_weight_2 = {cond: 5 for cond in problem.conditions.keys()}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("initial_weights", [init_weight_1, init_weight_2])
|
||||||
|
@pytest.mark.parametrize("final_weights", [final_weight_1, final_weight_2])
|
||||||
|
@pytest.mark.parametrize("target_epoch", [5, 10])
|
||||||
|
def test_constructor(initial_weights, final_weights, target_epoch):
|
||||||
|
LinearWeighting(
|
||||||
|
initial_weights=initial_weights,
|
||||||
|
final_weights=final_weights,
|
||||||
|
target_epoch=target_epoch,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should fail if initial_weights is not a dictionary
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
LinearWeighting(
|
||||||
|
initial_weights=[1, 1, 1],
|
||||||
|
final_weights=final_weights,
|
||||||
|
target_epoch=target_epoch,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should fail if final_weights is not a dictionary
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
LinearWeighting(
|
||||||
|
initial_weights=initial_weights,
|
||||||
|
final_weights=[1, 1, 1],
|
||||||
|
target_epoch=target_epoch,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should fail if target_epoch is not an integer
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
LinearWeighting(
|
||||||
|
initial_weights=initial_weights,
|
||||||
|
final_weights=final_weights,
|
||||||
|
target_epoch=1.5,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should fail if target_epoch is not positive
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
LinearWeighting(
|
||||||
|
initial_weights=initial_weights,
|
||||||
|
final_weights=final_weights,
|
||||||
|
target_epoch=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should fail if dictionary keys do not match
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
LinearWeighting(
|
||||||
|
initial_weights={list(initial_weights.keys())[0]: 1},
|
||||||
|
final_weights=final_weights,
|
||||||
|
target_epoch=target_epoch,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("initial_weights", [init_weight_1, init_weight_2])
|
||||||
|
@pytest.mark.parametrize("final_weights", [final_weight_1, final_weight_2])
|
||||||
|
@pytest.mark.parametrize("target_epoch", [5, 10])
|
||||||
|
def test_train_aggregation(initial_weights, final_weights, target_epoch):
|
||||||
|
weighting = LinearWeighting(
|
||||||
|
initial_weights=initial_weights,
|
||||||
|
final_weights=final_weights,
|
||||||
|
target_epoch=target_epoch,
|
||||||
|
)
|
||||||
|
solver = PINN(problem=problem, model=model, weighting=weighting)
|
||||||
|
trainer = Trainer(solver=solver, max_epochs=target_epoch, accelerator="cpu")
|
||||||
|
trainer.train()
|
||||||
|
|
||||||
|
# Check that weights are updated correctly
|
||||||
|
assert all(
|
||||||
|
math.isclose(
|
||||||
|
weighting.last_saved_weights()[cond],
|
||||||
|
final_weights[cond],
|
||||||
|
rel_tol=1e-5,
|
||||||
|
abs_tol=1e-8,
|
||||||
|
)
|
||||||
|
for cond in final_weights.keys()
|
||||||
|
)
|
||||||
Reference in New Issue
Block a user