Formatting
* Adding black as dev dependency * Formatting pina code * Formatting tests
This commit is contained in:
committed by
Nicola Demo
parent
4c4482b155
commit
42ab1a666b
@@ -67,7 +67,7 @@ class R3Refinement(Callback):
|
||||
# compute residual
|
||||
res_loss = {}
|
||||
tot_loss = []
|
||||
for location in self._sampling_locations: #TODO fix for new collector
|
||||
for location in self._sampling_locations: # TODO fix for new collector
|
||||
condition = solver.problem.conditions[location]
|
||||
pts = solver.problem.input_pts[location]
|
||||
# send points to correct device
|
||||
|
||||
@@ -26,7 +26,7 @@ class MetricTracker(Callback):
|
||||
super().__init__()
|
||||
self._collection = []
|
||||
# Default to tracking 'train_loss' and 'val_loss' if not specified
|
||||
self.metrics_to_track = metrics_to_track or ['train_loss', 'val_loss']
|
||||
self.metrics_to_track = metrics_to_track or ["train_loss", "val_loss"]
|
||||
|
||||
def on_train_epoch_end(self, trainer, pl_module):
|
||||
"""
|
||||
@@ -40,7 +40,8 @@ class MetricTracker(Callback):
|
||||
if trainer.current_epoch > 0:
|
||||
# Append only the tracked metrics to avoid unnecessary data
|
||||
tracked_metrics = {
|
||||
k: v for k, v in trainer.logged_metrics.items()
|
||||
k: v
|
||||
for k, v in trainer.logged_metrics.items()
|
||||
if k in self.metrics_to_track
|
||||
}
|
||||
self._collection.append(copy.deepcopy(tracked_metrics))
|
||||
@@ -57,16 +58,18 @@ class MetricTracker(Callback):
|
||||
return {}
|
||||
|
||||
# Get intersection of keys across all collected dictionaries
|
||||
common_keys = set(self._collection[0]).intersection(*self._collection[1:])
|
||||
|
||||
common_keys = set(self._collection[0]).intersection(
|
||||
*self._collection[1:]
|
||||
)
|
||||
|
||||
# Stack the metric values for common keys and return
|
||||
return {
|
||||
k: torch.stack([dic[k] for dic in self._collection])
|
||||
for k in common_keys if k in self.metrics_to_track
|
||||
for k in common_keys
|
||||
if k in self.metrics_to_track
|
||||
}
|
||||
|
||||
|
||||
|
||||
class PINAProgressBar(TQDMProgressBar):
|
||||
|
||||
BAR_FORMAT = "{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_noinv_fmt}{postfix}]"
|
||||
@@ -142,7 +145,8 @@ class PINAProgressBar(TQDMProgressBar):
|
||||
for key in self._sorted_metrics:
|
||||
if (
|
||||
key not in trainer.solver.problem.conditions.keys()
|
||||
and key != "train" and key != "val"
|
||||
and key != "train"
|
||||
and key != "val"
|
||||
):
|
||||
raise KeyError(f"Key '{key}' is not present in the dictionary")
|
||||
# add the loss pedix
|
||||
|
||||
Reference in New Issue
Block a user