* Adding a test for all PINN solvers to assert that the metrics are correctly log
* Adding test for Metric Tracker * Modify Metric Tracker to correctly log metrics
This commit is contained in:
committed by
Nicola Demo
parent
d00fb95d6e
commit
0fa4e1e58a
@@ -1,6 +1,8 @@
|
||||
"""PINA Callbacks Implementations"""
|
||||
|
||||
from pytorch_lightning.callbacks import Callback
|
||||
from pytorch_lightning.core.module import LightningModule
|
||||
from pytorch_lightning.trainer.trainer import Trainer
|
||||
import torch
|
||||
import copy
|
||||
|
||||
@@ -28,20 +30,41 @@ class MetricTracker(Callback):
|
||||
"""
|
||||
self._collection = []
|
||||
|
||||
def on_train_epoch_end(self, trainer, __):
|
||||
def on_train_epoch_start(self, trainer, pl_module):
|
||||
"""
|
||||
Collect and track metrics at the end of each training epoch.
|
||||
Collect and track metrics at the start of each training epoch. At epoch
|
||||
zero the metric is not saved. At epoch ``k`` the metric which is tracked
|
||||
is the one of epoch ``k-1``.
|
||||
|
||||
:param trainer: The trainer object managing the training process.
|
||||
:type trainer: pytorch_lightning.Trainer
|
||||
:param _: Placeholder argument.
|
||||
:param pl_module: Placeholder argument.
|
||||
|
||||
:return: None
|
||||
:rtype: None
|
||||
"""
|
||||
self._collection.append(
|
||||
copy.deepcopy(trainer.logged_metrics)
|
||||
) # track them
|
||||
super().on_train_epoch_end(trainer, pl_module)
|
||||
if trainer.current_epoch > 0:
|
||||
self._collection.append(
|
||||
copy.deepcopy(trainer.logged_metrics)
|
||||
) # track them
|
||||
|
||||
def on_train_end(self, trainer, pl_module):
|
||||
"""
|
||||
Collect and track metrics at the end of training.
|
||||
|
||||
:param trainer: The trainer object managing the training process.
|
||||
:type trainer: pytorch_lightning.Trainer
|
||||
:param pl_module: Placeholder argument.
|
||||
|
||||
:return: None
|
||||
:rtype: None
|
||||
"""
|
||||
super().on_train_end(trainer, pl_module)
|
||||
if trainer.current_epoch > 0:
|
||||
self._collection.append(
|
||||
copy.deepcopy(trainer.logged_metrics)
|
||||
) # track them
|
||||
|
||||
@property
|
||||
def metrics(self):
|
||||
|
||||
@@ -195,15 +195,20 @@ class AbstractProblem(metaclass=ABCMeta):
|
||||
)
|
||||
|
||||
# check consistency location
|
||||
locations_to_sample = [
|
||||
condition for condition in self.conditions
|
||||
if hasattr(self.conditions[condition], 'location')
|
||||
]
|
||||
if locations == "all":
|
||||
locations = [condition for condition in self.conditions]
|
||||
# only locations that can be sampled
|
||||
locations = locations_to_sample
|
||||
else:
|
||||
check_consistency(locations, str)
|
||||
|
||||
if sorted(locations) != sorted(self.conditions):
|
||||
if sorted(locations) != sorted(locations_to_sample):
|
||||
TypeError(
|
||||
f"Wrong locations for sampling. Location ",
|
||||
f"should be in {self.conditions}.",
|
||||
f"should be in {locations_to_sample}.",
|
||||
)
|
||||
|
||||
# sampling
|
||||
|
||||
Reference in New Issue
Block a user