Format Python code with psf/black push (#325)

* 🎨 Format Python code with psf/black
This commit is contained in:
github-actions[bot]
2024-08-12 18:30:46 +02:00
committed by GitHub
parent cce9876751
commit 5445559cb2
5 changed files with 85 additions and 56 deletions

View File

@@ -2,8 +2,8 @@ __all__ = [
"SwitchOptimizer",
"R3Refinement",
"MetricTracker",
"PINAProgressBar"
]
"PINAProgressBar",
]
from .optimizer_callbacks import SwitchOptimizer
from .adaptive_refinment_callbacks import R3Refinement

View File

@@ -6,9 +6,12 @@ import torch
import copy
from pytorch_lightning.callbacks import Callback, TQDMProgressBar
from lightning.pytorch.callbacks.progress.progress_bar import get_standard_metrics
from lightning.pytorch.callbacks.progress.progress_bar import (
get_standard_metrics,
)
from pina.utils import check_consistency
class MetricTracker(Callback):
def __init__(self):
@@ -37,7 +40,7 @@ class MetricTracker(Callback):
def on_train_epoch_end(self, trainer, pl_module):
"""
Collect and track metrics at the end of each training epoch.
Collect and track metrics at the end of each training epoch.
:param trainer: The trainer object managing the training process.
:type trainer: pytorch_lightning.Trainer
@@ -68,7 +71,8 @@ class MetricTracker(Callback):
class PINAProgressBar(TQDMProgressBar):
BAR_FORMAT = "{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_noinv_fmt}{postfix}]"
def __init__(self, metrics='mean', **kwargs):
def __init__(self, metrics="mean", **kwargs):
"""
PINA Implementation of a Lightning Callback for enriching the progress
bar.
@@ -123,7 +127,7 @@ class PINAProgressBar(TQDMProgressBar):
if pbar_metrics:
pbar_metrics = {
key: pbar_metrics[key] for key in self._sorted_metrics
}
}
duplicates = list(standard_metrics.keys() & pbar_metrics.keys())
if duplicates:
rank_zero_warn(
@@ -133,7 +137,7 @@ class PINAProgressBar(TQDMProgressBar):
)
return {**standard_metrics, **pbar_metrics}
def on_fit_start(self, trainer, pl_module):
"""
Check that the metrics defined in the initialization are available,
@@ -145,9 +149,13 @@ class PINAProgressBar(TQDMProgressBar):
"""
# Check if all keys in sort_keys are present in the dictionary
for key in self._sorted_metrics:
if key not in trainer.solver.problem.conditions.keys() and key != 'mean':
if (
key not in trainer.solver.problem.conditions.keys()
and key != "mean"
):
raise KeyError(f"Key '{key}' is not present in the dictionary")
# add the loss pedix
self._sorted_metrics = [
metric + '_loss' for metric in self._sorted_metrics]
return super().on_fit_start(trainer, pl_module)
metric + "_loss" for metric in self._sorted_metrics
]
return super().on_fit_start(trainer, pl_module)