"""PINA Callbacks Implementations""" from pytorch_lightning.callbacks import Callback from pytorch_lightning.core.module import LightningModule from pytorch_lightning.trainer.trainer import Trainer import torch import copy class MetricTracker(Callback): def __init__(self): """ PINA Implementation of a Lightning Callback for Metric Tracking. This class provides functionality to track relevant metrics during the training process. :ivar _collection: A list to store collected metrics after each training epoch. :param trainer: The trainer object managing the training process. :type trainer: pytorch_lightning.Trainer :return: A dictionary containing aggregated metric values. :rtype: dict Example: >>> tracker = MetricTracker() >>> # ... Perform training ... >>> metrics = tracker.metrics """ self._collection = [] def on_train_epoch_start(self, trainer, pl_module): """ Collect and track metrics at the start of each training epoch. At epoch zero the metric is not saved. At epoch ``k`` the metric which is tracked is the one of epoch ``k-1``. :param trainer: The trainer object managing the training process. :type trainer: pytorch_lightning.Trainer :param pl_module: Placeholder argument. :return: None :rtype: None """ super().on_train_epoch_end(trainer, pl_module) if trainer.current_epoch > 0: self._collection.append( copy.deepcopy(trainer.logged_metrics) ) # track them def on_train_end(self, trainer, pl_module): """ Collect and track metrics at the end of training. :param trainer: The trainer object managing the training process. :type trainer: pytorch_lightning.Trainer :param pl_module: Placeholder argument. :return: None :rtype: None """ super().on_train_end(trainer, pl_module) if trainer.current_epoch > 0: self._collection.append( copy.deepcopy(trainer.logged_metrics) ) # track them @property def metrics(self): """ Aggregate collected metrics during training. :return: A dictionary containing aggregated metric values. :rtype: dict """ common_keys = set.intersection(*map(set, self._collection)) v = { k: torch.stack([dic[k] for dic in self._collection]) for k in common_keys } return v