Formatting

* Adding black as dev dependency
* Formatting pina code
* Formatting tests
This commit is contained in:
Dario Coscia
2025-02-24 11:26:49 +01:00
committed by Nicola Demo
parent 4c4482b155
commit 42ab1a666b
77 changed files with 1170 additions and 924 deletions

View File

@@ -26,7 +26,7 @@ class MetricTracker(Callback):
super().__init__()
self._collection = []
# Default to tracking 'train_loss' and 'val_loss' if not specified
self.metrics_to_track = metrics_to_track or ['train_loss', 'val_loss']
self.metrics_to_track = metrics_to_track or ["train_loss", "val_loss"]
def on_train_epoch_end(self, trainer, pl_module):
"""
@@ -40,7 +40,8 @@ class MetricTracker(Callback):
if trainer.current_epoch > 0:
# Append only the tracked metrics to avoid unnecessary data
tracked_metrics = {
k: v for k, v in trainer.logged_metrics.items()
k: v
for k, v in trainer.logged_metrics.items()
if k in self.metrics_to_track
}
self._collection.append(copy.deepcopy(tracked_metrics))
@@ -57,16 +58,18 @@ class MetricTracker(Callback):
return {}
# Get intersection of keys across all collected dictionaries
common_keys = set(self._collection[0]).intersection(*self._collection[1:])
common_keys = set(self._collection[0]).intersection(
*self._collection[1:]
)
# Stack the metric values for common keys and return
return {
k: torch.stack([dic[k] for dic in self._collection])
for k in common_keys if k in self.metrics_to_track
for k in common_keys
if k in self.metrics_to_track
}
class PINAProgressBar(TQDMProgressBar):
BAR_FORMAT = "{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_noinv_fmt}{postfix}]"
@@ -142,7 +145,8 @@ class PINAProgressBar(TQDMProgressBar):
for key in self._sorted_metrics:
if (
key not in trainer.solver.problem.conditions.keys()
and key != "train" and key != "val"
and key != "train"
and key != "val"
):
raise KeyError(f"Key '{key}' is not present in the dictionary")
# add the loss pedix