Files
PINA/pina/callbacks/processing_callbacks.py
dario-coscia 0fa4e1e58a * Adding a test for all PINN solvers to assert that the metrics are correctly log
* Adding test for Metric Tracker
* Modify Metric Tracker to correctly log metrics
2024-08-12 14:48:09 +02:00

83 lines
2.6 KiB
Python

"""PINA Callbacks Implementations"""
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.core.module import LightningModule
from pytorch_lightning.trainer.trainer import Trainer
import torch
import copy
class MetricTracker(Callback):
def __init__(self):
"""
PINA Implementation of a Lightning Callback for Metric Tracking.
This class provides functionality to track relevant metrics during the training process.
:ivar _collection: A list to store collected metrics after each training epoch.
:param trainer: The trainer object managing the training process.
:type trainer: pytorch_lightning.Trainer
:return: A dictionary containing aggregated metric values.
:rtype: dict
Example:
>>> tracker = MetricTracker()
>>> # ... Perform training ...
>>> metrics = tracker.metrics
"""
self._collection = []
def on_train_epoch_start(self, trainer, pl_module):
"""
Collect and track metrics at the start of each training epoch. At epoch
zero the metric is not saved. At epoch ``k`` the metric which is tracked
is the one of epoch ``k-1``.
:param trainer: The trainer object managing the training process.
:type trainer: pytorch_lightning.Trainer
:param pl_module: Placeholder argument.
:return: None
:rtype: None
"""
super().on_train_epoch_end(trainer, pl_module)
if trainer.current_epoch > 0:
self._collection.append(
copy.deepcopy(trainer.logged_metrics)
) # track them
def on_train_end(self, trainer, pl_module):
"""
Collect and track metrics at the end of training.
:param trainer: The trainer object managing the training process.
:type trainer: pytorch_lightning.Trainer
:param pl_module: Placeholder argument.
:return: None
:rtype: None
"""
super().on_train_end(trainer, pl_module)
if trainer.current_epoch > 0:
self._collection.append(
copy.deepcopy(trainer.logged_metrics)
) # track them
@property
def metrics(self):
"""
Aggregate collected metrics during training.
:return: A dictionary containing aggregated metric values.
:rtype: dict
"""
common_keys = set.intersection(*map(set, self._collection))
v = {
k: torch.stack([dic[k] for dic in self._collection])
for k in common_keys
}
return v