* Adding a test for all PINN solvers to assert that the metrics are correctly log
* Adding test for Metric Tracker * Modify Metric Tracker to correctly log metrics
This commit is contained in:
committed by
Nicola Demo
parent
d00fb95d6e
commit
0fa4e1e58a
@@ -1,6 +1,8 @@
|
||||
"""PINA Callbacks Implementations"""
|
||||
|
||||
from pytorch_lightning.callbacks import Callback
|
||||
from pytorch_lightning.core.module import LightningModule
|
||||
from pytorch_lightning.trainer.trainer import Trainer
|
||||
import torch
|
||||
import copy
|
||||
|
||||
@@ -28,20 +30,41 @@ class MetricTracker(Callback):
|
||||
"""
|
||||
self._collection = []
|
||||
|
||||
def on_train_epoch_end(self, trainer, __):
|
||||
def on_train_epoch_start(self, trainer, pl_module):
|
||||
"""
|
||||
Collect and track metrics at the end of each training epoch.
|
||||
Collect and track metrics at the start of each training epoch. At epoch
|
||||
zero the metric is not saved. At epoch ``k`` the metric which is tracked
|
||||
is the one of epoch ``k-1``.
|
||||
|
||||
:param trainer: The trainer object managing the training process.
|
||||
:type trainer: pytorch_lightning.Trainer
|
||||
:param _: Placeholder argument.
|
||||
:param pl_module: Placeholder argument.
|
||||
|
||||
:return: None
|
||||
:rtype: None
|
||||
"""
|
||||
self._collection.append(
|
||||
copy.deepcopy(trainer.logged_metrics)
|
||||
) # track them
|
||||
super().on_train_epoch_end(trainer, pl_module)
|
||||
if trainer.current_epoch > 0:
|
||||
self._collection.append(
|
||||
copy.deepcopy(trainer.logged_metrics)
|
||||
) # track them
|
||||
|
||||
def on_train_end(self, trainer, pl_module):
|
||||
"""
|
||||
Collect and track metrics at the end of training.
|
||||
|
||||
:param trainer: The trainer object managing the training process.
|
||||
:type trainer: pytorch_lightning.Trainer
|
||||
:param pl_module: Placeholder argument.
|
||||
|
||||
:return: None
|
||||
:rtype: None
|
||||
"""
|
||||
super().on_train_end(trainer, pl_module)
|
||||
if trainer.current_epoch > 0:
|
||||
self._collection.append(
|
||||
copy.deepcopy(trainer.logged_metrics)
|
||||
) # track them
|
||||
|
||||
@property
|
||||
def metrics(self):
|
||||
|
||||
Reference in New Issue
Block a user