83 lines
2.6 KiB
Python
83 lines
2.6 KiB
Python
"""PINA Callbacks Implementations"""
|
|
|
|
from pytorch_lightning.callbacks import Callback
|
|
from pytorch_lightning.core.module import LightningModule
|
|
from pytorch_lightning.trainer.trainer import Trainer
|
|
import torch
|
|
import copy
|
|
|
|
|
|
class MetricTracker(Callback):
|
|
|
|
def __init__(self):
|
|
"""
|
|
PINA Implementation of a Lightning Callback for Metric Tracking.
|
|
|
|
This class provides functionality to track relevant metrics during the training process.
|
|
|
|
:ivar _collection: A list to store collected metrics after each training epoch.
|
|
|
|
:param trainer: The trainer object managing the training process.
|
|
:type trainer: pytorch_lightning.Trainer
|
|
|
|
:return: A dictionary containing aggregated metric values.
|
|
:rtype: dict
|
|
|
|
Example:
|
|
>>> tracker = MetricTracker()
|
|
>>> # ... Perform training ...
|
|
>>> metrics = tracker.metrics
|
|
"""
|
|
self._collection = []
|
|
|
|
def on_train_epoch_start(self, trainer, pl_module):
|
|
"""
|
|
Collect and track metrics at the start of each training epoch. At epoch
|
|
zero the metric is not saved. At epoch ``k`` the metric which is tracked
|
|
is the one of epoch ``k-1``.
|
|
|
|
:param trainer: The trainer object managing the training process.
|
|
:type trainer: pytorch_lightning.Trainer
|
|
:param pl_module: Placeholder argument.
|
|
|
|
:return: None
|
|
:rtype: None
|
|
"""
|
|
super().on_train_epoch_end(trainer, pl_module)
|
|
if trainer.current_epoch > 0:
|
|
self._collection.append(
|
|
copy.deepcopy(trainer.logged_metrics)
|
|
) # track them
|
|
|
|
def on_train_end(self, trainer, pl_module):
|
|
"""
|
|
Collect and track metrics at the end of training.
|
|
|
|
:param trainer: The trainer object managing the training process.
|
|
:type trainer: pytorch_lightning.Trainer
|
|
:param pl_module: Placeholder argument.
|
|
|
|
:return: None
|
|
:rtype: None
|
|
"""
|
|
super().on_train_end(trainer, pl_module)
|
|
if trainer.current_epoch > 0:
|
|
self._collection.append(
|
|
copy.deepcopy(trainer.logged_metrics)
|
|
) # track them
|
|
|
|
@property
|
|
def metrics(self):
|
|
"""
|
|
Aggregate collected metrics during training.
|
|
|
|
:return: A dictionary containing aggregated metric values.
|
|
:rtype: dict
|
|
"""
|
|
common_keys = set.intersection(*map(set, self._collection))
|
|
v = {
|
|
k: torch.stack([dic[k] for dic in self._collection])
|
|
for k in common_keys
|
|
}
|
|
return v
|