add linear weighting
This commit is contained in:
committed by
Giovanni Canali
parent
96402baf20
commit
ef3542486c
@@ -8,6 +8,7 @@ __all__ = [
|
||||
"ScalarWeighting",
|
||||
"NeuralTangentKernelWeighting",
|
||||
"SelfAdaptiveWeighting",
|
||||
"LinearWeighting",
|
||||
]
|
||||
|
||||
from .loss_interface import LossInterface
|
||||
@@ -17,3 +18,4 @@ from .weighting_interface import WeightingInterface
|
||||
from .scalar_weighting import ScalarWeighting
|
||||
from .ntk_weighting import NeuralTangentKernelWeighting
|
||||
from .self_adaptive_weighting import SelfAdaptiveWeighting
|
||||
from .linear_weighting import LinearWeighting
|
||||
|
||||
64
pina/loss/linear_weighting.py
Normal file
64
pina/loss/linear_weighting.py
Normal file
@@ -0,0 +1,64 @@
|
||||
"""Module for the LinearWeighting class."""
|
||||
|
||||
from ..loss import WeightingInterface
|
||||
from ..utils import check_consistency, check_positive_integer
|
||||
|
||||
|
||||
class LinearWeighting(WeightingInterface):
|
||||
"""
|
||||
A weighting scheme that linearly scales weights from initial values to final
|
||||
values over a specified number of epochs.
|
||||
"""
|
||||
|
||||
def __init__(self, initial_weights, final_weights, target_epoch):
|
||||
"""
|
||||
:param dict initial_weights: The weights to be assigned to each loss
|
||||
term at the beginning of training. The keys are the conditions and
|
||||
the values are the corresponding weights. If a condition is not
|
||||
present in the dictionary, the default value (1) is used.
|
||||
:param dict final_weights: The weights to be assigned to each loss term
|
||||
once the target epoch is reached. The keys are the conditions and
|
||||
the values are the corresponding weights. If a condition is not
|
||||
present in the dictionary, the default value (1) is used.
|
||||
:param int target_epoch: The epoch at which the weights reach their
|
||||
final values.
|
||||
:raises ValueError: If the keys of the two dictionaries are not
|
||||
consistent.
|
||||
"""
|
||||
super().__init__(update_every_n_epochs=1, aggregator="sum")
|
||||
|
||||
# Check consistency
|
||||
check_consistency([initial_weights, final_weights], dict)
|
||||
check_positive_integer(value=target_epoch, strict=True)
|
||||
|
||||
# Check that the keys of the two dictionaries are the same
|
||||
if initial_weights.keys() != final_weights.keys():
|
||||
raise ValueError(
|
||||
"The keys of the initial_weights and final_weights "
|
||||
"dictionaries must be the same."
|
||||
)
|
||||
|
||||
# Initialization
|
||||
self.initial_weights = initial_weights
|
||||
self.final_weights = final_weights
|
||||
self.target_epoch = target_epoch
|
||||
|
||||
def weights_update(self, losses):
|
||||
"""
|
||||
Update the weighting scheme based on the given losses.
|
||||
|
||||
:param dict losses: The dictionary of losses.
|
||||
:return: The updated weights.
|
||||
:rtype: dict
|
||||
"""
|
||||
return {
|
||||
condition: self.last_saved_weights().get(
|
||||
condition, self.initial_weights.get(condition, 1)
|
||||
)
|
||||
+ (
|
||||
self.final_weights.get(condition, 1)
|
||||
- self.initial_weights.get(condition, 1)
|
||||
)
|
||||
/ (self.target_epoch)
|
||||
for condition in losses.keys()
|
||||
}
|
||||
@@ -61,11 +61,10 @@ class NeuralTangentKernelWeighting(WeightingInterface):
|
||||
losses_norm[condition] = grads.norm()
|
||||
|
||||
# Update the weights
|
||||
self.weights = {
|
||||
condition: self.alpha * self.weights.get(condition, 1)
|
||||
return {
|
||||
condition: self.alpha * self.last_saved_weights().get(condition, 1)
|
||||
+ (1 - self.alpha)
|
||||
* losses_norm[condition]
|
||||
/ sum(losses_norm.values())
|
||||
for condition in losses
|
||||
}
|
||||
return self.weights
|
||||
|
||||
@@ -17,7 +17,7 @@ class ScalarWeighting(WeightingInterface):
|
||||
If a single scalar value is provided, it is assigned to all loss
|
||||
terms. If a dictionary is provided, the keys are the conditions and
|
||||
the values are the weights. If a condition is not present in the
|
||||
dictionary, the default value is used.
|
||||
dictionary, the default value (1) is used.
|
||||
:type weights: float | int | dict
|
||||
"""
|
||||
super().__init__(update_every_n_epochs=1, aggregator="sum")
|
||||
@@ -29,11 +29,9 @@ class ScalarWeighting(WeightingInterface):
|
||||
if isinstance(weights, dict):
|
||||
self.values = weights
|
||||
self.default_value_weights = 1
|
||||
elif isinstance(weights, (float, int)):
|
||||
else:
|
||||
self.values = {}
|
||||
self.default_value_weights = weights
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
def weights_update(self, losses):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user