diff --git a/docs/source/_rst/_code.rst b/docs/source/_rst/_code.rst index 5d3c3e5..e298aaf 100644 --- a/docs/source/_rst/_code.rst +++ b/docs/source/_rst/_code.rst @@ -15,7 +15,7 @@ The pipeline to solve differential equations with PINA follows just five steps: 2. Generate data using built in `Geometries`_, or load high level simulation results as :doc:`LabelTensor ` 3. Choose or build one or more `Models`_ to solve the problem 4. Choose a solver across PINA available `Solvers`_, or build one using the :doc:`SolverInterface ` - 5. Train the model with the PINA :doc:`Trainer `, enhance the train with `Callbacks_` + 5. Train the model with the PINA :doc:`Trainer `, enhance the train with `Callbacks`_ PINA Features -------------- @@ -155,9 +155,9 @@ Callbacks .. toctree:: :titlesonly: - Metric tracking - Optimizer callbacks - Adaptive Refinments + Processing Callbacks + Optimizer Callbacks + Adaptive Refinment Callback Metrics and Losses -------------------- diff --git a/docs/source/_rst/callbacks/processing_callbacks.rst b/docs/source/_rst/callbacks/processing_callbacks.rst index e024a49..bd3bbc8 100644 --- a/docs/source/_rst/callbacks/processing_callbacks.rst +++ b/docs/source/_rst/callbacks/processing_callbacks.rst @@ -3,5 +3,9 @@ Processing callbacks .. currentmodule:: pina.callbacks.processing_callbacks .. autoclass:: MetricTracker + :members: + :show-inheritance: + +.. autoclass:: PINAProgressBar :members: :show-inheritance: \ No newline at end of file diff --git a/pina/callbacks/__init__.py b/pina/callbacks/__init__.py index 9698136..4ba0271 100644 --- a/pina/callbacks/__init__.py +++ b/pina/callbacks/__init__.py @@ -1,5 +1,10 @@ -__all__ = ["SwitchOptimizer", "R3Refinement", "MetricTracker"] +__all__ = [ + "SwitchOptimizer", + "R3Refinement", + "MetricTracker", + "PINAProgressBar" + ] from .optimizer_callbacks import SwitchOptimizer from .adaptive_refinment_callbacks import R3Refinement -from .processing_callbacks import MetricTracker +from .processing_callbacks import MetricTracker, PINAProgressBar diff --git a/pina/callbacks/processing_callbacks.py b/pina/callbacks/processing_callbacks.py index 3b86936..c6175b0 100644 --- a/pina/callbacks/processing_callbacks.py +++ b/pina/callbacks/processing_callbacks.py @@ -1,11 +1,13 @@ """PINA Callbacks Implementations""" -from pytorch_lightning.callbacks import Callback from pytorch_lightning.core.module import LightningModule from pytorch_lightning.trainer.trainer import Trainer import torch import copy +from pytorch_lightning.callbacks import Callback, TQDMProgressBar +from lightning.pytorch.callbacks.progress.progress_bar import get_standard_metrics +from pina.utils import check_consistency class MetricTracker(Callback): @@ -13,9 +15,11 @@ class MetricTracker(Callback): """ PINA Implementation of a Lightning Callback for Metric Tracking. - This class provides functionality to track relevant metrics during the training process. + This class provides functionality to track relevant metrics during + the training process. - :ivar _collection: A list to store collected metrics after each training epoch. + :ivar _collection: A list to store collected metrics after each + training epoch. :param trainer: The trainer object managing the training process. :type trainer: pytorch_lightning.Trainer @@ -28,20 +32,16 @@ class MetricTracker(Callback): >>> # ... Perform training ... >>> metrics = tracker.metrics """ + super().__init__() self._collection = [] - def on_train_epoch_start(self, trainer, pl_module): + def on_train_epoch_end(self, trainer, pl_module): """ - Collect and track metrics at the start of each training epoch. At epoch - zero the metric is not saved. At epoch ``k`` the metric which is tracked - is the one of epoch ``k-1``. + Collect and track metrics at the end of each training epoch. :param trainer: The trainer object managing the training process. :type trainer: pytorch_lightning.Trainer :param pl_module: Placeholder argument. - - :return: None - :rtype: None """ super().on_train_epoch_end(trainer, pl_module) if trainer.current_epoch > 0: @@ -49,23 +49,6 @@ class MetricTracker(Callback): copy.deepcopy(trainer.logged_metrics) ) # track them - def on_train_end(self, trainer, pl_module): - """ - Collect and track metrics at the end of training. - - :param trainer: The trainer object managing the training process. - :type trainer: pytorch_lightning.Trainer - :param pl_module: Placeholder argument. - - :return: None - :rtype: None - """ - super().on_train_end(trainer, pl_module) - if trainer.current_epoch > 0: - self._collection.append( - copy.deepcopy(trainer.logged_metrics) - ) # track them - @property def metrics(self): """ @@ -80,3 +63,91 @@ class MetricTracker(Callback): for k in common_keys } return v + + +class PINAProgressBar(TQDMProgressBar): + + BAR_FORMAT = "{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_noinv_fmt}{postfix}]" + def __init__(self, metrics='mean', **kwargs): + """ + PINA Implementation of a Lightning Callback for enriching the progress + bar. + + This class provides functionality to display only relevant metrics + during the training process. + + :param metrics: Logged metrics to display during the training. It should + be a subset of the conditions keys defined in + :obj:`pina.condition.Condition`. + :type metrics: str | list(str) | tuple(str) + + :Keyword Arguments: + The additional keyword arguments specify the progress bar + and can be choosen from the `pytorch-lightning + TQDMProgressBar API `_ + + Example: + >>> pbar = PINAProgressBar(['mean']) + >>> # ... Perform training ... + >>> trainer = Trainer(solver, callbacks=[pbar]) + """ + super().__init__(**kwargs) + # check consistency + if not isinstance(metrics, (list, tuple)): + metrics = [metrics] + check_consistency(metrics, str) + self._sorted_metrics = metrics + + def get_metrics(self, trainer, pl_module): + r"""Combines progress bar metrics collected from the trainer with + standard metrics from get_standard_metrics. + Implement this to override the items displayed in the progress bar. + The progress bar metrics are sorted according to ``metrics``. + + Here is an example of how to override the defaults: + + .. code-block:: python + + def get_metrics(self, trainer, model): + # don't show the version number + items = super().get_metrics(trainer, model) + items.pop("v_num", None) + return items + + :return: Dictionary with the items to be displayed in the progress bar. + :rtype: tuple(dict) + + """ + standard_metrics = get_standard_metrics(trainer) + pbar_metrics = trainer.progress_bar_metrics + if pbar_metrics: + pbar_metrics = { + key: pbar_metrics[key] for key in self._sorted_metrics + } + duplicates = list(standard_metrics.keys() & pbar_metrics.keys()) + if duplicates: + rank_zero_warn( + f"The progress bar already tracks a metric with the name(s) '{', '.join(duplicates)}' and" + f" `self.log('{duplicates[0]}', ..., prog_bar=True)` will overwrite this value. " + " If this is undesired, change the name or override `get_metrics()` in the progress bar callback.", + ) + + return {**standard_metrics, **pbar_metrics} + + def on_fit_start(self, trainer, pl_module): + """ + Check that the metrics defined in the initialization are available, + i.e. are correctly logged. + + :param trainer: The trainer object managing the training process. + :type trainer: pytorch_lightning.Trainer + :param pl_module: Placeholder argument. + """ + # Check if all keys in sort_keys are present in the dictionary + for key in self._sorted_metrics: + if key not in trainer.solver.problem.conditions.keys() and key != 'mean': + raise KeyError(f"Key '{key}' is not present in the dictionary") + # add the loss pedix + self._sorted_metrics = [ + metric + '_loss' for metric in self._sorted_metrics] + return super().on_fit_start(trainer, pl_module) \ No newline at end of file diff --git a/pina/solvers/pinns/basepinn.py b/pina/solvers/pinns/basepinn.py index 53d4d3a..0f82056 100644 --- a/pina/solvers/pinns/basepinn.py +++ b/pina/solvers/pinns/basepinn.py @@ -126,8 +126,9 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta): # clamp unknown parameters in InverseProblem (if needed) self._clamp_params() - # total loss (must be a torch.Tensor) + # total loss (must be a torch.Tensor), and logs total_loss = sum(condition_losses) + self.save_logs_and_release() return total_loss.as_subclass(torch.Tensor) def loss_data(self, input_tensor, output_tensor): @@ -205,7 +206,7 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta): ) self.__logged_res_losses.append(loss_value) - def on_train_epoch_end(self): + def save_logs_and_release(self): """ At the end of each epoch we free the stored losses. This function should not be override if not intentionally. @@ -218,7 +219,6 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta): ) # free the logged losses self.__logged_res_losses = [] - return super().on_train_epoch_end() def _clamp_inverse_problem_params(self): """ diff --git a/tests/test_callbacks/test_progress_bar.py b/tests/test_callbacks/test_progress_bar.py new file mode 100644 index 0000000..990b471 --- /dev/null +++ b/tests/test_callbacks/test_progress_bar.py @@ -0,0 +1,78 @@ +import torch +import pytest + +from pina.problem import SpatialProblem +from pina.operators import laplacian +from pina.geometry import CartesianDomain +from pina import Condition, LabelTensor +from pina.solvers import PINN +from pina.trainer import Trainer +from pina.model import FeedForward +from pina.equation.equation import Equation +from pina.equation.equation_factory import FixedValue +from pina.callbacks.processing_callbacks import PINAProgressBar + + +def laplace_equation(input_, output_): + force_term = (torch.sin(input_.extract(['x']) * torch.pi) * + torch.sin(input_.extract(['y']) * torch.pi)) + delta_u = laplacian(output_.extract(['u']), input_) + return delta_u - force_term + + +my_laplace = Equation(laplace_equation) +in_ = LabelTensor(torch.tensor([[0., 1.]]), ['x', 'y']) +out_ = LabelTensor(torch.tensor([[0.]]), ['u']) + + +class Poisson(SpatialProblem): + output_variables = ['u'] + spatial_domain = CartesianDomain({'x': [0, 1], 'y': [0, 1]}) + + conditions = { + 'gamma1': Condition( + location=CartesianDomain({'x': [0, 1], 'y': 1}), + equation=FixedValue(0.0)), + 'gamma2': Condition( + location=CartesianDomain({'x': [0, 1], 'y': 0}), + equation=FixedValue(0.0)), + 'gamma3': Condition( + location=CartesianDomain({'x': 1, 'y': [0, 1]}), + equation=FixedValue(0.0)), + 'gamma4': Condition( + location=CartesianDomain({'x': 0, 'y': [0, 1]}), + equation=FixedValue(0.0)), + 'D': Condition( + input_points=LabelTensor(torch.rand(size=(100, 2)), ['x', 'y']), + equation=my_laplace), + 'data': Condition( + input_points=in_, + output_points=out_) + } + + +# make the problem +poisson_problem = Poisson() +boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4'] +n = 10 +poisson_problem.discretise_domain(n, 'grid', locations=boundaries) +model = FeedForward(len(poisson_problem.input_variables), + len(poisson_problem.output_variables)) + +# make the solver +solver = PINN(problem=poisson_problem, model=model) + + +def test_progress_bar_constructor(): + PINAProgressBar(['mean_loss']) + +def test_progress_bar_routine(): + # make the trainer + trainer = Trainer(solver=solver, + callbacks=[ + PINAProgressBar(['mean', 'D']) + ], + accelerator='cpu', + max_epochs=5) + trainer.train() + # TODO there should be a check that the correct metrics are displayed \ No newline at end of file