Update callbacks and tests (#482)

---------

Co-authored-by: giovanni <giovanni.canali98@yahoo.it>
This commit is contained in:
Dario Coscia
2025-03-13 16:19:38 +01:00
committed by FilippoOlivo
parent 18d178ab3a
commit 9dab6380f8
8 changed files with 264 additions and 229 deletions

View File

@@ -1,7 +1,7 @@
"""PINA Callbacks Implementations"""
import torch
import copy
import torch
from lightning.pytorch.callbacks import Callback, TQDMProgressBar
from lightning.pytorch.callbacks.progress.progress_bar import (
@@ -11,22 +11,37 @@ from pina.utils import check_consistency
class MetricTracker(Callback):
"""
Lightning Callback for Metric Tracking.
"""
def __init__(self, metrics_to_track=None):
"""
Lightning Callback for Metric Tracking.
Tracks specified metrics during training.
Tracks specific metrics during the training process.
:ivar _collection: A list to store collected metrics after each epoch.
:param metrics_to_track: List of metrics to track. Defaults to train/val loss.
:type metrics_to_track: list, optional
:param metrics_to_track: List of metrics to track.
Defaults to train loss.
:type metrics_to_track: list[str], optional
"""
super().__init__()
self._collection = []
# Default to tracking 'train_loss' and 'val_loss' if not specified
self.metrics_to_track = metrics_to_track or ["train_loss", "val_loss"]
# Default to tracking 'train_loss' if not specified
self.metrics_to_track = metrics_to_track
def setup(self, trainer, pl_module, stage):
"""
Called when fit, validate, test, predict, or tune begins.
:param Trainer trainer: A :class:`~pina.trainer.Trainer` instance.
:param SolverInterface pl_module: A
:class:`~pina.solver.solver.SolverInterface` instance.
:param str stage: Either 'fit', 'test' or 'predict'.
"""
if self.metrics_to_track is None and trainer.batch_size is None:
self.metrics_to_track = ["train_loss"]
elif self.metrics_to_track is None:
self.metrics_to_track = ["train_loss_epoch"]
return super().setup(trainer, pl_module, stage)
def on_train_epoch_end(self, trainer, pl_module):
"""
@@ -71,26 +86,28 @@ class MetricTracker(Callback):
class PINAProgressBar(TQDMProgressBar):
"""
PINA Implementation of a Lightning Callback for enriching the progress bar.
"""
BAR_FORMAT = "{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_noinv_fmt}{postfix}]"
BAR_FORMAT = (
"{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, "
"{rate_noinv_fmt}{postfix}]"
)
def __init__(self, metrics="val", **kwargs):
"""
PINA Implementation of a Lightning Callback for enriching the progress
bar.
This class enables the display of only relevant metrics during training.
This class provides functionality to display only relevant metrics
during the training process.
:param metrics: Logged metrics to display during the training. It should
be a subset of the conditions keys defined in
:param metrics: Logged metrics to be shown during the training.
Must be a subset of the conditions keys defined in
:obj:`pina.condition.Condition`.
:type metrics: str | list(str) | tuple(str)
:Keyword Arguments:
The additional keyword arguments specify the progress bar
and can be choosen from the `pytorch-lightning
TQDMProgressBar API <https://lightning.ai/docs/pytorch/stable/_modules/lightning/pytorch/callback/progress/tqdm_progress.html#TQDMProgressBar>`_
The additional keyword arguments specify the progress bar and can be
choosen from the `pytorch-lightning TQDMProgressBar API
<https://lightning.ai/docs/pytorch/stable/_modules/lightning/pytorch/callback/progress/tqdm_progress.html#TQDMProgressBar>`_
Example:
>>> pbar = PINAProgressBar(['mean'])
@@ -105,9 +122,9 @@ class PINAProgressBar(TQDMProgressBar):
self._sorted_metrics = metrics
def get_metrics(self, trainer, pl_module):
r"""Combines progress bar metrics collected from the trainer with
r"""Combine progress bar metrics collected from the trainer with
standard metrics from get_standard_metrics.
Implement this to override the items displayed in the progress bar.
Override this method to customize the items shown in the progress bar.
The progress bar metrics are sorted according to ``metrics``.
Here is an example of how to override the defaults:
@@ -122,20 +139,20 @@ class PINAProgressBar(TQDMProgressBar):
:return: Dictionary with the items to be displayed in the progress bar.
:rtype: tuple(dict)
"""
standard_metrics = get_standard_metrics(trainer)
pbar_metrics = trainer.progress_bar_metrics
if pbar_metrics:
pbar_metrics = {
key: pbar_metrics[key] for key in self._sorted_metrics
key: pbar_metrics[key]
for key in pbar_metrics
if key in self._sorted_metrics
}
return {**standard_metrics, **pbar_metrics}
def on_fit_start(self, trainer, pl_module):
def setup(self, trainer, pl_module, stage):
"""
Check that the metrics defined in the initialization are available,
i.e. are correctly logged.
Check that the initialized metrics are available and correctly logged.
:param trainer: The trainer object managing the training process.
:type trainer: pytorch_lightning.Trainer
@@ -150,7 +167,11 @@ class PINAProgressBar(TQDMProgressBar):
):
raise KeyError(f"Key '{key}' is not present in the dictionary")
# add the loss pedix
if trainer.batch_size is not None:
pedix = "_loss_epoch"
else:
pedix = "_loss"
self._sorted_metrics = [
metric + "_loss" for metric in self._sorted_metrics
metric + pedix for metric in self._sorted_metrics
]
return super().on_fit_start(trainer, pl_module)
return super().setup(trainer, pl_module, stage)