* Adding a test for all PINN solvers to assert that the metrics are correctly log

* Adding test for Metric Tracker
* Modify Metric Tracker to correctly log metrics
This commit is contained in:
dario-coscia
2024-08-06 11:12:19 +02:00
committed by Nicola Demo
parent d00fb95d6e
commit 0fa4e1e58a
10 changed files with 308 additions and 9 deletions

View File

@@ -1,6 +1,8 @@
"""PINA Callbacks Implementations"""
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.core.module import LightningModule
from pytorch_lightning.trainer.trainer import Trainer
import torch
import copy
@@ -28,20 +30,41 @@ class MetricTracker(Callback):
"""
self._collection = []
def on_train_epoch_end(self, trainer, __):
def on_train_epoch_start(self, trainer, pl_module):
"""
Collect and track metrics at the end of each training epoch.
Collect and track metrics at the start of each training epoch. At epoch
zero the metric is not saved. At epoch ``k`` the metric which is tracked
is the one of epoch ``k-1``.
:param trainer: The trainer object managing the training process.
:type trainer: pytorch_lightning.Trainer
:param _: Placeholder argument.
:param pl_module: Placeholder argument.
:return: None
:rtype: None
"""
self._collection.append(
copy.deepcopy(trainer.logged_metrics)
) # track them
super().on_train_epoch_end(trainer, pl_module)
if trainer.current_epoch > 0:
self._collection.append(
copy.deepcopy(trainer.logged_metrics)
) # track them
def on_train_end(self, trainer, pl_module):
"""
Collect and track metrics at the end of training.
:param trainer: The trainer object managing the training process.
:type trainer: pytorch_lightning.Trainer
:param pl_module: Placeholder argument.
:return: None
:rtype: None
"""
super().on_train_end(trainer, pl_module)
if trainer.current_epoch > 0:
self._collection.append(
copy.deepcopy(trainer.logged_metrics)
) # track them
@property
def metrics(self):

View File

@@ -195,15 +195,20 @@ class AbstractProblem(metaclass=ABCMeta):
)
# check consistency location
locations_to_sample = [
condition for condition in self.conditions
if hasattr(self.conditions[condition], 'location')
]
if locations == "all":
locations = [condition for condition in self.conditions]
# only locations that can be sampled
locations = locations_to_sample
else:
check_consistency(locations, str)
if sorted(locations) != sorted(self.conditions):
if sorted(locations) != sorted(locations_to_sample):
TypeError(
f"Wrong locations for sampling. Location ",
f"should be in {self.conditions}.",
f"should be in {locations_to_sample}.",
)
# sampling