add linear weighting

This commit is contained in:
giovanni
2025-09-05 10:12:09 +02:00
committed by Giovanni Canali
parent 96402baf20
commit ef3542486c
7 changed files with 176 additions and 9 deletions

View File

@@ -253,7 +253,6 @@ Callbacks
Optimizer callback <callback/optimizer_callback.rst> Optimizer callback <callback/optimizer_callback.rst>
R3 Refinment callback <callback/refinement/r3_refinement.rst> R3 Refinment callback <callback/refinement/r3_refinement.rst>
Refinment Interface callback <callback/refinement/refinement_interface.rst> Refinment Interface callback <callback/refinement/refinement_interface.rst>
Weighting callback <callback/linear_weight_update_callback.rst>
Losses and Weightings Losses and Weightings
--------------------- ---------------------
@@ -267,4 +266,5 @@ Losses and Weightings
WeightingInterface <loss/weighting_interface.rst> WeightingInterface <loss/weighting_interface.rst>
ScalarWeighting <loss/scalar_weighting.rst> ScalarWeighting <loss/scalar_weighting.rst>
NeuralTangentKernelWeighting <loss/ntk_weighting.rst> NeuralTangentKernelWeighting <loss/ntk_weighting.rst>
SelfAdaptiveWeighting <loss/self_adaptive_weighting.rst> SelfAdaptiveWeighting <loss/self_adaptive_weighting.rst>
LinearWeighting <loss/linear_weighting.rst>

View File

@@ -0,0 +1,9 @@
LinearWeighting
=============================
.. currentmodule:: pina.loss.linear_weighting
.. automodule:: pina.loss.linear_weighting
.. autoclass:: LinearWeighting
:members:
:show-inheritance:

View File

@@ -8,6 +8,7 @@ __all__ = [
"ScalarWeighting", "ScalarWeighting",
"NeuralTangentKernelWeighting", "NeuralTangentKernelWeighting",
"SelfAdaptiveWeighting", "SelfAdaptiveWeighting",
"LinearWeighting",
] ]
from .loss_interface import LossInterface from .loss_interface import LossInterface
@@ -17,3 +18,4 @@ from .weighting_interface import WeightingInterface
from .scalar_weighting import ScalarWeighting from .scalar_weighting import ScalarWeighting
from .ntk_weighting import NeuralTangentKernelWeighting from .ntk_weighting import NeuralTangentKernelWeighting
from .self_adaptive_weighting import SelfAdaptiveWeighting from .self_adaptive_weighting import SelfAdaptiveWeighting
from .linear_weighting import LinearWeighting

View File

@@ -0,0 +1,64 @@
"""Module for the LinearWeighting class."""
from ..loss import WeightingInterface
from ..utils import check_consistency, check_positive_integer
class LinearWeighting(WeightingInterface):
"""
A weighting scheme that linearly scales weights from initial values to final
values over a specified number of epochs.
"""
def __init__(self, initial_weights, final_weights, target_epoch):
"""
:param dict initial_weights: The weights to be assigned to each loss
term at the beginning of training. The keys are the conditions and
the values are the corresponding weights. If a condition is not
present in the dictionary, the default value (1) is used.
:param dict final_weights: The weights to be assigned to each loss term
once the target epoch is reached. The keys are the conditions and
the values are the corresponding weights. If a condition is not
present in the dictionary, the default value (1) is used.
:param int target_epoch: The epoch at which the weights reach their
final values.
:raises ValueError: If the keys of the two dictionaries are not
consistent.
"""
super().__init__(update_every_n_epochs=1, aggregator="sum")
# Check consistency
check_consistency([initial_weights, final_weights], dict)
check_positive_integer(value=target_epoch, strict=True)
# Check that the keys of the two dictionaries are the same
if initial_weights.keys() != final_weights.keys():
raise ValueError(
"The keys of the initial_weights and final_weights "
"dictionaries must be the same."
)
# Initialization
self.initial_weights = initial_weights
self.final_weights = final_weights
self.target_epoch = target_epoch
def weights_update(self, losses):
"""
Update the weighting scheme based on the given losses.
:param dict losses: The dictionary of losses.
:return: The updated weights.
:rtype: dict
"""
return {
condition: self.last_saved_weights().get(
condition, self.initial_weights.get(condition, 1)
)
+ (
self.final_weights.get(condition, 1)
- self.initial_weights.get(condition, 1)
)
/ (self.target_epoch)
for condition in losses.keys()
}

View File

@@ -61,11 +61,10 @@ class NeuralTangentKernelWeighting(WeightingInterface):
losses_norm[condition] = grads.norm() losses_norm[condition] = grads.norm()
# Update the weights # Update the weights
self.weights = { return {
condition: self.alpha * self.weights.get(condition, 1) condition: self.alpha * self.last_saved_weights().get(condition, 1)
+ (1 - self.alpha) + (1 - self.alpha)
* losses_norm[condition] * losses_norm[condition]
/ sum(losses_norm.values()) / sum(losses_norm.values())
for condition in losses for condition in losses
} }
return self.weights

View File

@@ -17,7 +17,7 @@ class ScalarWeighting(WeightingInterface):
If a single scalar value is provided, it is assigned to all loss If a single scalar value is provided, it is assigned to all loss
terms. If a dictionary is provided, the keys are the conditions and terms. If a dictionary is provided, the keys are the conditions and
the values are the weights. If a condition is not present in the the values are the weights. If a condition is not present in the
dictionary, the default value is used. dictionary, the default value (1) is used.
:type weights: float | int | dict :type weights: float | int | dict
""" """
super().__init__(update_every_n_epochs=1, aggregator="sum") super().__init__(update_every_n_epochs=1, aggregator="sum")
@@ -29,11 +29,9 @@ class ScalarWeighting(WeightingInterface):
if isinstance(weights, dict): if isinstance(weights, dict):
self.values = weights self.values = weights
self.default_value_weights = 1 self.default_value_weights = 1
elif isinstance(weights, (float, int)): else:
self.values = {} self.values = {}
self.default_value_weights = weights self.default_value_weights = weights
else:
raise ValueError
def weights_update(self, losses): def weights_update(self, losses):
""" """

View File

@@ -0,0 +1,95 @@
import math
import pytest
from pina import Trainer
from pina.solver import PINN
from pina.model import FeedForward
from pina.loss import LinearWeighting
from pina.problem.zoo import Poisson2DSquareProblem
# Initialize problem and model
problem = Poisson2DSquareProblem()
problem.discretise_domain(10)
model = FeedForward(len(problem.input_variables), len(problem.output_variables))
# Weights for testing
init_weight_1 = {cond: 3 for cond in problem.conditions.keys()}
init_weight_2 = {cond: 4 for cond in problem.conditions.keys()}
final_weight_1 = {cond: 1 for cond in problem.conditions.keys()}
final_weight_2 = {cond: 5 for cond in problem.conditions.keys()}
@pytest.mark.parametrize("initial_weights", [init_weight_1, init_weight_2])
@pytest.mark.parametrize("final_weights", [final_weight_1, final_weight_2])
@pytest.mark.parametrize("target_epoch", [5, 10])
def test_constructor(initial_weights, final_weights, target_epoch):
LinearWeighting(
initial_weights=initial_weights,
final_weights=final_weights,
target_epoch=target_epoch,
)
# Should fail if initial_weights is not a dictionary
with pytest.raises(ValueError):
LinearWeighting(
initial_weights=[1, 1, 1],
final_weights=final_weights,
target_epoch=target_epoch,
)
# Should fail if final_weights is not a dictionary
with pytest.raises(ValueError):
LinearWeighting(
initial_weights=initial_weights,
final_weights=[1, 1, 1],
target_epoch=target_epoch,
)
# Should fail if target_epoch is not an integer
with pytest.raises(AssertionError):
LinearWeighting(
initial_weights=initial_weights,
final_weights=final_weights,
target_epoch=1.5,
)
# Should fail if target_epoch is not positive
with pytest.raises(AssertionError):
LinearWeighting(
initial_weights=initial_weights,
final_weights=final_weights,
target_epoch=0,
)
# Should fail if dictionary keys do not match
with pytest.raises(ValueError):
LinearWeighting(
initial_weights={list(initial_weights.keys())[0]: 1},
final_weights=final_weights,
target_epoch=target_epoch,
)
@pytest.mark.parametrize("initial_weights", [init_weight_1, init_weight_2])
@pytest.mark.parametrize("final_weights", [final_weight_1, final_weight_2])
@pytest.mark.parametrize("target_epoch", [5, 10])
def test_train_aggregation(initial_weights, final_weights, target_epoch):
weighting = LinearWeighting(
initial_weights=initial_weights,
final_weights=final_weights,
target_epoch=target_epoch,
)
solver = PINN(problem=problem, model=model, weighting=weighting)
trainer = Trainer(solver=solver, max_epochs=target_epoch, accelerator="cpu")
trainer.train()
# Check that weights are updated correctly
assert all(
math.isclose(
weighting.last_saved_weights()[cond],
final_weights[cond],
rel_tol=1e-5,
abs_tol=1e-8,
)
for cond in final_weights.keys()
)