Format Python code with psf/black push (#325)
* 🎨 Format Python code with psf/black
This commit is contained in:
committed by
GitHub
parent
cce9876751
commit
5445559cb2
@@ -2,8 +2,8 @@ __all__ = [
|
||||
"SwitchOptimizer",
|
||||
"R3Refinement",
|
||||
"MetricTracker",
|
||||
"PINAProgressBar"
|
||||
]
|
||||
"PINAProgressBar",
|
||||
]
|
||||
|
||||
from .optimizer_callbacks import SwitchOptimizer
|
||||
from .adaptive_refinment_callbacks import R3Refinement
|
||||
|
||||
@@ -6,9 +6,12 @@ import torch
|
||||
import copy
|
||||
|
||||
from pytorch_lightning.callbacks import Callback, TQDMProgressBar
|
||||
from lightning.pytorch.callbacks.progress.progress_bar import get_standard_metrics
|
||||
from lightning.pytorch.callbacks.progress.progress_bar import (
|
||||
get_standard_metrics,
|
||||
)
|
||||
from pina.utils import check_consistency
|
||||
|
||||
|
||||
class MetricTracker(Callback):
|
||||
|
||||
def __init__(self):
|
||||
@@ -37,7 +40,7 @@ class MetricTracker(Callback):
|
||||
|
||||
def on_train_epoch_end(self, trainer, pl_module):
|
||||
"""
|
||||
Collect and track metrics at the end of each training epoch.
|
||||
Collect and track metrics at the end of each training epoch.
|
||||
|
||||
:param trainer: The trainer object managing the training process.
|
||||
:type trainer: pytorch_lightning.Trainer
|
||||
@@ -68,7 +71,8 @@ class MetricTracker(Callback):
|
||||
class PINAProgressBar(TQDMProgressBar):
|
||||
|
||||
BAR_FORMAT = "{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_noinv_fmt}{postfix}]"
|
||||
def __init__(self, metrics='mean', **kwargs):
|
||||
|
||||
def __init__(self, metrics="mean", **kwargs):
|
||||
"""
|
||||
PINA Implementation of a Lightning Callback for enriching the progress
|
||||
bar.
|
||||
@@ -123,7 +127,7 @@ class PINAProgressBar(TQDMProgressBar):
|
||||
if pbar_metrics:
|
||||
pbar_metrics = {
|
||||
key: pbar_metrics[key] for key in self._sorted_metrics
|
||||
}
|
||||
}
|
||||
duplicates = list(standard_metrics.keys() & pbar_metrics.keys())
|
||||
if duplicates:
|
||||
rank_zero_warn(
|
||||
@@ -133,7 +137,7 @@ class PINAProgressBar(TQDMProgressBar):
|
||||
)
|
||||
|
||||
return {**standard_metrics, **pbar_metrics}
|
||||
|
||||
|
||||
def on_fit_start(self, trainer, pl_module):
|
||||
"""
|
||||
Check that the metrics defined in the initialization are available,
|
||||
@@ -145,9 +149,13 @@ class PINAProgressBar(TQDMProgressBar):
|
||||
"""
|
||||
# Check if all keys in sort_keys are present in the dictionary
|
||||
for key in self._sorted_metrics:
|
||||
if key not in trainer.solver.problem.conditions.keys() and key != 'mean':
|
||||
if (
|
||||
key not in trainer.solver.problem.conditions.keys()
|
||||
and key != "mean"
|
||||
):
|
||||
raise KeyError(f"Key '{key}' is not present in the dictionary")
|
||||
# add the loss pedix
|
||||
self._sorted_metrics = [
|
||||
metric + '_loss' for metric in self._sorted_metrics]
|
||||
return super().on_fit_start(trainer, pl_module)
|
||||
metric + "_loss" for metric in self._sorted_metrics
|
||||
]
|
||||
return super().on_fit_start(trainer, pl_module)
|
||||
|
||||
Reference in New Issue
Block a user