* Modifying solvers to log every epoch correctly * add `on_epoch` flag to logger * fix bug in `pinn.py` `pts -> samples` in `_loss_phys` * add `optimizer_zero_grad()` in garom generator training loop * modify imports in `callbacks.py` * fixing tests --------- Co-authored-by: Dario Coscia <dariocoscia@dhcp-015.eduroam.sissa.it>
25 lines
659 B
Python
25 lines
659 B
Python
'''PINA Callbacks Implementations'''
|
|
|
|
from pytorch_lightning.callbacks import Callback
|
|
import torch
|
|
import copy
|
|
|
|
|
|
class MetricTracker(Callback):
|
|
"""
|
|
PINA implementation of a Lightining Callback to track relevant
|
|
metrics during training.
|
|
"""
|
|
def __init__(self):
|
|
self._collection = []
|
|
|
|
def on_train_epoch_end(self, trainer, __):
|
|
self._collection.append(copy.deepcopy(trainer.logged_metrics)) # track them
|
|
|
|
@property
|
|
def metrics(self):
|
|
common_keys = set.intersection(*map(set, self._collection))
|
|
v = {k: torch.stack([dic[k] for dic in self._collection]) for k in common_keys}
|
|
return v
|
|
|
|
|