Tutorials v0.1 (#178)
Tutorial update and small fixes * Tutorials update + Tutorial FNO * Create a metric tracker callback * Update PINN for logging * Update plotter for plotting * Small fix LabelTensor * Small fix FNO --------- Co-authored-by: Dario Coscia <dariocoscia@cli-10-110-13-250.WIFIeduroamSTUD.units.it> Co-authored-by: Dario Coscia <dariocoscia@dhcp-176.eduroam.sissa.it>
This commit is contained in:
committed by
Nicola Demo
parent
939353f517
commit
a9b1bd2826
@@ -1,7 +1,9 @@
|
||||
__all__ = [
|
||||
'SwitchOptimizer',
|
||||
'R3Refinement',
|
||||
'MetricTracker'
|
||||
]
|
||||
|
||||
from .optimizer_callbacks import SwitchOptimizer
|
||||
from .adaptive_refinment_callbacks import R3Refinement
|
||||
from .adaptive_refinment_callbacks import R3Refinement
|
||||
from .processing_callbacks import MetricTracker
|
||||
25
pina/callbacks/processing_callbacks.py
Normal file
25
pina/callbacks/processing_callbacks.py
Normal file
@@ -0,0 +1,25 @@
|
||||
'''PINA Callbacks Implementations'''
|
||||
|
||||
from lightning.pytorch.callbacks import Callback
|
||||
import torch
|
||||
import copy
|
||||
|
||||
|
||||
class MetricTracker(Callback):
|
||||
"""
|
||||
PINA implementation of a Lightining Callback to track relevant
|
||||
metrics during training.
|
||||
"""
|
||||
def __init__(self):
|
||||
self._collection = []
|
||||
|
||||
def on_train_epoch_end(self, trainer, __):
|
||||
self._collection.append(copy.deepcopy(trainer.logged_metrics)) # track them
|
||||
|
||||
@property
|
||||
def metrics(self):
|
||||
common_keys = set.intersection(*map(set, self._collection))
|
||||
v = {k: torch.stack([dic[k] for dic in self._collection]) for k in common_keys}
|
||||
return v
|
||||
|
||||
|
||||
Reference in New Issue
Block a user