Tutorial update and small fixes * Tutorials update + Tutorial FNO * Create a metric tracker callback * Update PINN for logging * Update plotter for plotting * Small fix LabelTensor * Small fix FNO --------- Co-authored-by: Dario Coscia <dariocoscia@cli-10-110-13-250.WIFIeduroamSTUD.units.it> Co-authored-by: Dario Coscia <dariocoscia@dhcp-176.eduroam.sissa.it>
25 lines
659 B
Python
25 lines
659 B
Python
'''PINA Callbacks Implementations'''
|
|
|
|
from lightning.pytorch.callbacks import Callback
|
|
import torch
|
|
import copy
|
|
|
|
|
|
class MetricTracker(Callback):
|
|
"""
|
|
PINA implementation of a Lightining Callback to track relevant
|
|
metrics during training.
|
|
"""
|
|
def __init__(self):
|
|
self._collection = []
|
|
|
|
def on_train_epoch_end(self, trainer, __):
|
|
self._collection.append(copy.deepcopy(trainer.logged_metrics)) # track them
|
|
|
|
@property
|
|
def metrics(self):
|
|
common_keys = set.intersection(*map(set, self._collection))
|
|
v = {k: torch.stack([dic[k] for dic in self._collection]) for k in common_keys}
|
|
return v
|
|
|
|
|