Files
PINA/pina/callbacks/processing_callbacks.py
Dario Coscia 3f9305d475 Solvers logging (#202)
* Modifying solvers to log every epoch correctly
* add `on_epoch` flag to logger
* fix bug in `pinn.py` `pts -> samples` in `_loss_phys`
* add `optimizer_zero_grad()` in garom generator training loop
* modify imports in `callbacks.py`
* fixing tests

---------

Co-authored-by: Dario Coscia <dariocoscia@dhcp-015.eduroam.sissa.it>
2023-11-17 09:51:29 +01:00

25 lines
659 B
Python

'''PINA Callbacks Implementations'''
from pytorch_lightning.callbacks import Callback
import torch
import copy
class MetricTracker(Callback):
"""
PINA implementation of a Lightining Callback to track relevant
metrics during training.
"""
def __init__(self):
self._collection = []
def on_train_epoch_end(self, trainer, __):
self._collection.append(copy.deepcopy(trainer.logged_metrics)) # track them
@property
def metrics(self):
common_keys = set.intersection(*map(set, self._collection))
v = {k: torch.stack([dic[k] for dic in self._collection]) for k in common_keys}
return v