Update solvers (#434)

* Enable DDP training with batch_size=None and add validity check for split sizes
* Refactoring SolverInterfaces (#435)
* Solver update + weighting
* Updating PINN for 0.2
* Modify GAROM + tests
* Adding more versatile loggers
* Disable compilation when running on Windows
* Fix tests

---------

Co-authored-by: giovanni <giovanni.canali98@yahoo.it>
Co-authored-by: FilippoOlivo <filippo@filippoolivo.com>
This commit is contained in:
Dario Coscia
2025-02-17 11:26:21 +01:00
committed by Nicola Demo
parent 780c4921eb
commit 9cae9a438f
50 changed files with 2848 additions and 4187 deletions

View File

@@ -1,7 +1,5 @@
"""PINA Callbacks Implementations"""
from lightning.pytorch.core.module import LightningModule
from lightning.pytorch.trainer.trainer import Trainer
import torch
import copy
@@ -16,30 +14,19 @@ class MetricTracker(Callback):
def __init__(self, metrics_to_track=None):
"""
PINA Implementation of a Lightning Callback for Metric Tracking.
Lightning Callback for Metric Tracking.
This class provides functionality to track relevant metrics during
the training process.
Tracks specific metrics during the training process.
:ivar _collection: A list to store collected metrics after each
training epoch.
:ivar _collection: A list to store collected metrics after each epoch.
:param trainer: The trainer object managing the training process.
:type trainer: pytorch_lightning.Trainer
:return: A dictionary containing aggregated metric values.
:rtype: dict
Example:
>>> tracker = MetricTracker()
>>> # ... Perform training ...
>>> metrics = tracker.metrics
:param metrics_to_track: List of metrics to track. Defaults to train/val loss.
:type metrics_to_track: list, optional
"""
super().__init__()
self._collection = []
if metrics_to_track is not None:
metrics_to_track = ['train_loss_epoch', 'train_loss_step', 'val_loss']
self.metrics_to_track = metrics_to_track
# Default to tracking 'train_loss' and 'val_loss' if not specified
self.metrics_to_track = metrics_to_track or ['train_loss', 'val_loss']
def on_train_epoch_end(self, trainer, pl_module):
"""
@@ -47,35 +34,44 @@ class MetricTracker(Callback):
:param trainer: The trainer object managing the training process.
:type trainer: pytorch_lightning.Trainer
:param pl_module: Placeholder argument.
:param pl_module: The model being trained (not used here).
"""
super().on_train_epoch_end(trainer, pl_module)
# Track metrics after the first epoch onwards
if trainer.current_epoch > 0:
self._collection.append(
copy.deepcopy(trainer.logged_metrics)
) # track them
# Append only the tracked metrics to avoid unnecessary data
tracked_metrics = {
k: v for k, v in trainer.logged_metrics.items()
if k in self.metrics_to_track
}
self._collection.append(copy.deepcopy(tracked_metrics))
@property
def metrics(self):
"""
Aggregate collected metrics during training.
Aggregate collected metrics over all epochs.
:return: A dictionary containing aggregated metric values.
:rtype: dict
"""
common_keys = set.intersection(*map(set, self._collection))
v = {
if not self._collection:
return {}
# Get intersection of keys across all collected dictionaries
common_keys = set(self._collection[0]).intersection(*self._collection[1:])
# Stack the metric values for common keys and return
return {
k: torch.stack([dic[k] for dic in self._collection])
for k in common_keys
for k in common_keys if k in self.metrics_to_track
}
return v
class PINAProgressBar(TQDMProgressBar):
BAR_FORMAT = "{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_noinv_fmt}{postfix}]"
def __init__(self, metrics="val_loss", **kwargs):
def __init__(self, metrics="val", **kwargs):
"""
PINA Implementation of a Lightning Callback for enriching the progress
bar.
@@ -131,14 +127,6 @@ class PINAProgressBar(TQDMProgressBar):
pbar_metrics = {
key: pbar_metrics[key] for key in self._sorted_metrics
}
duplicates = list(standard_metrics.keys() & pbar_metrics.keys())
if duplicates:
rank_zero_warn(
f"The progress bar already tracks a metric with the name(s) '{', '.join(duplicates)}' and"
f" `self.log('{duplicates[0]}', ..., prog_bar=True)` will overwrite this value. "
" If this is undesired, change the name or override `get_metrics()` in the progress bar callback.",
)
return {**standard_metrics, **pbar_metrics}
def on_fit_start(self, trainer, pl_module):
@@ -154,7 +142,7 @@ class PINAProgressBar(TQDMProgressBar):
for key in self._sorted_metrics:
if (
key not in trainer.solver.problem.conditions.keys()
and key != "mean"
and key != "train" and key != "val"
):
raise KeyError(f"Key '{key}' is not present in the dictionary")
# add the loss pedix