fix tests

This commit is contained in:
Nicola Demo
2025-01-23 09:52:23 +01:00
parent 9aed1a30b3
commit a899327de1
32 changed files with 2331 additions and 2428 deletions

View File

@@ -49,7 +49,7 @@ class R3Refinement(Callback):
"""
# extract the solver and device from trainer
solver = trainer._model
solver = trainer.solver
device = trainer._accelerator_connector._accelerator_flag
precision = trainer.precision
if precision == "64-true":
@@ -67,7 +67,7 @@ class R3Refinement(Callback):
# compute residual
res_loss = {}
tot_loss = []
for location in self._sampling_locations:
for location in self._sampling_locations: #TODO fix for new collector
condition = solver.problem.conditions[location]
pts = solver.problem.input_pts[location]
# send points to correct device
@@ -79,6 +79,8 @@ class R3Refinement(Callback):
res_loss[location] = torch.abs(target).as_subclass(torch.Tensor)
tot_loss.append(torch.abs(target))
print(tot_loss)
return torch.vstack(tot_loss), res_loss
def _r3_routine(self, trainer):
@@ -139,7 +141,7 @@ class R3Refinement(Callback):
:rtype: None
"""
# extract locations for sampling
problem = trainer._model.problem
problem = trainer.solver.problem
locations = []
for condition_name in problem.conditions:
condition = problem.conditions[condition_name]

View File

@@ -3,61 +3,45 @@
from lightning.pytorch.callbacks import Callback
import torch
from ..utils import check_consistency
from pina.optim import TorchOptimizer
class SwitchOptimizer(Callback):
def __init__(self, new_optimizers, new_optimizers_kwargs, epoch_switch):
def __init__(self, new_optimizers, epoch_switch):
"""
PINA Implementation of a Lightning Callback to switch optimizer during training.
PINA Implementation of a Lightning Callback to switch optimizer during
training.
This callback allows for switching between different optimizers during training, enabling
the exploration of multiple optimization strategies without the need to stop training.
This callback allows for switching between different optimizers during
training, enabling the exploration of multiple optimization strategies
without the need to stop training.
:param new_optimizers: The model optimizers to switch to. Can be a single
:class:`torch.optim.Optimizer` or a list of them for multiple model solvers.
:type new_optimizers: torch.optim.Optimizer | list
:param new_optimizers_kwargs: The keyword arguments for the new optimizers. Can be a single dictionary
or a list of dictionaries corresponding to each optimizer.
:type new_optimizers_kwargs: dict | list
:param new_optimizers: The model optimizers to switch to. Can be a
single :class:`torch.optim.Optimizer` or a list of them for multiple
model solvers.
:type new_optimizers: pina.optim.TorchOptimizer | list
:param epoch_switch: The epoch at which to switch to the new optimizer.
:type epoch_switch: int
:raises ValueError: If `epoch_switch` is less than 1 or if there is a mismatch in the number of
optimizers and their corresponding keyword argument dictionaries.
Example:
>>> switch_callback = SwitchOptimizer(new_optimizers=[optimizer1, optimizer2],
>>> new_optimizers_kwargs=[{'lr': 0.001}, {'lr': 0.01}],
>>> switch_callback = SwitchOptimizer(new_optimizers=optimizer,
>>> epoch_switch=10)
"""
super().__init__()
# check type consistency
check_consistency(new_optimizers, torch.optim.Optimizer, subclass=True)
check_consistency(new_optimizers_kwargs, dict)
check_consistency(epoch_switch, int)
if epoch_switch < 1:
raise ValueError("epoch_switch must be greater than one.")
if not isinstance(new_optimizers, list):
new_optimizers = [new_optimizers]
new_optimizers_kwargs = [new_optimizers_kwargs]
len_optimizer = len(new_optimizers)
len_optimizer_kwargs = len(new_optimizers_kwargs)
if len_optimizer_kwargs != len_optimizer:
raise ValueError(
"You must define one dictionary of keyword"
" arguments for each optimizers."
f" Got {len_optimizer} optimizers, and"
f" {len_optimizer_kwargs} dicitionaries"
)
# check type consistency
for optimizer in new_optimizers:
check_consistency(optimizer, TorchOptimizer)
check_consistency(epoch_switch, int)
# save new optimizers
self._new_optimizers = new_optimizers
self._new_optimizers_kwargs = new_optimizers_kwargs
self._epoch_switch = epoch_switch
def on_train_epoch_start(self, trainer, __):
@@ -73,13 +57,9 @@ class SwitchOptimizer(Callback):
"""
if trainer.current_epoch == self._epoch_switch:
optims = []
for idx, (optim, optim_kwargs) in enumerate(
zip(self._new_optimizers, self._new_optimizers_kwargs)
):
optims.append(
optim(
trainer._model.models[idx].parameters(), **optim_kwargs
)
)
for idx, optim in enumerate(self._new_optimizers):
optim.hook(trainer.solver.models[idx].parameters())
optims.append(optim.instance)
trainer.optimizers = optims

View File

@@ -14,7 +14,7 @@ from pina.utils import check_consistency
class MetricTracker(Callback):
def __init__(self):
def __init__(self, metrics_to_track=None):
"""
PINA Implementation of a Lightning Callback for Metric Tracking.
@@ -37,6 +37,9 @@ class MetricTracker(Callback):
"""
super().__init__()
self._collection = []
if metrics_to_track is not None:
metrics_to_track = ['train_loss_epoch', 'train_loss_step', 'val_loss']
self.metrics_to_track = metrics_to_track
def on_train_epoch_end(self, trainer, pl_module):
"""
@@ -72,7 +75,7 @@ class PINAProgressBar(TQDMProgressBar):
BAR_FORMAT = "{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_noinv_fmt}{postfix}]"
def __init__(self, metrics="mean", **kwargs):
def __init__(self, metrics="val_loss", **kwargs):
"""
PINA Implementation of a Lightning Callback for enriching the progress
bar.