diff --git a/docs/source/_rst/callback/linear_weight_update_callback.rst b/docs/source/_rst/callback/linear_weight_update_callback.rst deleted file mode 100644 index fe45b56..0000000 --- a/docs/source/_rst/callback/linear_weight_update_callback.rst +++ /dev/null @@ -1,7 +0,0 @@ -Weighting callbacks -======================== - -.. currentmodule:: pina.callback.linear_weight_update_callback -.. autoclass:: LinearWeightUpdate - :members: - :show-inheritance: \ No newline at end of file diff --git a/pina/callback/__init__.py b/pina/callback/__init__.py index dc1164e..e9a70ea 100644 --- a/pina/callback/__init__.py +++ b/pina/callback/__init__.py @@ -4,11 +4,9 @@ __all__ = [ "SwitchOptimizer", "MetricTracker", "PINAProgressBar", - "LinearWeightUpdate", "R3Refinement", ] from .optimizer_callback import SwitchOptimizer from .processing_callback import MetricTracker, PINAProgressBar -from .linear_weight_update_callback import LinearWeightUpdate from .refinement import R3Refinement diff --git a/pina/callback/linear_weight_update_callback.py b/pina/callback/linear_weight_update_callback.py deleted file mode 100644 index ae25ca1..0000000 --- a/pina/callback/linear_weight_update_callback.py +++ /dev/null @@ -1,87 +0,0 @@ -"""Module for the LinearWeightUpdate callback.""" - -import warnings -from lightning.pytorch.callbacks import Callback -from ..utils import check_consistency -from ..loss import ScalarWeighting - - -class LinearWeightUpdate(Callback): - """ - Callback to linearly adjust the weight of a condition from an - initial value to a target value over a specified number of epochs. - """ - - def __init__( - self, target_epoch, condition_name, initial_value, target_value - ): - """ - Callback initialization. - - :param int target_epoch: The epoch at which the weight of the condition - should reach the target value. - :param str condition_name: The name of the condition whose weight - should be adjusted. - :param float initial_value: The initial value of the weight. - :param float target_value: The target value of the weight. - """ - super().__init__() - self.target_epoch = target_epoch - self.condition_name = condition_name - self.initial_value = initial_value - self.target_value = target_value - - # Check consistency - check_consistency(self.target_epoch, int, subclass=False) - check_consistency(self.condition_name, str, subclass=False) - check_consistency(self.initial_value, (float, int), subclass=False) - check_consistency(self.target_value, (float, int), subclass=False) - - def on_train_start(self, trainer, pl_module): - """ - Initialize the weight of the condition to the specified `initial_value`. - - :param Trainer trainer: A :class:`~pina.trainer.Trainer` instance. - :param SolverInterface pl_module: A - :class:`~pina.solver.solver.SolverInterface` instance. - """ - # Check that the target epoch is valid - if not 0 < self.target_epoch <= trainer.max_epochs: - raise ValueError( - "`target_epoch` must be greater than 0" - " and less than or equal to `max_epochs`." - ) - - # Check that the condition is a problem condition - if self.condition_name not in pl_module.problem.conditions: - raise ValueError( - f"`{self.condition_name}` must be a problem condition." - ) - - # Check that the initial value is not equal to the target value - if self.initial_value == self.target_value: - warnings.warn( - "`initial_value` is equal to `target_value`. " - "No effective adjustment will be performed.", - UserWarning, - ) - - # Check that the weighting schema is ScalarWeighting - if not isinstance(pl_module.weighting, ScalarWeighting): - raise ValueError("The weighting schema must be ScalarWeighting.") - - # Initialize the weight of the condition - pl_module.weighting.weights[self.condition_name] = self.initial_value - - def on_train_epoch_start(self, trainer, pl_module): - """ - Adjust at each epoch the weight of the condition. - - :param Trainer trainer: A :class:`~pina.trainer.Trainer` instance. - :param SolverInterface pl_module: A - :class:`~pina.solver.solver.SolverInterface` instance. - """ - if 0 < trainer.current_epoch <= self.target_epoch: - pl_module.weighting.weights[self.condition_name] += ( - self.target_value - self.initial_value - ) / (self.target_epoch - 1) diff --git a/pina/loss/ntk_weighting.py b/pina/loss/ntk_weighting.py index 6149f23..b888126 100644 --- a/pina/loss/ntk_weighting.py +++ b/pina/loss/ntk_weighting.py @@ -2,7 +2,7 @@ import torch from .weighting_interface import WeightingInterface -from ..utils import check_consistency +from ..utils import check_consistency, in_range class NeuralTangentKernelWeighting(WeightingInterface): @@ -20,32 +20,34 @@ class NeuralTangentKernelWeighting(WeightingInterface): """ - def __init__(self, alpha=0.5): + def __init__(self, update_every_n_epochs=1, alpha=0.5): """ Initialization of the :class:`NeuralTangentKernelWeighting` class. + :param int update_every_n_epochs: The number of training epochs between + weight updates. If set to 1, the weights are updated at every epoch. + Default is 1. :param float alpha: The alpha parameter. :raises ValueError: If ``alpha`` is not between 0 and 1 (inclusive). """ - super().__init__() + super().__init__(update_every_n_epochs=update_every_n_epochs) # Check consistency check_consistency(alpha, float) - if alpha < 0 or alpha > 1: - raise ValueError("alpha should be a value between 0 and 1") + if not in_range(alpha, [0, 1], strict=False): + raise ValueError("alpha must be in range (0, 1).") # Initialize parameters self.alpha = alpha self.weights = {} - self.default_value_weights = 1.0 - def aggregate(self, losses): + def weights_update(self, losses): """ - Weight the losses according to the Neural Tangent Kernel algorithm. + Update the weighting scheme based on the given losses. - :param dict(torch.Tensor) input: The dictionary of losses. - :return: The aggregation of the losses. It should be a scalar Tensor. - :rtype: torch.Tensor + :param dict losses: The dictionary of losses. + :return: The updated weights. + :rtype: dict """ # Define a dictionary to store the norms of the gradients losses_norm = {} @@ -60,14 +62,10 @@ class NeuralTangentKernelWeighting(WeightingInterface): # Update the weights self.weights = { - condition: self.alpha - * self.weights.get(condition, self.default_value_weights) + condition: self.alpha * self.weights.get(condition, 1) + (1 - self.alpha) * losses_norm[condition] / sum(losses_norm.values()) for condition in losses } - - return sum( - self.weights[condition] * loss for condition, loss in losses.items() - ) + return self.weights diff --git a/pina/loss/scalar_weighting.py b/pina/loss/scalar_weighting.py index c10b574..d770c89 100644 --- a/pina/loss/scalar_weighting.py +++ b/pina/loss/scalar_weighting.py @@ -4,22 +4,6 @@ from .weighting_interface import WeightingInterface from ..utils import check_consistency -class _NoWeighting(WeightingInterface): - """ - Weighting scheme that does not apply any weighting to the losses. - """ - - def aggregate(self, losses): - """ - Aggregate the losses. - - :param dict losses: The dictionary of losses. - :return: The aggregated losses. - :rtype: torch.Tensor - """ - return sum(losses.values()) - - class ScalarWeighting(WeightingInterface): """ Weighting scheme that assigns a scalar weight to each loss term. @@ -36,28 +20,42 @@ class ScalarWeighting(WeightingInterface): dictionary, the default value is used. :type weights: float | int | dict """ - super().__init__() + super().__init__(update_every_n_epochs=1, aggregator="sum") # Check consistency check_consistency([weights], (float, dict, int)) - # Weights initialization - if isinstance(weights, (float, int)): + # Initialization + if isinstance(weights, dict): + self.values = weights + self.default_value_weights = 1 + elif isinstance(weights, (float, int)): + self.values = {} self.default_value_weights = weights - self.weights = {} else: - self.default_value_weights = 1.0 - self.weights = weights + raise ValueError - def aggregate(self, losses): + def weights_update(self, losses): """ - Aggregate the losses. + Update the weighting scheme based on the given losses. :param dict losses: The dictionary of losses. - :return: The aggregated losses. - :rtype: torch.Tensor + :return: The updated weights. + :rtype: dict """ - return sum( - self.weights.get(condition, self.default_value_weights) * loss - for condition, loss in losses.items() - ) + return { + condition: self.values.get(condition, self.default_value_weights) + for condition in losses.keys() + } + + +class _NoWeighting(ScalarWeighting): + """ + Weighting scheme that does not apply any weighting to the losses. + """ + + def __init__(self): + """ + Initialization of the :class:`_NoWeighting` class. + """ + super().__init__(weights=1) diff --git a/pina/loss/self_adaptive_weighting.py b/pina/loss/self_adaptive_weighting.py index 8533078..62196c5 100644 --- a/pina/loss/self_adaptive_weighting.py +++ b/pina/loss/self_adaptive_weighting.py @@ -2,7 +2,6 @@ import torch from .weighting_interface import WeightingInterface -from ..utils import check_positive_integer class SelfAdaptiveWeighting(WeightingInterface): @@ -22,59 +21,37 @@ class SelfAdaptiveWeighting(WeightingInterface): """ - def __init__(self, k=100): + def __init__(self, update_every_n_epochs=1): """ Initialization of the :class:`SelfAdaptiveWeighting` class. - :param int k: The number of epochs after which the weights are updated. - Default is 100. - - :raises ValueError: If ``k`` is not a positive integer. + :param int update_every_n_epochs: The number of training epochs between + weight updates. If set to 1, the weights are updated at every epoch. + Default is 1. """ - super().__init__() + super().__init__(update_every_n_epochs=update_every_n_epochs) - # Check consistency - check_positive_integer(value=k, strict=True) - - # Initialize parameters - self.k = k - self.weights = {} - self.default_value_weights = 1.0 - - def aggregate(self, losses): + def weights_update(self, losses): """ - Weight the losses according to the self-adaptive algorithm. + Update the weighting scheme based on the given losses. - :param dict(torch.Tensor) losses: The dictionary of losses. - :return: The aggregation of the losses. It should be a scalar Tensor. - :rtype: torch.Tensor + :param dict losses: The dictionary of losses. + :return: The updated weights. + :rtype: dict """ - # If weights have not been initialized, set them to 1 - if not self.weights: - self.weights = { - condition: self.default_value_weights for condition in losses - } + # Define a dictionary to store the norms of the gradients + losses_norm = {} - # Update every k epochs - if self.solver.trainer.current_epoch % self.k == 0: + # Compute the gradient norms for each loss component + for condition, loss in losses.items(): + loss.backward(retain_graph=True) + grads = torch.cat( + [p.grad.flatten() for p in self.solver.model.parameters()] + ) + losses_norm[condition] = grads.norm() - # Define a dictionary to store the norms of the gradients - losses_norm = {} - - # Compute the gradient norms for each loss component - for condition, loss in losses.items(): - loss.backward(retain_graph=True) - grads = torch.cat( - [p.grad.flatten() for p in self.solver.model.parameters()] - ) - losses_norm[condition] = grads.norm() - - # Update the weights - self.weights = { - condition: sum(losses_norm.values()) / losses_norm[condition] - for condition in losses - } - - return sum( - self.weights[condition] * loss for condition, loss in losses.items() - ) + # Update the weights + return { + condition: sum(losses_norm.values()) / losses_norm[condition] + for condition in losses + } diff --git a/pina/loss/weighting_interface.py b/pina/loss/weighting_interface.py index 567d493..bc34c31 100644 --- a/pina/loss/weighting_interface.py +++ b/pina/loss/weighting_interface.py @@ -1,6 +1,10 @@ """Module for the Weighting Interface.""" from abc import ABCMeta, abstractmethod +from typing import final +from ..utils import check_positive_integer, is_function + +_AGGREGATE_METHODS = {"sum": sum, "mean": lambda x: sum(x) / len(x)} class WeightingInterface(metaclass=ABCMeta): @@ -9,20 +13,93 @@ class WeightingInterface(metaclass=ABCMeta): should inherit from this class. """ - def __init__(self): + def __init__(self, update_every_n_epochs=1, aggregator="sum"): """ Initialization of the :class:`WeightingInterface` class. + + :param int update_every_n_epochs: The number of training epochs between + weight updates. If set to 1, the weights are updated at every epoch. + This parameter is ignored by static weighting schemes. Default is 1. + :param aggregator: The aggregation method. Either: + - 'sum' → torch.sum + - 'mean' → torch.mean + - callable → custom aggregation function + :type aggregator: str | Callable """ + # Check consistency + check_positive_integer(value=update_every_n_epochs, strict=True) + + # Aggregation + if isinstance(aggregator, str): + if aggregator not in _AGGREGATE_METHODS: + raise ValueError( + f"Invalid aggregator '{aggregator}'. Must be one of " + f"{list(_AGGREGATE_METHODS.keys())}." + ) + aggregator = _AGGREGATE_METHODS[aggregator] + + elif not is_function(aggregator): + raise TypeError( + f"Aggregator must be either a string or a callable, " + f"got {type(aggregator).__name__}." + ) + + # Initialization self._solver = None + self.update_every_n_epochs = update_every_n_epochs + self.aggregator_fn = aggregator + self._saved_weights = {} @abstractmethod - def aggregate(self, losses): + def weights_update(self, losses): """ - Aggregate the losses. + Update the weighting scheme based on the given losses. + + This method must be implemented by subclasses. Its role is to update the + values of the weights. The updated weights will then be used by + :meth:`aggregate` to compute the final aggregated loss. :param dict losses: The dictionary of losses. + :return: The updated weights. + :rtype: dict """ + @final + def aggregate(self, losses): + """ + Update the weights (if needed) and aggregate the given losses. + + This method first checks whether the loss weights need to be updated + based on the current epoch and the ``update_every_n_epochs`` setting. + If an update is required, it calls :meth:`weights_update` to refresh the + weights. Afterwards, it aggregates the (weighted) losses into a single + scalar tensor using the configured aggregator function. This method must + not be overridden. + + :param dict losses: The dictionary of losses. + :return: The aggregated loss tensor. + :rtype: torch.Tensor + """ + # Update weights + if self.solver.trainer.current_epoch % self.update_every_n_epochs == 0: + self._saved_weights = self.weights_update(losses) + + # Aggregate. Using direct indexing instead of .get() ensures that a + # KeyError is raised if the expected condition is missing from the dict. + return self.aggregator_fn( + self._saved_weights[condition] * loss + for condition, loss in losses.items() + ) + + def last_saved_weights(self): + """ + Get the last saved weights. + + :return: The last saved weights. + :rtype: dict + """ + return self._saved_weights + @property def solver(self): """ diff --git a/pina/utils.py b/pina/utils.py index ddbd2e8..2aafba1 100644 --- a/pina/utils.py +++ b/pina/utils.py @@ -240,3 +240,31 @@ def check_positive_integer(value, strict=True): assert ( isinstance(value, int) and value >= 0 ), f"Expected a non-negative integer, got {value}." + + +def in_range(value, range_vals, strict=True): + """ + Check if a value is within a specified range. + + :param int value: The integer value to check. + :param list[int] range_vals: A list of two integers representing the range + limits. The first element specifies the lower bound, and the second + specifies the upper bound. + :param bool strict: If True, the value must be strictly positive. + Default is True. + :return: True if the value satisfies the range condition, False otherwise. + :rtype: bool + """ + # Validate inputs + check_consistency(value, (float, int)) + check_consistency(range_vals, (float, int)) + assert ( + isinstance(range_vals, list) and len(range_vals) == 2 + ), "range_vals must be a list of two integers [lower, upper]" + lower, upper = range_vals + + # Check the range + if strict: + return lower < value < upper + + return lower <= value <= upper diff --git a/tests/test_callback/test_linear_weight_update_callback.py b/tests/test_callback/test_linear_weight_update_callback.py deleted file mode 100644 index c1f4cf3..0000000 --- a/tests/test_callback/test_linear_weight_update_callback.py +++ /dev/null @@ -1,164 +0,0 @@ -import pytest -import math -from pina.solver import PINN -from pina.loss import ScalarWeighting -from pina.trainer import Trainer -from pina.model import FeedForward -from pina.problem.zoo import Poisson2DSquareProblem as Poisson -from pina.callback import LinearWeightUpdate - - -# Define the problem -poisson_problem = Poisson() -poisson_problem.discretise_domain(50, "grid") -cond_name = list(poisson_problem.conditions.keys())[0] - -# Define the model -model = FeedForward( - input_dimensions=len(poisson_problem.input_variables), - output_dimensions=len(poisson_problem.output_variables), - layers=[32, 32], -) - -# Define the weighting schema -weights_dict = {key: 1 for key in poisson_problem.conditions.keys()} -weighting = ScalarWeighting(weights=weights_dict) - -# Define the solver -solver = PINN(problem=poisson_problem, model=model, weighting=weighting) - -# Value used for testing -epochs = 10 - - -@pytest.mark.parametrize("initial_value", [1, 5.5]) -@pytest.mark.parametrize("target_value", [10, 25.5]) -def test_constructor(initial_value, target_value): - LinearWeightUpdate( - target_epoch=epochs, - condition_name=cond_name, - initial_value=initial_value, - target_value=target_value, - ) - - # Target_epoch must be int - with pytest.raises(ValueError): - LinearWeightUpdate( - target_epoch=10.0, - condition_name=cond_name, - initial_value=0, - target_value=1, - ) - - # Condition_name must be str - with pytest.raises(ValueError): - LinearWeightUpdate( - target_epoch=epochs, - condition_name=100, - initial_value=0, - target_value=1, - ) - - # Initial_value must be float or int - with pytest.raises(ValueError): - LinearWeightUpdate( - target_epoch=epochs, - condition_name=cond_name, - initial_value="0", - target_value=1, - ) - - # Target_value must be float or int - with pytest.raises(ValueError): - LinearWeightUpdate( - target_epoch=epochs, - condition_name=cond_name, - initial_value=0, - target_value="1", - ) - - -@pytest.mark.parametrize("initial_value, target_value", [(1, 10), (10, 1)]) -def test_training(initial_value, target_value): - callback = LinearWeightUpdate( - target_epoch=epochs, - condition_name=cond_name, - initial_value=initial_value, - target_value=target_value, - ) - trainer = Trainer( - solver=solver, - callbacks=[callback], - accelerator="cpu", - max_epochs=epochs, - ) - trainer.train() - - # Check that the final weight value matches the target value - final_value = solver.weighting.weights[cond_name] - assert math.isclose(final_value, target_value) - - # Target_epoch must be greater than 0 - with pytest.raises(ValueError): - callback = LinearWeightUpdate( - target_epoch=0, - condition_name=cond_name, - initial_value=0, - target_value=1, - ) - trainer = Trainer( - solver=solver, - callbacks=[callback], - accelerator="cpu", - max_epochs=5, - ) - trainer.train() - - # Target_epoch must be less than or equal to max_epochs - with pytest.raises(ValueError): - callback = LinearWeightUpdate( - target_epoch=epochs, - condition_name=cond_name, - initial_value=0, - target_value=1, - ) - trainer = Trainer( - solver=solver, - callbacks=[callback], - accelerator="cpu", - max_epochs=epochs - 1, - ) - trainer.train() - - # Condition_name must be a problem condition - with pytest.raises(ValueError): - callback = LinearWeightUpdate( - target_epoch=epochs, - condition_name="not_a_condition", - initial_value=0, - target_value=1, - ) - trainer = Trainer( - solver=solver, - callbacks=[callback], - accelerator="cpu", - max_epochs=epochs, - ) - trainer.train() - - # Weighting schema must be ScalarWeighting - with pytest.raises(ValueError): - callback = LinearWeightUpdate( - target_epoch=epochs, - condition_name=cond_name, - initial_value=0, - target_value=1, - ) - unweighted_solver = PINN(problem=poisson_problem, model=model) - trainer = Trainer( - solver=unweighted_solver, - callbacks=[callback], - accelerator="cpu", - max_epochs=epochs, - ) - trainer.train() diff --git a/tests/test_weighting/test_ntk_weighting.py b/tests/test_weighting/test_ntk_weighting.py index 236c498..49442b9 100644 --- a/tests/test_weighting/test_ntk_weighting.py +++ b/tests/test_weighting/test_ntk_weighting.py @@ -12,22 +12,42 @@ problem.discretise_domain(10) model = FeedForward(len(problem.input_variables), len(problem.output_variables)) +@pytest.mark.parametrize("update_every_n_epochs", [1, 10, 100, 1000]) @pytest.mark.parametrize("alpha", [0.0, 0.5, 1.0]) -def test_constructor(alpha): - NeuralTangentKernelWeighting(alpha=alpha) +def test_constructor(update_every_n_epochs, alpha): + NeuralTangentKernelWeighting( + update_every_n_epochs=update_every_n_epochs, alpha=alpha + ) # Should fail if alpha is not >= 0 with pytest.raises(ValueError): - NeuralTangentKernelWeighting(alpha=-0.1) + NeuralTangentKernelWeighting( + update_every_n_epochs=update_every_n_epochs, alpha=-0.1 + ) # Should fail if alpha is not <= 1 with pytest.raises(ValueError): NeuralTangentKernelWeighting(alpha=1.1) + # Should fail if update_every_n_epochs is not an integer + with pytest.raises(AssertionError): + NeuralTangentKernelWeighting(update_every_n_epochs=1.5) + # Should fail if update_every_n_epochs is not > 0 + with pytest.raises(AssertionError): + NeuralTangentKernelWeighting(update_every_n_epochs=0) + + # Should fail if update_every_n_epochs is not > 0 + with pytest.raises(AssertionError): + NeuralTangentKernelWeighting(update_every_n_epochs=-3) + + +@pytest.mark.parametrize("update_every_n_epochs", [1, 3]) @pytest.mark.parametrize("alpha", [0.0, 0.5, 1.0]) -def test_train_aggregation(alpha): - weighting = NeuralTangentKernelWeighting(alpha=alpha) +def test_train_aggregation(update_every_n_epochs, alpha): + weighting = NeuralTangentKernelWeighting( + update_every_n_epochs=update_every_n_epochs, alpha=alpha + ) solver = PINN(problem=problem, model=model, weighting=weighting) trainer = Trainer(solver=solver, max_epochs=5, accelerator="cpu") trainer.train() diff --git a/tests/test_weighting/test_scalar_weighting.py b/tests/test_weighting/test_scalar_weighting.py index 54b3293..bbf71af 100644 --- a/tests/test_weighting/test_scalar_weighting.py +++ b/tests/test_weighting/test_scalar_weighting.py @@ -29,20 +29,6 @@ def test_constructor(weights): ScalarWeighting(weights=[1, 2, 3]) -@pytest.mark.parametrize( - "weights", [1, 1.0, dict(zip(condition_names, [1] * len(condition_names)))] -) -def test_aggregate(weights): - weighting = ScalarWeighting(weights=weights) - losses = dict( - zip( - condition_names, - [torch.randn(1) for _ in range(len(condition_names))], - ) - ) - weighting.aggregate(losses=losses) - - @pytest.mark.parametrize( "weights", [1, 1.0, dict(zip(condition_names, [1] * len(condition_names)))] ) diff --git a/tests/test_weighting/test_self_adaptive_weighting.py b/tests/test_weighting/test_self_adaptive_weighting.py index b82f545..066e885 100644 --- a/tests/test_weighting/test_self_adaptive_weighting.py +++ b/tests/test_weighting/test_self_adaptive_weighting.py @@ -12,26 +12,28 @@ problem.discretise_domain(10) model = FeedForward(len(problem.input_variables), len(problem.output_variables)) -@pytest.mark.parametrize("k", [10, 100, 1000]) -def test_constructor(k): - SelfAdaptiveWeighting(k=k) +@pytest.mark.parametrize("update_every_n_epochs", [10, 100, 1000]) +def test_constructor(update_every_n_epochs): + SelfAdaptiveWeighting(update_every_n_epochs=update_every_n_epochs) - # Should fail if k is not an integer + # Should fail if update_every_n_epochs is not an integer with pytest.raises(AssertionError): - SelfAdaptiveWeighting(k=1.5) + SelfAdaptiveWeighting(update_every_n_epochs=1.5) - # Should fail if k is not > 0 + # Should fail if update_every_n_epochs is not > 0 with pytest.raises(AssertionError): - SelfAdaptiveWeighting(k=0) + SelfAdaptiveWeighting(update_every_n_epochs=0) - # Should fail if k is not > 0 + # Should fail if update_every_n_epochs is not > 0 with pytest.raises(AssertionError): - SelfAdaptiveWeighting(k=-3) + SelfAdaptiveWeighting(update_every_n_epochs=-3) -@pytest.mark.parametrize("k", [2, 3]) -def test_train_aggregation(k): - weighting = SelfAdaptiveWeighting(k=k) +@pytest.mark.parametrize("update_every_n_epochs", [1, 3]) +def test_train_aggregation(update_every_n_epochs): + weighting = SelfAdaptiveWeighting( + update_every_n_epochs=update_every_n_epochs + ) solver = PINN(problem=problem, model=model, weighting=weighting) trainer = Trainer(solver=solver, max_epochs=5, accelerator="cpu") trainer.train()