weighting refactory
Co-authored-by: Dario Coscia <dariocos99@gmail.com>
This commit is contained in:
committed by
Giovanni Canali
parent
c42bdd575c
commit
96402baf20
@@ -1,7 +0,0 @@
|
||||
Weighting callbacks
|
||||
========================
|
||||
|
||||
.. currentmodule:: pina.callback.linear_weight_update_callback
|
||||
.. autoclass:: LinearWeightUpdate
|
||||
:members:
|
||||
:show-inheritance:
|
||||
@@ -4,11 +4,9 @@ __all__ = [
|
||||
"SwitchOptimizer",
|
||||
"MetricTracker",
|
||||
"PINAProgressBar",
|
||||
"LinearWeightUpdate",
|
||||
"R3Refinement",
|
||||
]
|
||||
|
||||
from .optimizer_callback import SwitchOptimizer
|
||||
from .processing_callback import MetricTracker, PINAProgressBar
|
||||
from .linear_weight_update_callback import LinearWeightUpdate
|
||||
from .refinement import R3Refinement
|
||||
|
||||
@@ -1,87 +0,0 @@
|
||||
"""Module for the LinearWeightUpdate callback."""
|
||||
|
||||
import warnings
|
||||
from lightning.pytorch.callbacks import Callback
|
||||
from ..utils import check_consistency
|
||||
from ..loss import ScalarWeighting
|
||||
|
||||
|
||||
class LinearWeightUpdate(Callback):
|
||||
"""
|
||||
Callback to linearly adjust the weight of a condition from an
|
||||
initial value to a target value over a specified number of epochs.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, target_epoch, condition_name, initial_value, target_value
|
||||
):
|
||||
"""
|
||||
Callback initialization.
|
||||
|
||||
:param int target_epoch: The epoch at which the weight of the condition
|
||||
should reach the target value.
|
||||
:param str condition_name: The name of the condition whose weight
|
||||
should be adjusted.
|
||||
:param float initial_value: The initial value of the weight.
|
||||
:param float target_value: The target value of the weight.
|
||||
"""
|
||||
super().__init__()
|
||||
self.target_epoch = target_epoch
|
||||
self.condition_name = condition_name
|
||||
self.initial_value = initial_value
|
||||
self.target_value = target_value
|
||||
|
||||
# Check consistency
|
||||
check_consistency(self.target_epoch, int, subclass=False)
|
||||
check_consistency(self.condition_name, str, subclass=False)
|
||||
check_consistency(self.initial_value, (float, int), subclass=False)
|
||||
check_consistency(self.target_value, (float, int), subclass=False)
|
||||
|
||||
def on_train_start(self, trainer, pl_module):
|
||||
"""
|
||||
Initialize the weight of the condition to the specified `initial_value`.
|
||||
|
||||
:param Trainer trainer: A :class:`~pina.trainer.Trainer` instance.
|
||||
:param SolverInterface pl_module: A
|
||||
:class:`~pina.solver.solver.SolverInterface` instance.
|
||||
"""
|
||||
# Check that the target epoch is valid
|
||||
if not 0 < self.target_epoch <= trainer.max_epochs:
|
||||
raise ValueError(
|
||||
"`target_epoch` must be greater than 0"
|
||||
" and less than or equal to `max_epochs`."
|
||||
)
|
||||
|
||||
# Check that the condition is a problem condition
|
||||
if self.condition_name not in pl_module.problem.conditions:
|
||||
raise ValueError(
|
||||
f"`{self.condition_name}` must be a problem condition."
|
||||
)
|
||||
|
||||
# Check that the initial value is not equal to the target value
|
||||
if self.initial_value == self.target_value:
|
||||
warnings.warn(
|
||||
"`initial_value` is equal to `target_value`. "
|
||||
"No effective adjustment will be performed.",
|
||||
UserWarning,
|
||||
)
|
||||
|
||||
# Check that the weighting schema is ScalarWeighting
|
||||
if not isinstance(pl_module.weighting, ScalarWeighting):
|
||||
raise ValueError("The weighting schema must be ScalarWeighting.")
|
||||
|
||||
# Initialize the weight of the condition
|
||||
pl_module.weighting.weights[self.condition_name] = self.initial_value
|
||||
|
||||
def on_train_epoch_start(self, trainer, pl_module):
|
||||
"""
|
||||
Adjust at each epoch the weight of the condition.
|
||||
|
||||
:param Trainer trainer: A :class:`~pina.trainer.Trainer` instance.
|
||||
:param SolverInterface pl_module: A
|
||||
:class:`~pina.solver.solver.SolverInterface` instance.
|
||||
"""
|
||||
if 0 < trainer.current_epoch <= self.target_epoch:
|
||||
pl_module.weighting.weights[self.condition_name] += (
|
||||
self.target_value - self.initial_value
|
||||
) / (self.target_epoch - 1)
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
import torch
|
||||
from .weighting_interface import WeightingInterface
|
||||
from ..utils import check_consistency
|
||||
from ..utils import check_consistency, in_range
|
||||
|
||||
|
||||
class NeuralTangentKernelWeighting(WeightingInterface):
|
||||
@@ -20,32 +20,34 @@ class NeuralTangentKernelWeighting(WeightingInterface):
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, alpha=0.5):
|
||||
def __init__(self, update_every_n_epochs=1, alpha=0.5):
|
||||
"""
|
||||
Initialization of the :class:`NeuralTangentKernelWeighting` class.
|
||||
|
||||
:param int update_every_n_epochs: The number of training epochs between
|
||||
weight updates. If set to 1, the weights are updated at every epoch.
|
||||
Default is 1.
|
||||
:param float alpha: The alpha parameter.
|
||||
:raises ValueError: If ``alpha`` is not between 0 and 1 (inclusive).
|
||||
"""
|
||||
super().__init__()
|
||||
super().__init__(update_every_n_epochs=update_every_n_epochs)
|
||||
|
||||
# Check consistency
|
||||
check_consistency(alpha, float)
|
||||
if alpha < 0 or alpha > 1:
|
||||
raise ValueError("alpha should be a value between 0 and 1")
|
||||
if not in_range(alpha, [0, 1], strict=False):
|
||||
raise ValueError("alpha must be in range (0, 1).")
|
||||
|
||||
# Initialize parameters
|
||||
self.alpha = alpha
|
||||
self.weights = {}
|
||||
self.default_value_weights = 1.0
|
||||
|
||||
def aggregate(self, losses):
|
||||
def weights_update(self, losses):
|
||||
"""
|
||||
Weight the losses according to the Neural Tangent Kernel algorithm.
|
||||
Update the weighting scheme based on the given losses.
|
||||
|
||||
:param dict(torch.Tensor) input: The dictionary of losses.
|
||||
:return: The aggregation of the losses. It should be a scalar Tensor.
|
||||
:rtype: torch.Tensor
|
||||
:param dict losses: The dictionary of losses.
|
||||
:return: The updated weights.
|
||||
:rtype: dict
|
||||
"""
|
||||
# Define a dictionary to store the norms of the gradients
|
||||
losses_norm = {}
|
||||
@@ -60,14 +62,10 @@ class NeuralTangentKernelWeighting(WeightingInterface):
|
||||
|
||||
# Update the weights
|
||||
self.weights = {
|
||||
condition: self.alpha
|
||||
* self.weights.get(condition, self.default_value_weights)
|
||||
condition: self.alpha * self.weights.get(condition, 1)
|
||||
+ (1 - self.alpha)
|
||||
* losses_norm[condition]
|
||||
/ sum(losses_norm.values())
|
||||
for condition in losses
|
||||
}
|
||||
|
||||
return sum(
|
||||
self.weights[condition] * loss for condition, loss in losses.items()
|
||||
)
|
||||
return self.weights
|
||||
|
||||
@@ -4,22 +4,6 @@ from .weighting_interface import WeightingInterface
|
||||
from ..utils import check_consistency
|
||||
|
||||
|
||||
class _NoWeighting(WeightingInterface):
|
||||
"""
|
||||
Weighting scheme that does not apply any weighting to the losses.
|
||||
"""
|
||||
|
||||
def aggregate(self, losses):
|
||||
"""
|
||||
Aggregate the losses.
|
||||
|
||||
:param dict losses: The dictionary of losses.
|
||||
:return: The aggregated losses.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
return sum(losses.values())
|
||||
|
||||
|
||||
class ScalarWeighting(WeightingInterface):
|
||||
"""
|
||||
Weighting scheme that assigns a scalar weight to each loss term.
|
||||
@@ -36,28 +20,42 @@ class ScalarWeighting(WeightingInterface):
|
||||
dictionary, the default value is used.
|
||||
:type weights: float | int | dict
|
||||
"""
|
||||
super().__init__()
|
||||
super().__init__(update_every_n_epochs=1, aggregator="sum")
|
||||
|
||||
# Check consistency
|
||||
check_consistency([weights], (float, dict, int))
|
||||
|
||||
# Weights initialization
|
||||
if isinstance(weights, (float, int)):
|
||||
# Initialization
|
||||
if isinstance(weights, dict):
|
||||
self.values = weights
|
||||
self.default_value_weights = 1
|
||||
elif isinstance(weights, (float, int)):
|
||||
self.values = {}
|
||||
self.default_value_weights = weights
|
||||
self.weights = {}
|
||||
else:
|
||||
self.default_value_weights = 1.0
|
||||
self.weights = weights
|
||||
raise ValueError
|
||||
|
||||
def aggregate(self, losses):
|
||||
def weights_update(self, losses):
|
||||
"""
|
||||
Aggregate the losses.
|
||||
Update the weighting scheme based on the given losses.
|
||||
|
||||
:param dict losses: The dictionary of losses.
|
||||
:return: The aggregated losses.
|
||||
:rtype: torch.Tensor
|
||||
:return: The updated weights.
|
||||
:rtype: dict
|
||||
"""
|
||||
return sum(
|
||||
self.weights.get(condition, self.default_value_weights) * loss
|
||||
for condition, loss in losses.items()
|
||||
)
|
||||
return {
|
||||
condition: self.values.get(condition, self.default_value_weights)
|
||||
for condition in losses.keys()
|
||||
}
|
||||
|
||||
|
||||
class _NoWeighting(ScalarWeighting):
|
||||
"""
|
||||
Weighting scheme that does not apply any weighting to the losses.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
Initialization of the :class:`_NoWeighting` class.
|
||||
"""
|
||||
super().__init__(weights=1)
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
|
||||
import torch
|
||||
from .weighting_interface import WeightingInterface
|
||||
from ..utils import check_positive_integer
|
||||
|
||||
|
||||
class SelfAdaptiveWeighting(WeightingInterface):
|
||||
@@ -22,42 +21,24 @@ class SelfAdaptiveWeighting(WeightingInterface):
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, k=100):
|
||||
def __init__(self, update_every_n_epochs=1):
|
||||
"""
|
||||
Initialization of the :class:`SelfAdaptiveWeighting` class.
|
||||
|
||||
:param int k: The number of epochs after which the weights are updated.
|
||||
Default is 100.
|
||||
|
||||
:raises ValueError: If ``k`` is not a positive integer.
|
||||
:param int update_every_n_epochs: The number of training epochs between
|
||||
weight updates. If set to 1, the weights are updated at every epoch.
|
||||
Default is 1.
|
||||
"""
|
||||
super().__init__()
|
||||
super().__init__(update_every_n_epochs=update_every_n_epochs)
|
||||
|
||||
# Check consistency
|
||||
check_positive_integer(value=k, strict=True)
|
||||
|
||||
# Initialize parameters
|
||||
self.k = k
|
||||
self.weights = {}
|
||||
self.default_value_weights = 1.0
|
||||
|
||||
def aggregate(self, losses):
|
||||
def weights_update(self, losses):
|
||||
"""
|
||||
Weight the losses according to the self-adaptive algorithm.
|
||||
Update the weighting scheme based on the given losses.
|
||||
|
||||
:param dict(torch.Tensor) losses: The dictionary of losses.
|
||||
:return: The aggregation of the losses. It should be a scalar Tensor.
|
||||
:rtype: torch.Tensor
|
||||
:param dict losses: The dictionary of losses.
|
||||
:return: The updated weights.
|
||||
:rtype: dict
|
||||
"""
|
||||
# If weights have not been initialized, set them to 1
|
||||
if not self.weights:
|
||||
self.weights = {
|
||||
condition: self.default_value_weights for condition in losses
|
||||
}
|
||||
|
||||
# Update every k epochs
|
||||
if self.solver.trainer.current_epoch % self.k == 0:
|
||||
|
||||
# Define a dictionary to store the norms of the gradients
|
||||
losses_norm = {}
|
||||
|
||||
@@ -70,11 +51,7 @@ class SelfAdaptiveWeighting(WeightingInterface):
|
||||
losses_norm[condition] = grads.norm()
|
||||
|
||||
# Update the weights
|
||||
self.weights = {
|
||||
return {
|
||||
condition: sum(losses_norm.values()) / losses_norm[condition]
|
||||
for condition in losses
|
||||
}
|
||||
|
||||
return sum(
|
||||
self.weights[condition] * loss for condition, loss in losses.items()
|
||||
)
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
"""Module for the Weighting Interface."""
|
||||
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from typing import final
|
||||
from ..utils import check_positive_integer, is_function
|
||||
|
||||
_AGGREGATE_METHODS = {"sum": sum, "mean": lambda x: sum(x) / len(x)}
|
||||
|
||||
|
||||
class WeightingInterface(metaclass=ABCMeta):
|
||||
@@ -9,20 +13,93 @@ class WeightingInterface(metaclass=ABCMeta):
|
||||
should inherit from this class.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, update_every_n_epochs=1, aggregator="sum"):
|
||||
"""
|
||||
Initialization of the :class:`WeightingInterface` class.
|
||||
|
||||
:param int update_every_n_epochs: The number of training epochs between
|
||||
weight updates. If set to 1, the weights are updated at every epoch.
|
||||
This parameter is ignored by static weighting schemes. Default is 1.
|
||||
:param aggregator: The aggregation method. Either:
|
||||
- 'sum' → torch.sum
|
||||
- 'mean' → torch.mean
|
||||
- callable → custom aggregation function
|
||||
:type aggregator: str | Callable
|
||||
"""
|
||||
# Check consistency
|
||||
check_positive_integer(value=update_every_n_epochs, strict=True)
|
||||
|
||||
# Aggregation
|
||||
if isinstance(aggregator, str):
|
||||
if aggregator not in _AGGREGATE_METHODS:
|
||||
raise ValueError(
|
||||
f"Invalid aggregator '{aggregator}'. Must be one of "
|
||||
f"{list(_AGGREGATE_METHODS.keys())}."
|
||||
)
|
||||
aggregator = _AGGREGATE_METHODS[aggregator]
|
||||
|
||||
elif not is_function(aggregator):
|
||||
raise TypeError(
|
||||
f"Aggregator must be either a string or a callable, "
|
||||
f"got {type(aggregator).__name__}."
|
||||
)
|
||||
|
||||
# Initialization
|
||||
self._solver = None
|
||||
self.update_every_n_epochs = update_every_n_epochs
|
||||
self.aggregator_fn = aggregator
|
||||
self._saved_weights = {}
|
||||
|
||||
@abstractmethod
|
||||
def aggregate(self, losses):
|
||||
def weights_update(self, losses):
|
||||
"""
|
||||
Aggregate the losses.
|
||||
Update the weighting scheme based on the given losses.
|
||||
|
||||
This method must be implemented by subclasses. Its role is to update the
|
||||
values of the weights. The updated weights will then be used by
|
||||
:meth:`aggregate` to compute the final aggregated loss.
|
||||
|
||||
:param dict losses: The dictionary of losses.
|
||||
:return: The updated weights.
|
||||
:rtype: dict
|
||||
"""
|
||||
|
||||
@final
|
||||
def aggregate(self, losses):
|
||||
"""
|
||||
Update the weights (if needed) and aggregate the given losses.
|
||||
|
||||
This method first checks whether the loss weights need to be updated
|
||||
based on the current epoch and the ``update_every_n_epochs`` setting.
|
||||
If an update is required, it calls :meth:`weights_update` to refresh the
|
||||
weights. Afterwards, it aggregates the (weighted) losses into a single
|
||||
scalar tensor using the configured aggregator function. This method must
|
||||
not be overridden.
|
||||
|
||||
:param dict losses: The dictionary of losses.
|
||||
:return: The aggregated loss tensor.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
# Update weights
|
||||
if self.solver.trainer.current_epoch % self.update_every_n_epochs == 0:
|
||||
self._saved_weights = self.weights_update(losses)
|
||||
|
||||
# Aggregate. Using direct indexing instead of .get() ensures that a
|
||||
# KeyError is raised if the expected condition is missing from the dict.
|
||||
return self.aggregator_fn(
|
||||
self._saved_weights[condition] * loss
|
||||
for condition, loss in losses.items()
|
||||
)
|
||||
|
||||
def last_saved_weights(self):
|
||||
"""
|
||||
Get the last saved weights.
|
||||
|
||||
:return: The last saved weights.
|
||||
:rtype: dict
|
||||
"""
|
||||
return self._saved_weights
|
||||
|
||||
@property
|
||||
def solver(self):
|
||||
"""
|
||||
|
||||
@@ -240,3 +240,31 @@ def check_positive_integer(value, strict=True):
|
||||
assert (
|
||||
isinstance(value, int) and value >= 0
|
||||
), f"Expected a non-negative integer, got {value}."
|
||||
|
||||
|
||||
def in_range(value, range_vals, strict=True):
|
||||
"""
|
||||
Check if a value is within a specified range.
|
||||
|
||||
:param int value: The integer value to check.
|
||||
:param list[int] range_vals: A list of two integers representing the range
|
||||
limits. The first element specifies the lower bound, and the second
|
||||
specifies the upper bound.
|
||||
:param bool strict: If True, the value must be strictly positive.
|
||||
Default is True.
|
||||
:return: True if the value satisfies the range condition, False otherwise.
|
||||
:rtype: bool
|
||||
"""
|
||||
# Validate inputs
|
||||
check_consistency(value, (float, int))
|
||||
check_consistency(range_vals, (float, int))
|
||||
assert (
|
||||
isinstance(range_vals, list) and len(range_vals) == 2
|
||||
), "range_vals must be a list of two integers [lower, upper]"
|
||||
lower, upper = range_vals
|
||||
|
||||
# Check the range
|
||||
if strict:
|
||||
return lower < value < upper
|
||||
|
||||
return lower <= value <= upper
|
||||
|
||||
@@ -1,164 +0,0 @@
|
||||
import pytest
|
||||
import math
|
||||
from pina.solver import PINN
|
||||
from pina.loss import ScalarWeighting
|
||||
from pina.trainer import Trainer
|
||||
from pina.model import FeedForward
|
||||
from pina.problem.zoo import Poisson2DSquareProblem as Poisson
|
||||
from pina.callback import LinearWeightUpdate
|
||||
|
||||
|
||||
# Define the problem
|
||||
poisson_problem = Poisson()
|
||||
poisson_problem.discretise_domain(50, "grid")
|
||||
cond_name = list(poisson_problem.conditions.keys())[0]
|
||||
|
||||
# Define the model
|
||||
model = FeedForward(
|
||||
input_dimensions=len(poisson_problem.input_variables),
|
||||
output_dimensions=len(poisson_problem.output_variables),
|
||||
layers=[32, 32],
|
||||
)
|
||||
|
||||
# Define the weighting schema
|
||||
weights_dict = {key: 1 for key in poisson_problem.conditions.keys()}
|
||||
weighting = ScalarWeighting(weights=weights_dict)
|
||||
|
||||
# Define the solver
|
||||
solver = PINN(problem=poisson_problem, model=model, weighting=weighting)
|
||||
|
||||
# Value used for testing
|
||||
epochs = 10
|
||||
|
||||
|
||||
@pytest.mark.parametrize("initial_value", [1, 5.5])
|
||||
@pytest.mark.parametrize("target_value", [10, 25.5])
|
||||
def test_constructor(initial_value, target_value):
|
||||
LinearWeightUpdate(
|
||||
target_epoch=epochs,
|
||||
condition_name=cond_name,
|
||||
initial_value=initial_value,
|
||||
target_value=target_value,
|
||||
)
|
||||
|
||||
# Target_epoch must be int
|
||||
with pytest.raises(ValueError):
|
||||
LinearWeightUpdate(
|
||||
target_epoch=10.0,
|
||||
condition_name=cond_name,
|
||||
initial_value=0,
|
||||
target_value=1,
|
||||
)
|
||||
|
||||
# Condition_name must be str
|
||||
with pytest.raises(ValueError):
|
||||
LinearWeightUpdate(
|
||||
target_epoch=epochs,
|
||||
condition_name=100,
|
||||
initial_value=0,
|
||||
target_value=1,
|
||||
)
|
||||
|
||||
# Initial_value must be float or int
|
||||
with pytest.raises(ValueError):
|
||||
LinearWeightUpdate(
|
||||
target_epoch=epochs,
|
||||
condition_name=cond_name,
|
||||
initial_value="0",
|
||||
target_value=1,
|
||||
)
|
||||
|
||||
# Target_value must be float or int
|
||||
with pytest.raises(ValueError):
|
||||
LinearWeightUpdate(
|
||||
target_epoch=epochs,
|
||||
condition_name=cond_name,
|
||||
initial_value=0,
|
||||
target_value="1",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("initial_value, target_value", [(1, 10), (10, 1)])
|
||||
def test_training(initial_value, target_value):
|
||||
callback = LinearWeightUpdate(
|
||||
target_epoch=epochs,
|
||||
condition_name=cond_name,
|
||||
initial_value=initial_value,
|
||||
target_value=target_value,
|
||||
)
|
||||
trainer = Trainer(
|
||||
solver=solver,
|
||||
callbacks=[callback],
|
||||
accelerator="cpu",
|
||||
max_epochs=epochs,
|
||||
)
|
||||
trainer.train()
|
||||
|
||||
# Check that the final weight value matches the target value
|
||||
final_value = solver.weighting.weights[cond_name]
|
||||
assert math.isclose(final_value, target_value)
|
||||
|
||||
# Target_epoch must be greater than 0
|
||||
with pytest.raises(ValueError):
|
||||
callback = LinearWeightUpdate(
|
||||
target_epoch=0,
|
||||
condition_name=cond_name,
|
||||
initial_value=0,
|
||||
target_value=1,
|
||||
)
|
||||
trainer = Trainer(
|
||||
solver=solver,
|
||||
callbacks=[callback],
|
||||
accelerator="cpu",
|
||||
max_epochs=5,
|
||||
)
|
||||
trainer.train()
|
||||
|
||||
# Target_epoch must be less than or equal to max_epochs
|
||||
with pytest.raises(ValueError):
|
||||
callback = LinearWeightUpdate(
|
||||
target_epoch=epochs,
|
||||
condition_name=cond_name,
|
||||
initial_value=0,
|
||||
target_value=1,
|
||||
)
|
||||
trainer = Trainer(
|
||||
solver=solver,
|
||||
callbacks=[callback],
|
||||
accelerator="cpu",
|
||||
max_epochs=epochs - 1,
|
||||
)
|
||||
trainer.train()
|
||||
|
||||
# Condition_name must be a problem condition
|
||||
with pytest.raises(ValueError):
|
||||
callback = LinearWeightUpdate(
|
||||
target_epoch=epochs,
|
||||
condition_name="not_a_condition",
|
||||
initial_value=0,
|
||||
target_value=1,
|
||||
)
|
||||
trainer = Trainer(
|
||||
solver=solver,
|
||||
callbacks=[callback],
|
||||
accelerator="cpu",
|
||||
max_epochs=epochs,
|
||||
)
|
||||
trainer.train()
|
||||
|
||||
# Weighting schema must be ScalarWeighting
|
||||
with pytest.raises(ValueError):
|
||||
callback = LinearWeightUpdate(
|
||||
target_epoch=epochs,
|
||||
condition_name=cond_name,
|
||||
initial_value=0,
|
||||
target_value=1,
|
||||
)
|
||||
unweighted_solver = PINN(problem=poisson_problem, model=model)
|
||||
trainer = Trainer(
|
||||
solver=unweighted_solver,
|
||||
callbacks=[callback],
|
||||
accelerator="cpu",
|
||||
max_epochs=epochs,
|
||||
)
|
||||
trainer.train()
|
||||
@@ -12,22 +12,42 @@ problem.discretise_domain(10)
|
||||
model = FeedForward(len(problem.input_variables), len(problem.output_variables))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("update_every_n_epochs", [1, 10, 100, 1000])
|
||||
@pytest.mark.parametrize("alpha", [0.0, 0.5, 1.0])
|
||||
def test_constructor(alpha):
|
||||
NeuralTangentKernelWeighting(alpha=alpha)
|
||||
def test_constructor(update_every_n_epochs, alpha):
|
||||
NeuralTangentKernelWeighting(
|
||||
update_every_n_epochs=update_every_n_epochs, alpha=alpha
|
||||
)
|
||||
|
||||
# Should fail if alpha is not >= 0
|
||||
with pytest.raises(ValueError):
|
||||
NeuralTangentKernelWeighting(alpha=-0.1)
|
||||
NeuralTangentKernelWeighting(
|
||||
update_every_n_epochs=update_every_n_epochs, alpha=-0.1
|
||||
)
|
||||
|
||||
# Should fail if alpha is not <= 1
|
||||
with pytest.raises(ValueError):
|
||||
NeuralTangentKernelWeighting(alpha=1.1)
|
||||
|
||||
# Should fail if update_every_n_epochs is not an integer
|
||||
with pytest.raises(AssertionError):
|
||||
NeuralTangentKernelWeighting(update_every_n_epochs=1.5)
|
||||
|
||||
# Should fail if update_every_n_epochs is not > 0
|
||||
with pytest.raises(AssertionError):
|
||||
NeuralTangentKernelWeighting(update_every_n_epochs=0)
|
||||
|
||||
# Should fail if update_every_n_epochs is not > 0
|
||||
with pytest.raises(AssertionError):
|
||||
NeuralTangentKernelWeighting(update_every_n_epochs=-3)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("update_every_n_epochs", [1, 3])
|
||||
@pytest.mark.parametrize("alpha", [0.0, 0.5, 1.0])
|
||||
def test_train_aggregation(alpha):
|
||||
weighting = NeuralTangentKernelWeighting(alpha=alpha)
|
||||
def test_train_aggregation(update_every_n_epochs, alpha):
|
||||
weighting = NeuralTangentKernelWeighting(
|
||||
update_every_n_epochs=update_every_n_epochs, alpha=alpha
|
||||
)
|
||||
solver = PINN(problem=problem, model=model, weighting=weighting)
|
||||
trainer = Trainer(solver=solver, max_epochs=5, accelerator="cpu")
|
||||
trainer.train()
|
||||
|
||||
@@ -29,20 +29,6 @@ def test_constructor(weights):
|
||||
ScalarWeighting(weights=[1, 2, 3])
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"weights", [1, 1.0, dict(zip(condition_names, [1] * len(condition_names)))]
|
||||
)
|
||||
def test_aggregate(weights):
|
||||
weighting = ScalarWeighting(weights=weights)
|
||||
losses = dict(
|
||||
zip(
|
||||
condition_names,
|
||||
[torch.randn(1) for _ in range(len(condition_names))],
|
||||
)
|
||||
)
|
||||
weighting.aggregate(losses=losses)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"weights", [1, 1.0, dict(zip(condition_names, [1] * len(condition_names)))]
|
||||
)
|
||||
|
||||
@@ -12,26 +12,28 @@ problem.discretise_domain(10)
|
||||
model = FeedForward(len(problem.input_variables), len(problem.output_variables))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("k", [10, 100, 1000])
|
||||
def test_constructor(k):
|
||||
SelfAdaptiveWeighting(k=k)
|
||||
@pytest.mark.parametrize("update_every_n_epochs", [10, 100, 1000])
|
||||
def test_constructor(update_every_n_epochs):
|
||||
SelfAdaptiveWeighting(update_every_n_epochs=update_every_n_epochs)
|
||||
|
||||
# Should fail if k is not an integer
|
||||
# Should fail if update_every_n_epochs is not an integer
|
||||
with pytest.raises(AssertionError):
|
||||
SelfAdaptiveWeighting(k=1.5)
|
||||
SelfAdaptiveWeighting(update_every_n_epochs=1.5)
|
||||
|
||||
# Should fail if k is not > 0
|
||||
# Should fail if update_every_n_epochs is not > 0
|
||||
with pytest.raises(AssertionError):
|
||||
SelfAdaptiveWeighting(k=0)
|
||||
SelfAdaptiveWeighting(update_every_n_epochs=0)
|
||||
|
||||
# Should fail if k is not > 0
|
||||
# Should fail if update_every_n_epochs is not > 0
|
||||
with pytest.raises(AssertionError):
|
||||
SelfAdaptiveWeighting(k=-3)
|
||||
SelfAdaptiveWeighting(update_every_n_epochs=-3)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("k", [2, 3])
|
||||
def test_train_aggregation(k):
|
||||
weighting = SelfAdaptiveWeighting(k=k)
|
||||
@pytest.mark.parametrize("update_every_n_epochs", [1, 3])
|
||||
def test_train_aggregation(update_every_n_epochs):
|
||||
weighting = SelfAdaptiveWeighting(
|
||||
update_every_n_epochs=update_every_n_epochs
|
||||
)
|
||||
solver = PINN(problem=problem, model=model, weighting=weighting)
|
||||
trainer = Trainer(solver=solver, max_epochs=5, accelerator="cpu")
|
||||
trainer.train()
|
||||
|
||||
Reference in New Issue
Block a user