fix tests
This commit is contained in:
@@ -14,7 +14,7 @@ from pina.utils import check_consistency
|
||||
|
||||
class MetricTracker(Callback):
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, metrics_to_track=None):
|
||||
"""
|
||||
PINA Implementation of a Lightning Callback for Metric Tracking.
|
||||
|
||||
@@ -37,6 +37,9 @@ class MetricTracker(Callback):
|
||||
"""
|
||||
super().__init__()
|
||||
self._collection = []
|
||||
if metrics_to_track is not None:
|
||||
metrics_to_track = ['train_loss_epoch', 'train_loss_step', 'val_loss']
|
||||
self.metrics_to_track = metrics_to_track
|
||||
|
||||
def on_train_epoch_end(self, trainer, pl_module):
|
||||
"""
|
||||
@@ -72,7 +75,7 @@ class PINAProgressBar(TQDMProgressBar):
|
||||
|
||||
BAR_FORMAT = "{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_noinv_fmt}{postfix}]"
|
||||
|
||||
def __init__(self, metrics="mean", **kwargs):
|
||||
def __init__(self, metrics="val_loss", **kwargs):
|
||||
"""
|
||||
PINA Implementation of a Lightning Callback for enriching the progress
|
||||
bar.
|
||||
|
||||
Reference in New Issue
Block a user