weighting refactory

Co-authored-by: Dario Coscia <dariocos99@gmail.com>
This commit is contained in:
giovanni
2025-09-01 11:00:14 +02:00
committed by Giovanni Canali
parent c42bdd575c
commit 96402baf20
12 changed files with 214 additions and 388 deletions

View File

@@ -1,7 +0,0 @@
Weighting callbacks
========================
.. currentmodule:: pina.callback.linear_weight_update_callback
.. autoclass:: LinearWeightUpdate
:members:
:show-inheritance:

View File

@@ -4,11 +4,9 @@ __all__ = [
"SwitchOptimizer", "SwitchOptimizer",
"MetricTracker", "MetricTracker",
"PINAProgressBar", "PINAProgressBar",
"LinearWeightUpdate",
"R3Refinement", "R3Refinement",
] ]
from .optimizer_callback import SwitchOptimizer from .optimizer_callback import SwitchOptimizer
from .processing_callback import MetricTracker, PINAProgressBar from .processing_callback import MetricTracker, PINAProgressBar
from .linear_weight_update_callback import LinearWeightUpdate
from .refinement import R3Refinement from .refinement import R3Refinement

View File

@@ -1,87 +0,0 @@
"""Module for the LinearWeightUpdate callback."""
import warnings
from lightning.pytorch.callbacks import Callback
from ..utils import check_consistency
from ..loss import ScalarWeighting
class LinearWeightUpdate(Callback):
"""
Callback to linearly adjust the weight of a condition from an
initial value to a target value over a specified number of epochs.
"""
def __init__(
self, target_epoch, condition_name, initial_value, target_value
):
"""
Callback initialization.
:param int target_epoch: The epoch at which the weight of the condition
should reach the target value.
:param str condition_name: The name of the condition whose weight
should be adjusted.
:param float initial_value: The initial value of the weight.
:param float target_value: The target value of the weight.
"""
super().__init__()
self.target_epoch = target_epoch
self.condition_name = condition_name
self.initial_value = initial_value
self.target_value = target_value
# Check consistency
check_consistency(self.target_epoch, int, subclass=False)
check_consistency(self.condition_name, str, subclass=False)
check_consistency(self.initial_value, (float, int), subclass=False)
check_consistency(self.target_value, (float, int), subclass=False)
def on_train_start(self, trainer, pl_module):
"""
Initialize the weight of the condition to the specified `initial_value`.
:param Trainer trainer: A :class:`~pina.trainer.Trainer` instance.
:param SolverInterface pl_module: A
:class:`~pina.solver.solver.SolverInterface` instance.
"""
# Check that the target epoch is valid
if not 0 < self.target_epoch <= trainer.max_epochs:
raise ValueError(
"`target_epoch` must be greater than 0"
" and less than or equal to `max_epochs`."
)
# Check that the condition is a problem condition
if self.condition_name not in pl_module.problem.conditions:
raise ValueError(
f"`{self.condition_name}` must be a problem condition."
)
# Check that the initial value is not equal to the target value
if self.initial_value == self.target_value:
warnings.warn(
"`initial_value` is equal to `target_value`. "
"No effective adjustment will be performed.",
UserWarning,
)
# Check that the weighting schema is ScalarWeighting
if not isinstance(pl_module.weighting, ScalarWeighting):
raise ValueError("The weighting schema must be ScalarWeighting.")
# Initialize the weight of the condition
pl_module.weighting.weights[self.condition_name] = self.initial_value
def on_train_epoch_start(self, trainer, pl_module):
"""
Adjust at each epoch the weight of the condition.
:param Trainer trainer: A :class:`~pina.trainer.Trainer` instance.
:param SolverInterface pl_module: A
:class:`~pina.solver.solver.SolverInterface` instance.
"""
if 0 < trainer.current_epoch <= self.target_epoch:
pl_module.weighting.weights[self.condition_name] += (
self.target_value - self.initial_value
) / (self.target_epoch - 1)

View File

@@ -2,7 +2,7 @@
import torch import torch
from .weighting_interface import WeightingInterface from .weighting_interface import WeightingInterface
from ..utils import check_consistency from ..utils import check_consistency, in_range
class NeuralTangentKernelWeighting(WeightingInterface): class NeuralTangentKernelWeighting(WeightingInterface):
@@ -20,32 +20,34 @@ class NeuralTangentKernelWeighting(WeightingInterface):
""" """
def __init__(self, alpha=0.5): def __init__(self, update_every_n_epochs=1, alpha=0.5):
""" """
Initialization of the :class:`NeuralTangentKernelWeighting` class. Initialization of the :class:`NeuralTangentKernelWeighting` class.
:param int update_every_n_epochs: The number of training epochs between
weight updates. If set to 1, the weights are updated at every epoch.
Default is 1.
:param float alpha: The alpha parameter. :param float alpha: The alpha parameter.
:raises ValueError: If ``alpha`` is not between 0 and 1 (inclusive). :raises ValueError: If ``alpha`` is not between 0 and 1 (inclusive).
""" """
super().__init__() super().__init__(update_every_n_epochs=update_every_n_epochs)
# Check consistency # Check consistency
check_consistency(alpha, float) check_consistency(alpha, float)
if alpha < 0 or alpha > 1: if not in_range(alpha, [0, 1], strict=False):
raise ValueError("alpha should be a value between 0 and 1") raise ValueError("alpha must be in range (0, 1).")
# Initialize parameters # Initialize parameters
self.alpha = alpha self.alpha = alpha
self.weights = {} self.weights = {}
self.default_value_weights = 1.0
def aggregate(self, losses): def weights_update(self, losses):
""" """
Weight the losses according to the Neural Tangent Kernel algorithm. Update the weighting scheme based on the given losses.
:param dict(torch.Tensor) input: The dictionary of losses. :param dict losses: The dictionary of losses.
:return: The aggregation of the losses. It should be a scalar Tensor. :return: The updated weights.
:rtype: torch.Tensor :rtype: dict
""" """
# Define a dictionary to store the norms of the gradients # Define a dictionary to store the norms of the gradients
losses_norm = {} losses_norm = {}
@@ -60,14 +62,10 @@ class NeuralTangentKernelWeighting(WeightingInterface):
# Update the weights # Update the weights
self.weights = { self.weights = {
condition: self.alpha condition: self.alpha * self.weights.get(condition, 1)
* self.weights.get(condition, self.default_value_weights)
+ (1 - self.alpha) + (1 - self.alpha)
* losses_norm[condition] * losses_norm[condition]
/ sum(losses_norm.values()) / sum(losses_norm.values())
for condition in losses for condition in losses
} }
return self.weights
return sum(
self.weights[condition] * loss for condition, loss in losses.items()
)

View File

@@ -4,22 +4,6 @@ from .weighting_interface import WeightingInterface
from ..utils import check_consistency from ..utils import check_consistency
class _NoWeighting(WeightingInterface):
"""
Weighting scheme that does not apply any weighting to the losses.
"""
def aggregate(self, losses):
"""
Aggregate the losses.
:param dict losses: The dictionary of losses.
:return: The aggregated losses.
:rtype: torch.Tensor
"""
return sum(losses.values())
class ScalarWeighting(WeightingInterface): class ScalarWeighting(WeightingInterface):
""" """
Weighting scheme that assigns a scalar weight to each loss term. Weighting scheme that assigns a scalar weight to each loss term.
@@ -36,28 +20,42 @@ class ScalarWeighting(WeightingInterface):
dictionary, the default value is used. dictionary, the default value is used.
:type weights: float | int | dict :type weights: float | int | dict
""" """
super().__init__() super().__init__(update_every_n_epochs=1, aggregator="sum")
# Check consistency # Check consistency
check_consistency([weights], (float, dict, int)) check_consistency([weights], (float, dict, int))
# Weights initialization # Initialization
if isinstance(weights, (float, int)): if isinstance(weights, dict):
self.values = weights
self.default_value_weights = 1
elif isinstance(weights, (float, int)):
self.values = {}
self.default_value_weights = weights self.default_value_weights = weights
self.weights = {}
else: else:
self.default_value_weights = 1.0 raise ValueError
self.weights = weights
def aggregate(self, losses): def weights_update(self, losses):
""" """
Aggregate the losses. Update the weighting scheme based on the given losses.
:param dict losses: The dictionary of losses. :param dict losses: The dictionary of losses.
:return: The aggregated losses. :return: The updated weights.
:rtype: torch.Tensor :rtype: dict
""" """
return sum( return {
self.weights.get(condition, self.default_value_weights) * loss condition: self.values.get(condition, self.default_value_weights)
for condition, loss in losses.items() for condition in losses.keys()
) }
class _NoWeighting(ScalarWeighting):
"""
Weighting scheme that does not apply any weighting to the losses.
"""
def __init__(self):
"""
Initialization of the :class:`_NoWeighting` class.
"""
super().__init__(weights=1)

View File

@@ -2,7 +2,6 @@
import torch import torch
from .weighting_interface import WeightingInterface from .weighting_interface import WeightingInterface
from ..utils import check_positive_integer
class SelfAdaptiveWeighting(WeightingInterface): class SelfAdaptiveWeighting(WeightingInterface):
@@ -22,59 +21,37 @@ class SelfAdaptiveWeighting(WeightingInterface):
""" """
def __init__(self, k=100): def __init__(self, update_every_n_epochs=1):
""" """
Initialization of the :class:`SelfAdaptiveWeighting` class. Initialization of the :class:`SelfAdaptiveWeighting` class.
:param int k: The number of epochs after which the weights are updated. :param int update_every_n_epochs: The number of training epochs between
Default is 100. weight updates. If set to 1, the weights are updated at every epoch.
Default is 1.
:raises ValueError: If ``k`` is not a positive integer.
""" """
super().__init__() super().__init__(update_every_n_epochs=update_every_n_epochs)
# Check consistency def weights_update(self, losses):
check_positive_integer(value=k, strict=True)
# Initialize parameters
self.k = k
self.weights = {}
self.default_value_weights = 1.0
def aggregate(self, losses):
""" """
Weight the losses according to the self-adaptive algorithm. Update the weighting scheme based on the given losses.
:param dict(torch.Tensor) losses: The dictionary of losses. :param dict losses: The dictionary of losses.
:return: The aggregation of the losses. It should be a scalar Tensor. :return: The updated weights.
:rtype: torch.Tensor :rtype: dict
""" """
# If weights have not been initialized, set them to 1 # Define a dictionary to store the norms of the gradients
if not self.weights: losses_norm = {}
self.weights = {
condition: self.default_value_weights for condition in losses
}
# Update every k epochs # Compute the gradient norms for each loss component
if self.solver.trainer.current_epoch % self.k == 0: for condition, loss in losses.items():
loss.backward(retain_graph=True)
grads = torch.cat(
[p.grad.flatten() for p in self.solver.model.parameters()]
)
losses_norm[condition] = grads.norm()
# Define a dictionary to store the norms of the gradients # Update the weights
losses_norm = {} return {
condition: sum(losses_norm.values()) / losses_norm[condition]
# Compute the gradient norms for each loss component for condition in losses
for condition, loss in losses.items(): }
loss.backward(retain_graph=True)
grads = torch.cat(
[p.grad.flatten() for p in self.solver.model.parameters()]
)
losses_norm[condition] = grads.norm()
# Update the weights
self.weights = {
condition: sum(losses_norm.values()) / losses_norm[condition]
for condition in losses
}
return sum(
self.weights[condition] * loss for condition, loss in losses.items()
)

View File

@@ -1,6 +1,10 @@
"""Module for the Weighting Interface.""" """Module for the Weighting Interface."""
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
from typing import final
from ..utils import check_positive_integer, is_function
_AGGREGATE_METHODS = {"sum": sum, "mean": lambda x: sum(x) / len(x)}
class WeightingInterface(metaclass=ABCMeta): class WeightingInterface(metaclass=ABCMeta):
@@ -9,20 +13,93 @@ class WeightingInterface(metaclass=ABCMeta):
should inherit from this class. should inherit from this class.
""" """
def __init__(self): def __init__(self, update_every_n_epochs=1, aggregator="sum"):
""" """
Initialization of the :class:`WeightingInterface` class. Initialization of the :class:`WeightingInterface` class.
:param int update_every_n_epochs: The number of training epochs between
weight updates. If set to 1, the weights are updated at every epoch.
This parameter is ignored by static weighting schemes. Default is 1.
:param aggregator: The aggregation method. Either:
- 'sum' → torch.sum
- 'mean' → torch.mean
- callable → custom aggregation function
:type aggregator: str | Callable
""" """
# Check consistency
check_positive_integer(value=update_every_n_epochs, strict=True)
# Aggregation
if isinstance(aggregator, str):
if aggregator not in _AGGREGATE_METHODS:
raise ValueError(
f"Invalid aggregator '{aggregator}'. Must be one of "
f"{list(_AGGREGATE_METHODS.keys())}."
)
aggregator = _AGGREGATE_METHODS[aggregator]
elif not is_function(aggregator):
raise TypeError(
f"Aggregator must be either a string or a callable, "
f"got {type(aggregator).__name__}."
)
# Initialization
self._solver = None self._solver = None
self.update_every_n_epochs = update_every_n_epochs
self.aggregator_fn = aggregator
self._saved_weights = {}
@abstractmethod @abstractmethod
def aggregate(self, losses): def weights_update(self, losses):
""" """
Aggregate the losses. Update the weighting scheme based on the given losses.
This method must be implemented by subclasses. Its role is to update the
values of the weights. The updated weights will then be used by
:meth:`aggregate` to compute the final aggregated loss.
:param dict losses: The dictionary of losses. :param dict losses: The dictionary of losses.
:return: The updated weights.
:rtype: dict
""" """
@final
def aggregate(self, losses):
"""
Update the weights (if needed) and aggregate the given losses.
This method first checks whether the loss weights need to be updated
based on the current epoch and the ``update_every_n_epochs`` setting.
If an update is required, it calls :meth:`weights_update` to refresh the
weights. Afterwards, it aggregates the (weighted) losses into a single
scalar tensor using the configured aggregator function. This method must
not be overridden.
:param dict losses: The dictionary of losses.
:return: The aggregated loss tensor.
:rtype: torch.Tensor
"""
# Update weights
if self.solver.trainer.current_epoch % self.update_every_n_epochs == 0:
self._saved_weights = self.weights_update(losses)
# Aggregate. Using direct indexing instead of .get() ensures that a
# KeyError is raised if the expected condition is missing from the dict.
return self.aggregator_fn(
self._saved_weights[condition] * loss
for condition, loss in losses.items()
)
def last_saved_weights(self):
"""
Get the last saved weights.
:return: The last saved weights.
:rtype: dict
"""
return self._saved_weights
@property @property
def solver(self): def solver(self):
""" """

View File

@@ -240,3 +240,31 @@ def check_positive_integer(value, strict=True):
assert ( assert (
isinstance(value, int) and value >= 0 isinstance(value, int) and value >= 0
), f"Expected a non-negative integer, got {value}." ), f"Expected a non-negative integer, got {value}."
def in_range(value, range_vals, strict=True):
"""
Check if a value is within a specified range.
:param int value: The integer value to check.
:param list[int] range_vals: A list of two integers representing the range
limits. The first element specifies the lower bound, and the second
specifies the upper bound.
:param bool strict: If True, the value must be strictly positive.
Default is True.
:return: True if the value satisfies the range condition, False otherwise.
:rtype: bool
"""
# Validate inputs
check_consistency(value, (float, int))
check_consistency(range_vals, (float, int))
assert (
isinstance(range_vals, list) and len(range_vals) == 2
), "range_vals must be a list of two integers [lower, upper]"
lower, upper = range_vals
# Check the range
if strict:
return lower < value < upper
return lower <= value <= upper

View File

@@ -1,164 +0,0 @@
import pytest
import math
from pina.solver import PINN
from pina.loss import ScalarWeighting
from pina.trainer import Trainer
from pina.model import FeedForward
from pina.problem.zoo import Poisson2DSquareProblem as Poisson
from pina.callback import LinearWeightUpdate
# Define the problem
poisson_problem = Poisson()
poisson_problem.discretise_domain(50, "grid")
cond_name = list(poisson_problem.conditions.keys())[0]
# Define the model
model = FeedForward(
input_dimensions=len(poisson_problem.input_variables),
output_dimensions=len(poisson_problem.output_variables),
layers=[32, 32],
)
# Define the weighting schema
weights_dict = {key: 1 for key in poisson_problem.conditions.keys()}
weighting = ScalarWeighting(weights=weights_dict)
# Define the solver
solver = PINN(problem=poisson_problem, model=model, weighting=weighting)
# Value used for testing
epochs = 10
@pytest.mark.parametrize("initial_value", [1, 5.5])
@pytest.mark.parametrize("target_value", [10, 25.5])
def test_constructor(initial_value, target_value):
LinearWeightUpdate(
target_epoch=epochs,
condition_name=cond_name,
initial_value=initial_value,
target_value=target_value,
)
# Target_epoch must be int
with pytest.raises(ValueError):
LinearWeightUpdate(
target_epoch=10.0,
condition_name=cond_name,
initial_value=0,
target_value=1,
)
# Condition_name must be str
with pytest.raises(ValueError):
LinearWeightUpdate(
target_epoch=epochs,
condition_name=100,
initial_value=0,
target_value=1,
)
# Initial_value must be float or int
with pytest.raises(ValueError):
LinearWeightUpdate(
target_epoch=epochs,
condition_name=cond_name,
initial_value="0",
target_value=1,
)
# Target_value must be float or int
with pytest.raises(ValueError):
LinearWeightUpdate(
target_epoch=epochs,
condition_name=cond_name,
initial_value=0,
target_value="1",
)
@pytest.mark.parametrize("initial_value, target_value", [(1, 10), (10, 1)])
def test_training(initial_value, target_value):
callback = LinearWeightUpdate(
target_epoch=epochs,
condition_name=cond_name,
initial_value=initial_value,
target_value=target_value,
)
trainer = Trainer(
solver=solver,
callbacks=[callback],
accelerator="cpu",
max_epochs=epochs,
)
trainer.train()
# Check that the final weight value matches the target value
final_value = solver.weighting.weights[cond_name]
assert math.isclose(final_value, target_value)
# Target_epoch must be greater than 0
with pytest.raises(ValueError):
callback = LinearWeightUpdate(
target_epoch=0,
condition_name=cond_name,
initial_value=0,
target_value=1,
)
trainer = Trainer(
solver=solver,
callbacks=[callback],
accelerator="cpu",
max_epochs=5,
)
trainer.train()
# Target_epoch must be less than or equal to max_epochs
with pytest.raises(ValueError):
callback = LinearWeightUpdate(
target_epoch=epochs,
condition_name=cond_name,
initial_value=0,
target_value=1,
)
trainer = Trainer(
solver=solver,
callbacks=[callback],
accelerator="cpu",
max_epochs=epochs - 1,
)
trainer.train()
# Condition_name must be a problem condition
with pytest.raises(ValueError):
callback = LinearWeightUpdate(
target_epoch=epochs,
condition_name="not_a_condition",
initial_value=0,
target_value=1,
)
trainer = Trainer(
solver=solver,
callbacks=[callback],
accelerator="cpu",
max_epochs=epochs,
)
trainer.train()
# Weighting schema must be ScalarWeighting
with pytest.raises(ValueError):
callback = LinearWeightUpdate(
target_epoch=epochs,
condition_name=cond_name,
initial_value=0,
target_value=1,
)
unweighted_solver = PINN(problem=poisson_problem, model=model)
trainer = Trainer(
solver=unweighted_solver,
callbacks=[callback],
accelerator="cpu",
max_epochs=epochs,
)
trainer.train()

View File

@@ -12,22 +12,42 @@ problem.discretise_domain(10)
model = FeedForward(len(problem.input_variables), len(problem.output_variables)) model = FeedForward(len(problem.input_variables), len(problem.output_variables))
@pytest.mark.parametrize("update_every_n_epochs", [1, 10, 100, 1000])
@pytest.mark.parametrize("alpha", [0.0, 0.5, 1.0]) @pytest.mark.parametrize("alpha", [0.0, 0.5, 1.0])
def test_constructor(alpha): def test_constructor(update_every_n_epochs, alpha):
NeuralTangentKernelWeighting(alpha=alpha) NeuralTangentKernelWeighting(
update_every_n_epochs=update_every_n_epochs, alpha=alpha
)
# Should fail if alpha is not >= 0 # Should fail if alpha is not >= 0
with pytest.raises(ValueError): with pytest.raises(ValueError):
NeuralTangentKernelWeighting(alpha=-0.1) NeuralTangentKernelWeighting(
update_every_n_epochs=update_every_n_epochs, alpha=-0.1
)
# Should fail if alpha is not <= 1 # Should fail if alpha is not <= 1
with pytest.raises(ValueError): with pytest.raises(ValueError):
NeuralTangentKernelWeighting(alpha=1.1) NeuralTangentKernelWeighting(alpha=1.1)
# Should fail if update_every_n_epochs is not an integer
with pytest.raises(AssertionError):
NeuralTangentKernelWeighting(update_every_n_epochs=1.5)
# Should fail if update_every_n_epochs is not > 0
with pytest.raises(AssertionError):
NeuralTangentKernelWeighting(update_every_n_epochs=0)
# Should fail if update_every_n_epochs is not > 0
with pytest.raises(AssertionError):
NeuralTangentKernelWeighting(update_every_n_epochs=-3)
@pytest.mark.parametrize("update_every_n_epochs", [1, 3])
@pytest.mark.parametrize("alpha", [0.0, 0.5, 1.0]) @pytest.mark.parametrize("alpha", [0.0, 0.5, 1.0])
def test_train_aggregation(alpha): def test_train_aggregation(update_every_n_epochs, alpha):
weighting = NeuralTangentKernelWeighting(alpha=alpha) weighting = NeuralTangentKernelWeighting(
update_every_n_epochs=update_every_n_epochs, alpha=alpha
)
solver = PINN(problem=problem, model=model, weighting=weighting) solver = PINN(problem=problem, model=model, weighting=weighting)
trainer = Trainer(solver=solver, max_epochs=5, accelerator="cpu") trainer = Trainer(solver=solver, max_epochs=5, accelerator="cpu")
trainer.train() trainer.train()

View File

@@ -29,20 +29,6 @@ def test_constructor(weights):
ScalarWeighting(weights=[1, 2, 3]) ScalarWeighting(weights=[1, 2, 3])
@pytest.mark.parametrize(
"weights", [1, 1.0, dict(zip(condition_names, [1] * len(condition_names)))]
)
def test_aggregate(weights):
weighting = ScalarWeighting(weights=weights)
losses = dict(
zip(
condition_names,
[torch.randn(1) for _ in range(len(condition_names))],
)
)
weighting.aggregate(losses=losses)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"weights", [1, 1.0, dict(zip(condition_names, [1] * len(condition_names)))] "weights", [1, 1.0, dict(zip(condition_names, [1] * len(condition_names)))]
) )

View File

@@ -12,26 +12,28 @@ problem.discretise_domain(10)
model = FeedForward(len(problem.input_variables), len(problem.output_variables)) model = FeedForward(len(problem.input_variables), len(problem.output_variables))
@pytest.mark.parametrize("k", [10, 100, 1000]) @pytest.mark.parametrize("update_every_n_epochs", [10, 100, 1000])
def test_constructor(k): def test_constructor(update_every_n_epochs):
SelfAdaptiveWeighting(k=k) SelfAdaptiveWeighting(update_every_n_epochs=update_every_n_epochs)
# Should fail if k is not an integer # Should fail if update_every_n_epochs is not an integer
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
SelfAdaptiveWeighting(k=1.5) SelfAdaptiveWeighting(update_every_n_epochs=1.5)
# Should fail if k is not > 0 # Should fail if update_every_n_epochs is not > 0
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
SelfAdaptiveWeighting(k=0) SelfAdaptiveWeighting(update_every_n_epochs=0)
# Should fail if k is not > 0 # Should fail if update_every_n_epochs is not > 0
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
SelfAdaptiveWeighting(k=-3) SelfAdaptiveWeighting(update_every_n_epochs=-3)
@pytest.mark.parametrize("k", [2, 3]) @pytest.mark.parametrize("update_every_n_epochs", [1, 3])
def test_train_aggregation(k): def test_train_aggregation(update_every_n_epochs):
weighting = SelfAdaptiveWeighting(k=k) weighting = SelfAdaptiveWeighting(
update_every_n_epochs=update_every_n_epochs
)
solver = PINN(problem=problem, model=model, weighting=weighting) solver = PINN(problem=problem, model=model, weighting=weighting)
trainer = Trainer(solver=solver, max_epochs=5, accelerator="cpu") trainer = Trainer(solver=solver, max_epochs=5, accelerator="cpu")
trainer.train() trainer.train()