fix tests

This commit is contained in:
Nicola Demo
2025-01-23 09:52:23 +01:00
parent 9aed1a30b3
commit a899327de1
32 changed files with 2331 additions and 2428 deletions

View File

@@ -14,7 +14,7 @@ from pina.utils import check_consistency
class MetricTracker(Callback):
def __init__(self):
def __init__(self, metrics_to_track=None):
"""
PINA Implementation of a Lightning Callback for Metric Tracking.
@@ -37,6 +37,9 @@ class MetricTracker(Callback):
"""
super().__init__()
self._collection = []
if metrics_to_track is not None:
metrics_to_track = ['train_loss_epoch', 'train_loss_step', 'val_loss']
self.metrics_to_track = metrics_to_track
def on_train_epoch_end(self, trainer, pl_module):
"""
@@ -72,7 +75,7 @@ class PINAProgressBar(TQDMProgressBar):
BAR_FORMAT = "{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_noinv_fmt}{postfix}]"
def __init__(self, metrics="mean", **kwargs):
def __init__(self, metrics="val_loss", **kwargs):
"""
PINA Implementation of a Lightning Callback for enriching the progress
bar.