Update callbacks and tests (#482)
--------- Co-authored-by: giovanni <giovanni.canali98@yahoo.it>
This commit is contained in:
committed by
Nicola Demo
parent
6ae301622b
commit
632934f9cc
@@ -1,5 +1,6 @@
|
|||||||
"""PINA Callbacks Implementations"""
|
"""PINA Callbacks Implementations"""
|
||||||
|
|
||||||
|
import importlib.metadata
|
||||||
import torch
|
import torch
|
||||||
from lightning.pytorch.callbacks import Callback
|
from lightning.pytorch.callbacks import Callback
|
||||||
from ..label_tensor import LabelTensor
|
from ..label_tensor import LabelTensor
|
||||||
@@ -7,17 +8,17 @@ from ..utils import check_consistency
|
|||||||
|
|
||||||
|
|
||||||
class R3Refinement(Callback):
|
class R3Refinement(Callback):
|
||||||
|
"""
|
||||||
|
PINA Implementation of an R3 Refinement Callback.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, sample_every):
|
def __init__(self, sample_every):
|
||||||
"""
|
"""
|
||||||
PINA Implementation of an R3 Refinement Callback.
|
|
||||||
|
|
||||||
This callback implements the R3 (Retain-Resample-Release) routine for
|
This callback implements the R3 (Retain-Resample-Release) routine for
|
||||||
sampling new points based on adaptive search.
|
sampling new points based on adaptive search.
|
||||||
The algorithm incrementally accumulates collocation points in regions
|
The algorithm incrementally accumulates collocation points in regions
|
||||||
of high PDE residuals, and releases those
|
of high PDE residuals, and releases those with low residuals.
|
||||||
with low residuals. Points are sampled uniformly in all regions
|
Points are sampled uniformly in all regions where sampling is needed.
|
||||||
where sampling is needed.
|
|
||||||
|
|
||||||
.. seealso::
|
.. seealso::
|
||||||
|
|
||||||
@@ -33,142 +34,148 @@ class R3Refinement(Callback):
|
|||||||
Example:
|
Example:
|
||||||
>>> r3_callback = R3Refinement(sample_every=5)
|
>>> r3_callback = R3Refinement(sample_every=5)
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
raise NotImplementedError(
|
||||||
|
"R3Refinement callback is being refactored in the pina "
|
||||||
# sample every
|
f"{importlib.metadata.metadata('pina-mathlab')['Version']} "
|
||||||
check_consistency(sample_every, int)
|
"version. Please use version 0.1 if R3Refinement is required."
|
||||||
self._sample_every = sample_every
|
|
||||||
self._const_pts = None
|
|
||||||
|
|
||||||
def _compute_residual(self, trainer):
|
|
||||||
"""
|
|
||||||
Computes the residuals for a PINN object.
|
|
||||||
|
|
||||||
:return: the total loss, and pointwise loss.
|
|
||||||
:rtype: tuple
|
|
||||||
"""
|
|
||||||
|
|
||||||
# extract the solver and device from trainer
|
|
||||||
solver = trainer.solver
|
|
||||||
device = trainer._accelerator_connector._accelerator_flag
|
|
||||||
precision = trainer.precision
|
|
||||||
if precision == "64-true":
|
|
||||||
precision = torch.float64
|
|
||||||
elif precision == "32-true":
|
|
||||||
precision = torch.float32
|
|
||||||
else:
|
|
||||||
raise RuntimeError(
|
|
||||||
"Currently R3Refinement is only implemented "
|
|
||||||
"for precision '32-true' and '64-true', set "
|
|
||||||
"Trainer precision to match one of the "
|
|
||||||
"available precisions."
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# compute residual
|
# super().__init__()
|
||||||
res_loss = {}
|
|
||||||
tot_loss = []
|
|
||||||
for location in self._sampling_locations: # TODO fix for new collector
|
|
||||||
condition = solver.problem.conditions[location]
|
|
||||||
pts = solver.problem.input_pts[location]
|
|
||||||
# send points to correct device
|
|
||||||
pts = pts.to(device=device, dtype=precision)
|
|
||||||
pts = pts.requires_grad_(True)
|
|
||||||
pts.retain_grad()
|
|
||||||
# PINN loss: equation evaluated only for sampling locations
|
|
||||||
target = condition.equation.residual(pts, solver.forward(pts))
|
|
||||||
res_loss[location] = torch.abs(target).as_subclass(torch.Tensor)
|
|
||||||
tot_loss.append(torch.abs(target))
|
|
||||||
|
|
||||||
print(tot_loss)
|
# # sample every
|
||||||
|
# check_consistency(sample_every, int)
|
||||||
|
# self._sample_every = sample_every
|
||||||
|
# self._const_pts = None
|
||||||
|
|
||||||
return torch.vstack(tot_loss), res_loss
|
# def _compute_residual(self, trainer):
|
||||||
|
# """
|
||||||
|
# Computes the residuals for a PINN object.
|
||||||
|
|
||||||
def _r3_routine(self, trainer):
|
# :return: the total loss, and pointwise loss.
|
||||||
"""
|
# :rtype: tuple
|
||||||
R3 refinement main routine.
|
# """
|
||||||
|
|
||||||
:param Trainer trainer: PINA Trainer.
|
# # extract the solver and device from trainer
|
||||||
"""
|
# solver = trainer.solver
|
||||||
# compute residual (all device possible)
|
# device = trainer._accelerator_connector._accelerator_flag
|
||||||
tot_loss, res_loss = self._compute_residual(trainer)
|
# precision = trainer.precision
|
||||||
tot_loss = tot_loss.as_subclass(torch.Tensor)
|
# if precision == "64-true":
|
||||||
|
# precision = torch.float64
|
||||||
|
# elif precision == "32-true":
|
||||||
|
# precision = torch.float32
|
||||||
|
# else:
|
||||||
|
# raise RuntimeError(
|
||||||
|
# "Currently R3Refinement is only implemented "
|
||||||
|
# "for precision '32-true' and '64-true', set "
|
||||||
|
# "Trainer precision to match one of the "
|
||||||
|
# "available precisions."
|
||||||
|
# )
|
||||||
|
|
||||||
# !!!!!! From now everything is performed on CPU !!!!!!
|
# # compute residual
|
||||||
|
# res_loss = {}
|
||||||
|
# tot_loss = []
|
||||||
|
# for location in self._sampling_locations:
|
||||||
|
# condition = solver.problem.conditions[location]
|
||||||
|
# pts = solver.problem.input_pts[location]
|
||||||
|
# # send points to correct device
|
||||||
|
# pts = pts.to(device=device, dtype=precision)
|
||||||
|
# pts = pts.requires_grad_(True)
|
||||||
|
# pts.retain_grad()
|
||||||
|
# # PINN loss: equation evaluated only for sampling locations
|
||||||
|
# target = condition.equation.residual(pts, solver.forward(pts))
|
||||||
|
# res_loss[location] = torch.abs(target).as_subclass(torch.Tensor)
|
||||||
|
# tot_loss.append(torch.abs(target))
|
||||||
|
|
||||||
# average loss
|
# print(tot_loss)
|
||||||
avg = (tot_loss.mean()).to("cpu")
|
|
||||||
old_pts = {} # points to be retained
|
|
||||||
for location in self._sampling_locations:
|
|
||||||
pts = trainer._model.problem.input_pts[location]
|
|
||||||
labels = pts.labels
|
|
||||||
pts = pts.cpu().detach().as_subclass(torch.Tensor)
|
|
||||||
residuals = res_loss[location].cpu()
|
|
||||||
mask = (residuals > avg).flatten()
|
|
||||||
if any(mask): # append residuals greater than average
|
|
||||||
pts = (pts[mask]).as_subclass(LabelTensor)
|
|
||||||
pts.labels = labels
|
|
||||||
old_pts[location] = pts
|
|
||||||
numb_pts = self._const_pts[location] - len(old_pts[location])
|
|
||||||
# sample new points
|
|
||||||
trainer._model.problem.discretise_domain(
|
|
||||||
numb_pts, "random", locations=[location]
|
|
||||||
)
|
|
||||||
|
|
||||||
else: # if no res greater than average, samples all uniformly
|
# return torch.vstack(tot_loss), res_loss
|
||||||
numb_pts = self._const_pts[location]
|
|
||||||
# sample new points
|
|
||||||
trainer._model.problem.discretise_domain(
|
|
||||||
numb_pts, "random", locations=[location]
|
|
||||||
)
|
|
||||||
# adding previous population points
|
|
||||||
trainer._model.problem.add_points(old_pts)
|
|
||||||
|
|
||||||
# update dataloader
|
# def _r3_routine(self, trainer):
|
||||||
trainer._create_or_update_loader()
|
# """
|
||||||
|
# R3 refinement main routine.
|
||||||
|
|
||||||
def on_train_start(self, trainer, _):
|
# :param Trainer trainer: PINA Trainer.
|
||||||
"""
|
# """
|
||||||
Callback function called at the start of training.
|
# # compute residual (all device possible)
|
||||||
|
# tot_loss, res_loss = self._compute_residual(trainer)
|
||||||
|
# tot_loss = tot_loss.as_subclass(torch.Tensor)
|
||||||
|
|
||||||
This method extracts the locations for sampling from the problem
|
# # !!!!!! From now everything is performed on CPU !!!!!!
|
||||||
conditions and calculates the total population.
|
|
||||||
|
|
||||||
:param trainer: The trainer object managing the training process.
|
# # average loss
|
||||||
:type trainer: pytorch_lightning.Trainer
|
# avg = (tot_loss.mean()).to("cpu")
|
||||||
:param _: Placeholder argument (not used).
|
# old_pts = {} # points to be retained
|
||||||
|
# for location in self._sampling_locations:
|
||||||
|
# pts = trainer._model.problem.input_pts[location]
|
||||||
|
# labels = pts.labels
|
||||||
|
# pts = pts.cpu().detach().as_subclass(torch.Tensor)
|
||||||
|
# residuals = res_loss[location].cpu()
|
||||||
|
# mask = (residuals > avg).flatten()
|
||||||
|
# if any(mask): # append residuals greater than average
|
||||||
|
# pts = (pts[mask]).as_subclass(LabelTensor)
|
||||||
|
# pts.labels = labels
|
||||||
|
# old_pts[location] = pts
|
||||||
|
# numb_pts = self._const_pts[location] - len(old_pts[location])
|
||||||
|
# # sample new points
|
||||||
|
# trainer._model.problem.discretise_domain(
|
||||||
|
# numb_pts, "random", locations=[location]
|
||||||
|
# )
|
||||||
|
|
||||||
:return: None
|
# else: # if no res greater than average, samples all uniformly
|
||||||
:rtype: None
|
# numb_pts = self._const_pts[location]
|
||||||
"""
|
# # sample new points
|
||||||
# extract locations for sampling
|
# trainer._model.problem.discretise_domain(
|
||||||
problem = trainer.solver.problem
|
# numb_pts, "random", locations=[location]
|
||||||
locations = []
|
# )
|
||||||
for condition_name in problem.conditions:
|
# # adding previous population points
|
||||||
condition = problem.conditions[condition_name]
|
# trainer._model.problem.add_points(old_pts)
|
||||||
if hasattr(condition, "location"):
|
|
||||||
locations.append(condition_name)
|
|
||||||
self._sampling_locations = locations
|
|
||||||
|
|
||||||
# extract total population
|
# # update dataloader
|
||||||
const_pts = {} # for each location, store the # of pts to keep constant
|
# trainer._create_or_update_loader()
|
||||||
for location in self._sampling_locations:
|
|
||||||
pts = trainer._model.problem.input_pts[location]
|
|
||||||
const_pts[location] = len(pts)
|
|
||||||
self._const_pts = const_pts
|
|
||||||
|
|
||||||
def on_train_epoch_end(self, trainer, __):
|
# def on_train_start(self, trainer, _):
|
||||||
"""
|
# """
|
||||||
Callback function called at the end of each training epoch.
|
# Callback function called at the start of training.
|
||||||
|
|
||||||
This method triggers the R3 routine for refinement if the current
|
# This method extracts the locations for sampling from the problem
|
||||||
epoch is a multiple of `_sample_every`.
|
# conditions and calculates the total population.
|
||||||
|
|
||||||
:param trainer: The trainer object managing the training process.
|
# :param trainer: The trainer object managing the training process.
|
||||||
:type trainer: pytorch_lightning.Trainer
|
# :type trainer: pytorch_lightning.Trainer
|
||||||
:param __: Placeholder argument (not used).
|
# :param _: Placeholder argument (not used).
|
||||||
|
|
||||||
:return: None
|
# :return: None
|
||||||
:rtype: None
|
# :rtype: None
|
||||||
"""
|
# """
|
||||||
if trainer.current_epoch % self._sample_every == 0:
|
# # extract locations for sampling
|
||||||
self._r3_routine(trainer)
|
# problem = trainer.solver.problem
|
||||||
|
# locations = []
|
||||||
|
# for condition_name in problem.conditions:
|
||||||
|
# condition = problem.conditions[condition_name]
|
||||||
|
# if hasattr(condition, "location"):
|
||||||
|
# locations.append(condition_name)
|
||||||
|
# self._sampling_locations = locations
|
||||||
|
|
||||||
|
# # extract total population
|
||||||
|
# const_pts = {} # for each location, store the pts to keep constant
|
||||||
|
# for location in self._sampling_locations:
|
||||||
|
# pts = trainer._model.problem.input_pts[location]
|
||||||
|
# const_pts[location] = len(pts)
|
||||||
|
# self._const_pts = const_pts
|
||||||
|
|
||||||
|
# def on_train_epoch_end(self, trainer, __):
|
||||||
|
# """
|
||||||
|
# Callback function called at the end of each training epoch.
|
||||||
|
|
||||||
|
# This method triggers the R3 routine for refinement if the current
|
||||||
|
# epoch is a multiple of `_sample_every`.
|
||||||
|
|
||||||
|
# :param trainer: The trainer object managing the training process.
|
||||||
|
# :type trainer: pytorch_lightning.Trainer
|
||||||
|
# :param __: Placeholder argument (not used).
|
||||||
|
|
||||||
|
# :return: None
|
||||||
|
# :rtype: None
|
||||||
|
# """
|
||||||
|
# if trainer.current_epoch % self._sample_every == 0:
|
||||||
|
# self._r3_routine(trainer)
|
||||||
|
|||||||
@@ -37,12 +37,13 @@ class LinearWeightUpdate(Callback):
|
|||||||
check_consistency(self.initial_value, (float, int), subclass=False)
|
check_consistency(self.initial_value, (float, int), subclass=False)
|
||||||
check_consistency(self.target_value, (float, int), subclass=False)
|
check_consistency(self.target_value, (float, int), subclass=False)
|
||||||
|
|
||||||
def on_train_start(self, trainer, solver):
|
def on_train_start(self, trainer, pl_module):
|
||||||
"""
|
"""
|
||||||
Initialize the weight of the condition to the specified `initial_value`.
|
Initialize the weight of the condition to the specified `initial_value`.
|
||||||
|
|
||||||
:param Trainer trainer: a pina:class:`Trainer` instance.
|
:param Trainer trainer: A :class:`~pina.trainer.Trainer` instance.
|
||||||
:param SolverInterface solver: a pina:class:`SolverInterface` instance.
|
:param SolverInterface pl_module: A
|
||||||
|
:class:`~pina.solver.solver.SolverInterface` instance.
|
||||||
"""
|
"""
|
||||||
# Check that the target epoch is valid
|
# Check that the target epoch is valid
|
||||||
if not 0 < self.target_epoch <= trainer.max_epochs:
|
if not 0 < self.target_epoch <= trainer.max_epochs:
|
||||||
@@ -52,7 +53,7 @@ class LinearWeightUpdate(Callback):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Check that the condition is a problem condition
|
# Check that the condition is a problem condition
|
||||||
if self.condition_name not in solver.problem.conditions:
|
if self.condition_name not in pl_module.problem.conditions:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"`{self.condition_name}` must be a problem condition."
|
f"`{self.condition_name}` must be a problem condition."
|
||||||
)
|
)
|
||||||
@@ -66,20 +67,21 @@ class LinearWeightUpdate(Callback):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Check that the weighting schema is ScalarWeighting
|
# Check that the weighting schema is ScalarWeighting
|
||||||
if not isinstance(solver.weighting, ScalarWeighting):
|
if not isinstance(pl_module.weighting, ScalarWeighting):
|
||||||
raise ValueError("The weighting schema must be ScalarWeighting.")
|
raise ValueError("The weighting schema must be ScalarWeighting.")
|
||||||
|
|
||||||
# Initialize the weight of the condition
|
# Initialize the weight of the condition
|
||||||
solver.weighting.weights[self.condition_name] = self.initial_value
|
pl_module.weighting.weights[self.condition_name] = self.initial_value
|
||||||
|
|
||||||
def on_train_epoch_start(self, trainer, solver):
|
def on_train_epoch_start(self, trainer, pl_module):
|
||||||
"""
|
"""
|
||||||
Adjust at each epoch the weight of the condition.
|
Adjust at each epoch the weight of the condition.
|
||||||
|
|
||||||
:param Trainer trainer: a pina:class:`Trainer` instance.
|
:param Trainer trainer: A :class:`~pina.trainer.Trainer` instance.
|
||||||
:param SolverInterface solver: a pina:class:`SolverInterface` instance.
|
:param SolverInterface pl_module: A
|
||||||
|
:class:`~pina.solver.solver.SolverInterface` instance.
|
||||||
"""
|
"""
|
||||||
if 0 < trainer.current_epoch <= self.target_epoch:
|
if 0 < trainer.current_epoch <= self.target_epoch:
|
||||||
solver.weighting.weights[self.condition_name] += (
|
pl_module.weighting.weights[self.condition_name] += (
|
||||||
self.target_value - self.initial_value
|
self.target_value - self.initial_value
|
||||||
) / (self.target_epoch - 1)
|
) / (self.target_epoch - 1)
|
||||||
|
|||||||
@@ -1,27 +1,27 @@
|
|||||||
"""PINA Callbacks Implementations"""
|
"""PINA Callbacks Implementations"""
|
||||||
|
|
||||||
from lightning.pytorch.callbacks import Callback
|
from lightning.pytorch.callbacks import Callback
|
||||||
import torch
|
from ..optim import TorchOptimizer
|
||||||
from ..utils import check_consistency
|
from ..utils import check_consistency
|
||||||
from pina.optim import TorchOptimizer
|
|
||||||
|
|
||||||
|
|
||||||
class SwitchOptimizer(Callback):
|
class SwitchOptimizer(Callback):
|
||||||
|
|
||||||
def __init__(self, new_optimizers, epoch_switch):
|
|
||||||
"""
|
"""
|
||||||
PINA Implementation of a Lightning Callback to switch optimizer during
|
PINA Implementation of a Lightning Callback to switch optimizer during
|
||||||
training.
|
training.
|
||||||
|
"""
|
||||||
|
|
||||||
This callback allows for switching between different optimizers during
|
def __init__(self, new_optimizers, epoch_switch):
|
||||||
|
"""
|
||||||
|
This callback allows switching between different optimizers during
|
||||||
training, enabling the exploration of multiple optimization strategies
|
training, enabling the exploration of multiple optimization strategies
|
||||||
without the need to stop training.
|
without interrupting the training process.
|
||||||
|
|
||||||
:param new_optimizers: The model optimizers to switch to. Can be a
|
:param new_optimizers: The model optimizers to switch to. Can be a
|
||||||
single :class:`torch.optim.Optimizer` or a list of them for multiple
|
single :class:`torch.optim.Optimizer` instance or a list of them
|
||||||
model solver.
|
for multiple model solver.
|
||||||
:type new_optimizers: pina.optim.TorchOptimizer | list
|
:type new_optimizers: pina.optim.TorchOptimizer | list
|
||||||
:param epoch_switch: The epoch at which to switch to the new optimizer.
|
:param epoch_switch: The epoch at which the optimizer switch occurs.
|
||||||
:type epoch_switch: int
|
:type epoch_switch: int
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
@@ -46,7 +46,7 @@ class SwitchOptimizer(Callback):
|
|||||||
|
|
||||||
def on_train_epoch_start(self, trainer, __):
|
def on_train_epoch_start(self, trainer, __):
|
||||||
"""
|
"""
|
||||||
Callback function to switch optimizer at the start of each training epoch.
|
Switch the optimizer at the start of the specified training epoch.
|
||||||
|
|
||||||
:param trainer: The trainer object managing the training process.
|
:param trainer: The trainer object managing the training process.
|
||||||
:type trainer: pytorch_lightning.Trainer
|
:type trainer: pytorch_lightning.Trainer
|
||||||
@@ -59,7 +59,7 @@ class SwitchOptimizer(Callback):
|
|||||||
optims = []
|
optims = []
|
||||||
|
|
||||||
for idx, optim in enumerate(self._new_optimizers):
|
for idx, optim in enumerate(self._new_optimizers):
|
||||||
optim.hook(trainer.solver.models[idx].parameters())
|
optim.hook(trainer.solver._pina_models[idx].parameters())
|
||||||
optims.append(optim.instance)
|
optims.append(optim)
|
||||||
|
|
||||||
trainer.optimizers = optims
|
trainer.solver._pina_optimizers = optims
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
"""PINA Callbacks Implementations"""
|
"""PINA Callbacks Implementations"""
|
||||||
|
|
||||||
import torch
|
|
||||||
import copy
|
import copy
|
||||||
|
import torch
|
||||||
|
|
||||||
from lightning.pytorch.callbacks import Callback, TQDMProgressBar
|
from lightning.pytorch.callbacks import Callback, TQDMProgressBar
|
||||||
from lightning.pytorch.callbacks.progress.progress_bar import (
|
from lightning.pytorch.callbacks.progress.progress_bar import (
|
||||||
@@ -11,22 +11,37 @@ from pina.utils import check_consistency
|
|||||||
|
|
||||||
|
|
||||||
class MetricTracker(Callback):
|
class MetricTracker(Callback):
|
||||||
|
"""
|
||||||
|
Lightning Callback for Metric Tracking.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, metrics_to_track=None):
|
def __init__(self, metrics_to_track=None):
|
||||||
"""
|
"""
|
||||||
Lightning Callback for Metric Tracking.
|
Tracks specified metrics during training.
|
||||||
|
|
||||||
Tracks specific metrics during the training process.
|
:param metrics_to_track: List of metrics to track.
|
||||||
|
Defaults to train loss.
|
||||||
:ivar _collection: A list to store collected metrics after each epoch.
|
:type metrics_to_track: list[str], optional
|
||||||
|
|
||||||
:param metrics_to_track: List of metrics to track. Defaults to train/val loss.
|
|
||||||
:type metrics_to_track: list, optional
|
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._collection = []
|
self._collection = []
|
||||||
# Default to tracking 'train_loss' and 'val_loss' if not specified
|
# Default to tracking 'train_loss' if not specified
|
||||||
self.metrics_to_track = metrics_to_track or ["train_loss", "val_loss"]
|
self.metrics_to_track = metrics_to_track
|
||||||
|
|
||||||
|
def setup(self, trainer, pl_module, stage):
|
||||||
|
"""
|
||||||
|
Called when fit, validate, test, predict, or tune begins.
|
||||||
|
|
||||||
|
:param Trainer trainer: A :class:`~pina.trainer.Trainer` instance.
|
||||||
|
:param SolverInterface pl_module: A
|
||||||
|
:class:`~pina.solver.solver.SolverInterface` instance.
|
||||||
|
:param str stage: Either 'fit', 'test' or 'predict'.
|
||||||
|
"""
|
||||||
|
if self.metrics_to_track is None and trainer.batch_size is None:
|
||||||
|
self.metrics_to_track = ["train_loss"]
|
||||||
|
elif self.metrics_to_track is None:
|
||||||
|
self.metrics_to_track = ["train_loss_epoch"]
|
||||||
|
return super().setup(trainer, pl_module, stage)
|
||||||
|
|
||||||
def on_train_epoch_end(self, trainer, pl_module):
|
def on_train_epoch_end(self, trainer, pl_module):
|
||||||
"""
|
"""
|
||||||
@@ -71,26 +86,28 @@ class MetricTracker(Callback):
|
|||||||
|
|
||||||
|
|
||||||
class PINAProgressBar(TQDMProgressBar):
|
class PINAProgressBar(TQDMProgressBar):
|
||||||
|
"""
|
||||||
|
PINA Implementation of a Lightning Callback for enriching the progress bar.
|
||||||
|
"""
|
||||||
|
|
||||||
BAR_FORMAT = "{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_noinv_fmt}{postfix}]"
|
BAR_FORMAT = (
|
||||||
|
"{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, "
|
||||||
|
"{rate_noinv_fmt}{postfix}]"
|
||||||
|
)
|
||||||
|
|
||||||
def __init__(self, metrics="val", **kwargs):
|
def __init__(self, metrics="val", **kwargs):
|
||||||
"""
|
"""
|
||||||
PINA Implementation of a Lightning Callback for enriching the progress
|
This class enables the display of only relevant metrics during training.
|
||||||
bar.
|
|
||||||
|
|
||||||
This class provides functionality to display only relevant metrics
|
:param metrics: Logged metrics to be shown during the training.
|
||||||
during the training process.
|
Must be a subset of the conditions keys defined in
|
||||||
|
|
||||||
:param metrics: Logged metrics to display during the training. It should
|
|
||||||
be a subset of the conditions keys defined in
|
|
||||||
:obj:`pina.condition.Condition`.
|
:obj:`pina.condition.Condition`.
|
||||||
:type metrics: str | list(str) | tuple(str)
|
:type metrics: str | list(str) | tuple(str)
|
||||||
|
|
||||||
:Keyword Arguments:
|
:Keyword Arguments:
|
||||||
The additional keyword arguments specify the progress bar
|
The additional keyword arguments specify the progress bar and can be
|
||||||
and can be choosen from the `pytorch-lightning
|
choosen from the `pytorch-lightning TQDMProgressBar API
|
||||||
TQDMProgressBar API <https://lightning.ai/docs/pytorch/stable/_modules/lightning/pytorch/callback/progress/tqdm_progress.html#TQDMProgressBar>`_
|
<https://lightning.ai/docs/pytorch/stable/_modules/lightning/pytorch/callback/progress/tqdm_progress.html#TQDMProgressBar>`_
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
>>> pbar = PINAProgressBar(['mean'])
|
>>> pbar = PINAProgressBar(['mean'])
|
||||||
@@ -105,9 +122,9 @@ class PINAProgressBar(TQDMProgressBar):
|
|||||||
self._sorted_metrics = metrics
|
self._sorted_metrics = metrics
|
||||||
|
|
||||||
def get_metrics(self, trainer, pl_module):
|
def get_metrics(self, trainer, pl_module):
|
||||||
r"""Combines progress bar metrics collected from the trainer with
|
r"""Combine progress bar metrics collected from the trainer with
|
||||||
standard metrics from get_standard_metrics.
|
standard metrics from get_standard_metrics.
|
||||||
Implement this to override the items displayed in the progress bar.
|
Override this method to customize the items shown in the progress bar.
|
||||||
The progress bar metrics are sorted according to ``metrics``.
|
The progress bar metrics are sorted according to ``metrics``.
|
||||||
|
|
||||||
Here is an example of how to override the defaults:
|
Here is an example of how to override the defaults:
|
||||||
@@ -122,20 +139,20 @@ class PINAProgressBar(TQDMProgressBar):
|
|||||||
|
|
||||||
:return: Dictionary with the items to be displayed in the progress bar.
|
:return: Dictionary with the items to be displayed in the progress bar.
|
||||||
:rtype: tuple(dict)
|
:rtype: tuple(dict)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
standard_metrics = get_standard_metrics(trainer)
|
standard_metrics = get_standard_metrics(trainer)
|
||||||
pbar_metrics = trainer.progress_bar_metrics
|
pbar_metrics = trainer.progress_bar_metrics
|
||||||
if pbar_metrics:
|
if pbar_metrics:
|
||||||
pbar_metrics = {
|
pbar_metrics = {
|
||||||
key: pbar_metrics[key] for key in self._sorted_metrics
|
key: pbar_metrics[key]
|
||||||
|
for key in pbar_metrics
|
||||||
|
if key in self._sorted_metrics
|
||||||
}
|
}
|
||||||
return {**standard_metrics, **pbar_metrics}
|
return {**standard_metrics, **pbar_metrics}
|
||||||
|
|
||||||
def on_fit_start(self, trainer, pl_module):
|
def setup(self, trainer, pl_module, stage):
|
||||||
"""
|
"""
|
||||||
Check that the metrics defined in the initialization are available,
|
Check that the initialized metrics are available and correctly logged.
|
||||||
i.e. are correctly logged.
|
|
||||||
|
|
||||||
:param trainer: The trainer object managing the training process.
|
:param trainer: The trainer object managing the training process.
|
||||||
:type trainer: pytorch_lightning.Trainer
|
:type trainer: pytorch_lightning.Trainer
|
||||||
@@ -150,7 +167,11 @@ class PINAProgressBar(TQDMProgressBar):
|
|||||||
):
|
):
|
||||||
raise KeyError(f"Key '{key}' is not present in the dictionary")
|
raise KeyError(f"Key '{key}' is not present in the dictionary")
|
||||||
# add the loss pedix
|
# add the loss pedix
|
||||||
|
if trainer.batch_size is not None:
|
||||||
|
pedix = "_loss_epoch"
|
||||||
|
else:
|
||||||
|
pedix = "_loss"
|
||||||
self._sorted_metrics = [
|
self._sorted_metrics = [
|
||||||
metric + "_loss" for metric in self._sorted_metrics
|
metric + pedix for metric in self._sorted_metrics
|
||||||
]
|
]
|
||||||
return super().on_fit_start(trainer, pl_module)
|
return super().setup(trainer, pl_module, stage)
|
||||||
|
|||||||
@@ -64,8 +64,7 @@ class Trainer(lightning.pytorch.Trainer):
|
|||||||
:Keyword Arguments:
|
:Keyword Arguments:
|
||||||
The additional keyword arguments specify the training setup
|
The additional keyword arguments specify the training setup
|
||||||
and can be choosen from the `pytorch-lightning
|
and can be choosen from the `pytorch-lightning
|
||||||
Trainer API <https://lightning.ai/docs/pytorch/stable/common/
|
Trainer API <https://lightning.ai/docs/pytorch/stable/common/trainer.html#trainer-class-api>`_
|
||||||
trainer.html#trainer-class-api>`_
|
|
||||||
"""
|
"""
|
||||||
# check consistency for init types
|
# check consistency for init types
|
||||||
self._check_input_consistency(
|
self._check_input_consistency(
|
||||||
@@ -96,7 +95,6 @@ class Trainer(lightning.pytorch.Trainer):
|
|||||||
|
|
||||||
# Setting default kwargs, overriding lightning defaults
|
# Setting default kwargs, overriding lightning defaults
|
||||||
kwargs.setdefault("enable_progress_bar", True)
|
kwargs.setdefault("enable_progress_bar", True)
|
||||||
kwargs.setdefault("logger", None)
|
|
||||||
|
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
@@ -127,9 +125,6 @@ class Trainer(lightning.pytorch.Trainer):
|
|||||||
|
|
||||||
# logging
|
# logging
|
||||||
self.logging_kwargs = {
|
self.logging_kwargs = {
|
||||||
"logger": bool(
|
|
||||||
kwargs["logger"] is not None or kwargs["logger"] is True
|
|
||||||
),
|
|
||||||
"sync_dist": bool(
|
"sync_dist": bool(
|
||||||
len(self._accelerator_connector._parallel_devices) > 1
|
len(self._accelerator_connector._parallel_devices) > 1
|
||||||
),
|
),
|
||||||
|
|||||||
@@ -23,17 +23,18 @@ def test_metric_tracker_constructor():
|
|||||||
MetricTracker()
|
MetricTracker()
|
||||||
|
|
||||||
|
|
||||||
# def test_metric_tracker_routine(): #TODO revert
|
def test_metric_tracker_routine():
|
||||||
# # make the trainer
|
# make the trainer
|
||||||
# trainer = Trainer(solver=solver,
|
trainer = Trainer(
|
||||||
# callback=[
|
solver=solver,
|
||||||
# MetricTracker()
|
callbacks=[MetricTracker()],
|
||||||
# ],
|
accelerator="cpu",
|
||||||
# accelerator='cpu',
|
max_epochs=5,
|
||||||
# max_epochs=5)
|
log_every_n_steps=1,
|
||||||
# trainer.train()
|
)
|
||||||
# # get the tracked metrics
|
trainer.train()
|
||||||
# metrics = trainer.callback[0].metrics
|
# get the tracked metrics
|
||||||
# # assert the logged metrics are correct
|
metrics = trainer.callbacks[0].metrics
|
||||||
# logged_metrics = sorted(list(metrics.keys()))
|
# assert the logged metrics are correct
|
||||||
# assert logged_metrics == ['train_loss_epoch', 'train_loss_step', 'val_loss']
|
logged_metrics = sorted(list(metrics.keys()))
|
||||||
|
assert logged_metrics == ["train_loss"]
|
||||||
|
|||||||
@@ -21,19 +21,25 @@ model = FeedForward(
|
|||||||
# make the solver
|
# make the solver
|
||||||
solver = PINN(problem=poisson_problem, model=model)
|
solver = PINN(problem=poisson_problem, model=model)
|
||||||
|
|
||||||
adam_optimizer = TorchOptimizer(torch.optim.Adam, lr=0.01)
|
adam = TorchOptimizer(torch.optim.Adam, lr=0.01)
|
||||||
lbfgs_optimizer = TorchOptimizer(torch.optim.LBFGS, lr=0.001)
|
lbfgs = TorchOptimizer(torch.optim.LBFGS, lr=0.001)
|
||||||
|
|
||||||
|
|
||||||
def test_switch_optimizer_constructor():
|
def test_switch_optimizer_constructor():
|
||||||
SwitchOptimizer(adam_optimizer, epoch_switch=10)
|
SwitchOptimizer(adam, epoch_switch=10)
|
||||||
|
|
||||||
|
|
||||||
# def test_switch_optimizer_routine(): #TODO revert
|
def test_switch_optimizer_routine():
|
||||||
# # make the trainer
|
# check initial optimizer
|
||||||
# switch_opt_callback = SwitchOptimizer(lbfgs_optimizer, epoch_switch=3)
|
solver.configure_optimizers()
|
||||||
# trainer = Trainer(solver=solver,
|
assert solver.optimizer.instance.__class__ == torch.optim.Adam
|
||||||
# callback=[switch_opt_callback],
|
# make the trainer
|
||||||
# accelerator='cpu',
|
switch_opt_callback = SwitchOptimizer(lbfgs, epoch_switch=3)
|
||||||
# max_epochs=5)
|
trainer = Trainer(
|
||||||
# trainer.train()
|
solver=solver,
|
||||||
|
callbacks=[switch_opt_callback],
|
||||||
|
accelerator="cpu",
|
||||||
|
max_epochs=5,
|
||||||
|
)
|
||||||
|
trainer.train()
|
||||||
|
assert solver.optimizer.instance.__class__ == torch.optim.LBFGS
|
||||||
|
|||||||
@@ -5,29 +5,32 @@ from pina.callback.processing_callback import PINAProgressBar
|
|||||||
from pina.problem.zoo import Poisson2DSquareProblem as Poisson
|
from pina.problem.zoo import Poisson2DSquareProblem as Poisson
|
||||||
|
|
||||||
|
|
||||||
# # make the problem
|
# make the problem
|
||||||
# poisson_problem = Poisson()
|
poisson_problem = Poisson()
|
||||||
# boundaries = ['nil_g1', 'nil_g2', 'nil_g3', 'nil_g4']
|
boundaries = ["g1", "g2", "g3", "g4"]
|
||||||
# n = 10
|
n = 10
|
||||||
# poisson_problem.discretise_domain(n, 'grid', locations=boundaries)
|
condition_names = list(poisson_problem.conditions.keys())
|
||||||
# poisson_problem.discretise_domain(n, 'grid', locations='laplace_D')
|
poisson_problem.discretise_domain(n, "grid", domains=boundaries)
|
||||||
# model = FeedForward(len(poisson_problem.input_variables),
|
poisson_problem.discretise_domain(n, "grid", domains="D")
|
||||||
# len(poisson_problem.output_variables))
|
model = FeedForward(
|
||||||
|
len(poisson_problem.input_variables), len(poisson_problem.output_variables)
|
||||||
|
)
|
||||||
|
|
||||||
# # make the solver
|
# make the solver
|
||||||
# solver = PINN(problem=poisson_problem, model=model)
|
solver = PINN(problem=poisson_problem, model=model)
|
||||||
|
|
||||||
|
|
||||||
# def test_progress_bar_constructor():
|
def test_progress_bar_constructor():
|
||||||
# PINAProgressBar(['mean'])
|
PINAProgressBar()
|
||||||
|
|
||||||
# def test_progress_bar_routine():
|
|
||||||
# # make the trainer
|
def test_progress_bar_routine():
|
||||||
# trainer = Trainer(solver=solver,
|
# make the trainer
|
||||||
# callback=[
|
trainer = Trainer(
|
||||||
# PINAProgressBar(['mean', 'laplace_D'])
|
solver=solver,
|
||||||
# ],
|
callbacks=[PINAProgressBar(["val", condition_names[0]])],
|
||||||
# accelerator='cpu',
|
accelerator="cpu",
|
||||||
# max_epochs=5)
|
max_epochs=5,
|
||||||
# trainer.train()
|
)
|
||||||
# # TODO there should be a check that the correct metrics are displayed
|
trainer.train()
|
||||||
|
# TODO there should be a check that the correct metrics are displayed
|
||||||
|
|||||||
Reference in New Issue
Block a user