Update callbacks and tests (#482)
--------- Co-authored-by: giovanni <giovanni.canali98@yahoo.it>
This commit is contained in:
committed by
Nicola Demo
parent
6ae301622b
commit
632934f9cc
@@ -1,5 +1,6 @@
|
||||
"""PINA Callbacks Implementations"""
|
||||
|
||||
import importlib.metadata
|
||||
import torch
|
||||
from lightning.pytorch.callbacks import Callback
|
||||
from ..label_tensor import LabelTensor
|
||||
@@ -7,17 +8,17 @@ from ..utils import check_consistency
|
||||
|
||||
|
||||
class R3Refinement(Callback):
|
||||
"""
|
||||
PINA Implementation of an R3 Refinement Callback.
|
||||
"""
|
||||
|
||||
def __init__(self, sample_every):
|
||||
"""
|
||||
PINA Implementation of an R3 Refinement Callback.
|
||||
|
||||
This callback implements the R3 (Retain-Resample-Release) routine for
|
||||
sampling new points based on adaptive search.
|
||||
The algorithm incrementally accumulates collocation points in regions
|
||||
of high PDE residuals, and releases those
|
||||
with low residuals. Points are sampled uniformly in all regions
|
||||
where sampling is needed.
|
||||
of high PDE residuals, and releases those with low residuals.
|
||||
Points are sampled uniformly in all regions where sampling is needed.
|
||||
|
||||
.. seealso::
|
||||
|
||||
@@ -33,142 +34,148 @@ class R3Refinement(Callback):
|
||||
Example:
|
||||
>>> r3_callback = R3Refinement(sample_every=5)
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
# sample every
|
||||
check_consistency(sample_every, int)
|
||||
self._sample_every = sample_every
|
||||
self._const_pts = None
|
||||
|
||||
def _compute_residual(self, trainer):
|
||||
"""
|
||||
Computes the residuals for a PINN object.
|
||||
|
||||
:return: the total loss, and pointwise loss.
|
||||
:rtype: tuple
|
||||
"""
|
||||
|
||||
# extract the solver and device from trainer
|
||||
solver = trainer.solver
|
||||
device = trainer._accelerator_connector._accelerator_flag
|
||||
precision = trainer.precision
|
||||
if precision == "64-true":
|
||||
precision = torch.float64
|
||||
elif precision == "32-true":
|
||||
precision = torch.float32
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"Currently R3Refinement is only implemented "
|
||||
"for precision '32-true' and '64-true', set "
|
||||
"Trainer precision to match one of the "
|
||||
"available precisions."
|
||||
raise NotImplementedError(
|
||||
"R3Refinement callback is being refactored in the pina "
|
||||
f"{importlib.metadata.metadata('pina-mathlab')['Version']} "
|
||||
"version. Please use version 0.1 if R3Refinement is required."
|
||||
)
|
||||
|
||||
# compute residual
|
||||
res_loss = {}
|
||||
tot_loss = []
|
||||
for location in self._sampling_locations: # TODO fix for new collector
|
||||
condition = solver.problem.conditions[location]
|
||||
pts = solver.problem.input_pts[location]
|
||||
# send points to correct device
|
||||
pts = pts.to(device=device, dtype=precision)
|
||||
pts = pts.requires_grad_(True)
|
||||
pts.retain_grad()
|
||||
# PINN loss: equation evaluated only for sampling locations
|
||||
target = condition.equation.residual(pts, solver.forward(pts))
|
||||
res_loss[location] = torch.abs(target).as_subclass(torch.Tensor)
|
||||
tot_loss.append(torch.abs(target))
|
||||
# super().__init__()
|
||||
|
||||
print(tot_loss)
|
||||
# # sample every
|
||||
# check_consistency(sample_every, int)
|
||||
# self._sample_every = sample_every
|
||||
# self._const_pts = None
|
||||
|
||||
return torch.vstack(tot_loss), res_loss
|
||||
# def _compute_residual(self, trainer):
|
||||
# """
|
||||
# Computes the residuals for a PINN object.
|
||||
|
||||
def _r3_routine(self, trainer):
|
||||
"""
|
||||
R3 refinement main routine.
|
||||
# :return: the total loss, and pointwise loss.
|
||||
# :rtype: tuple
|
||||
# """
|
||||
|
||||
:param Trainer trainer: PINA Trainer.
|
||||
"""
|
||||
# compute residual (all device possible)
|
||||
tot_loss, res_loss = self._compute_residual(trainer)
|
||||
tot_loss = tot_loss.as_subclass(torch.Tensor)
|
||||
# # extract the solver and device from trainer
|
||||
# solver = trainer.solver
|
||||
# device = trainer._accelerator_connector._accelerator_flag
|
||||
# precision = trainer.precision
|
||||
# if precision == "64-true":
|
||||
# precision = torch.float64
|
||||
# elif precision == "32-true":
|
||||
# precision = torch.float32
|
||||
# else:
|
||||
# raise RuntimeError(
|
||||
# "Currently R3Refinement is only implemented "
|
||||
# "for precision '32-true' and '64-true', set "
|
||||
# "Trainer precision to match one of the "
|
||||
# "available precisions."
|
||||
# )
|
||||
|
||||
# !!!!!! From now everything is performed on CPU !!!!!!
|
||||
# # compute residual
|
||||
# res_loss = {}
|
||||
# tot_loss = []
|
||||
# for location in self._sampling_locations:
|
||||
# condition = solver.problem.conditions[location]
|
||||
# pts = solver.problem.input_pts[location]
|
||||
# # send points to correct device
|
||||
# pts = pts.to(device=device, dtype=precision)
|
||||
# pts = pts.requires_grad_(True)
|
||||
# pts.retain_grad()
|
||||
# # PINN loss: equation evaluated only for sampling locations
|
||||
# target = condition.equation.residual(pts, solver.forward(pts))
|
||||
# res_loss[location] = torch.abs(target).as_subclass(torch.Tensor)
|
||||
# tot_loss.append(torch.abs(target))
|
||||
|
||||
# average loss
|
||||
avg = (tot_loss.mean()).to("cpu")
|
||||
old_pts = {} # points to be retained
|
||||
for location in self._sampling_locations:
|
||||
pts = trainer._model.problem.input_pts[location]
|
||||
labels = pts.labels
|
||||
pts = pts.cpu().detach().as_subclass(torch.Tensor)
|
||||
residuals = res_loss[location].cpu()
|
||||
mask = (residuals > avg).flatten()
|
||||
if any(mask): # append residuals greater than average
|
||||
pts = (pts[mask]).as_subclass(LabelTensor)
|
||||
pts.labels = labels
|
||||
old_pts[location] = pts
|
||||
numb_pts = self._const_pts[location] - len(old_pts[location])
|
||||
# sample new points
|
||||
trainer._model.problem.discretise_domain(
|
||||
numb_pts, "random", locations=[location]
|
||||
)
|
||||
# print(tot_loss)
|
||||
|
||||
else: # if no res greater than average, samples all uniformly
|
||||
numb_pts = self._const_pts[location]
|
||||
# sample new points
|
||||
trainer._model.problem.discretise_domain(
|
||||
numb_pts, "random", locations=[location]
|
||||
)
|
||||
# adding previous population points
|
||||
trainer._model.problem.add_points(old_pts)
|
||||
# return torch.vstack(tot_loss), res_loss
|
||||
|
||||
# update dataloader
|
||||
trainer._create_or_update_loader()
|
||||
# def _r3_routine(self, trainer):
|
||||
# """
|
||||
# R3 refinement main routine.
|
||||
|
||||
def on_train_start(self, trainer, _):
|
||||
"""
|
||||
Callback function called at the start of training.
|
||||
# :param Trainer trainer: PINA Trainer.
|
||||
# """
|
||||
# # compute residual (all device possible)
|
||||
# tot_loss, res_loss = self._compute_residual(trainer)
|
||||
# tot_loss = tot_loss.as_subclass(torch.Tensor)
|
||||
|
||||
This method extracts the locations for sampling from the problem
|
||||
conditions and calculates the total population.
|
||||
# # !!!!!! From now everything is performed on CPU !!!!!!
|
||||
|
||||
:param trainer: The trainer object managing the training process.
|
||||
:type trainer: pytorch_lightning.Trainer
|
||||
:param _: Placeholder argument (not used).
|
||||
# # average loss
|
||||
# avg = (tot_loss.mean()).to("cpu")
|
||||
# old_pts = {} # points to be retained
|
||||
# for location in self._sampling_locations:
|
||||
# pts = trainer._model.problem.input_pts[location]
|
||||
# labels = pts.labels
|
||||
# pts = pts.cpu().detach().as_subclass(torch.Tensor)
|
||||
# residuals = res_loss[location].cpu()
|
||||
# mask = (residuals > avg).flatten()
|
||||
# if any(mask): # append residuals greater than average
|
||||
# pts = (pts[mask]).as_subclass(LabelTensor)
|
||||
# pts.labels = labels
|
||||
# old_pts[location] = pts
|
||||
# numb_pts = self._const_pts[location] - len(old_pts[location])
|
||||
# # sample new points
|
||||
# trainer._model.problem.discretise_domain(
|
||||
# numb_pts, "random", locations=[location]
|
||||
# )
|
||||
|
||||
:return: None
|
||||
:rtype: None
|
||||
"""
|
||||
# extract locations for sampling
|
||||
problem = trainer.solver.problem
|
||||
locations = []
|
||||
for condition_name in problem.conditions:
|
||||
condition = problem.conditions[condition_name]
|
||||
if hasattr(condition, "location"):
|
||||
locations.append(condition_name)
|
||||
self._sampling_locations = locations
|
||||
# else: # if no res greater than average, samples all uniformly
|
||||
# numb_pts = self._const_pts[location]
|
||||
# # sample new points
|
||||
# trainer._model.problem.discretise_domain(
|
||||
# numb_pts, "random", locations=[location]
|
||||
# )
|
||||
# # adding previous population points
|
||||
# trainer._model.problem.add_points(old_pts)
|
||||
|
||||
# extract total population
|
||||
const_pts = {} # for each location, store the # of pts to keep constant
|
||||
for location in self._sampling_locations:
|
||||
pts = trainer._model.problem.input_pts[location]
|
||||
const_pts[location] = len(pts)
|
||||
self._const_pts = const_pts
|
||||
# # update dataloader
|
||||
# trainer._create_or_update_loader()
|
||||
|
||||
def on_train_epoch_end(self, trainer, __):
|
||||
"""
|
||||
Callback function called at the end of each training epoch.
|
||||
# def on_train_start(self, trainer, _):
|
||||
# """
|
||||
# Callback function called at the start of training.
|
||||
|
||||
This method triggers the R3 routine for refinement if the current
|
||||
epoch is a multiple of `_sample_every`.
|
||||
# This method extracts the locations for sampling from the problem
|
||||
# conditions and calculates the total population.
|
||||
|
||||
:param trainer: The trainer object managing the training process.
|
||||
:type trainer: pytorch_lightning.Trainer
|
||||
:param __: Placeholder argument (not used).
|
||||
# :param trainer: The trainer object managing the training process.
|
||||
# :type trainer: pytorch_lightning.Trainer
|
||||
# :param _: Placeholder argument (not used).
|
||||
|
||||
:return: None
|
||||
:rtype: None
|
||||
"""
|
||||
if trainer.current_epoch % self._sample_every == 0:
|
||||
self._r3_routine(trainer)
|
||||
# :return: None
|
||||
# :rtype: None
|
||||
# """
|
||||
# # extract locations for sampling
|
||||
# problem = trainer.solver.problem
|
||||
# locations = []
|
||||
# for condition_name in problem.conditions:
|
||||
# condition = problem.conditions[condition_name]
|
||||
# if hasattr(condition, "location"):
|
||||
# locations.append(condition_name)
|
||||
# self._sampling_locations = locations
|
||||
|
||||
# # extract total population
|
||||
# const_pts = {} # for each location, store the pts to keep constant
|
||||
# for location in self._sampling_locations:
|
||||
# pts = trainer._model.problem.input_pts[location]
|
||||
# const_pts[location] = len(pts)
|
||||
# self._const_pts = const_pts
|
||||
|
||||
# def on_train_epoch_end(self, trainer, __):
|
||||
# """
|
||||
# Callback function called at the end of each training epoch.
|
||||
|
||||
# This method triggers the R3 routine for refinement if the current
|
||||
# epoch is a multiple of `_sample_every`.
|
||||
|
||||
# :param trainer: The trainer object managing the training process.
|
||||
# :type trainer: pytorch_lightning.Trainer
|
||||
# :param __: Placeholder argument (not used).
|
||||
|
||||
# :return: None
|
||||
# :rtype: None
|
||||
# """
|
||||
# if trainer.current_epoch % self._sample_every == 0:
|
||||
# self._r3_routine(trainer)
|
||||
|
||||
@@ -37,12 +37,13 @@ class LinearWeightUpdate(Callback):
|
||||
check_consistency(self.initial_value, (float, int), subclass=False)
|
||||
check_consistency(self.target_value, (float, int), subclass=False)
|
||||
|
||||
def on_train_start(self, trainer, solver):
|
||||
def on_train_start(self, trainer, pl_module):
|
||||
"""
|
||||
Initialize the weight of the condition to the specified `initial_value`.
|
||||
|
||||
:param Trainer trainer: a pina:class:`Trainer` instance.
|
||||
:param SolverInterface solver: a pina:class:`SolverInterface` instance.
|
||||
:param Trainer trainer: A :class:`~pina.trainer.Trainer` instance.
|
||||
:param SolverInterface pl_module: A
|
||||
:class:`~pina.solver.solver.SolverInterface` instance.
|
||||
"""
|
||||
# Check that the target epoch is valid
|
||||
if not 0 < self.target_epoch <= trainer.max_epochs:
|
||||
@@ -52,7 +53,7 @@ class LinearWeightUpdate(Callback):
|
||||
)
|
||||
|
||||
# Check that the condition is a problem condition
|
||||
if self.condition_name not in solver.problem.conditions:
|
||||
if self.condition_name not in pl_module.problem.conditions:
|
||||
raise ValueError(
|
||||
f"`{self.condition_name}` must be a problem condition."
|
||||
)
|
||||
@@ -66,20 +67,21 @@ class LinearWeightUpdate(Callback):
|
||||
)
|
||||
|
||||
# Check that the weighting schema is ScalarWeighting
|
||||
if not isinstance(solver.weighting, ScalarWeighting):
|
||||
if not isinstance(pl_module.weighting, ScalarWeighting):
|
||||
raise ValueError("The weighting schema must be ScalarWeighting.")
|
||||
|
||||
# Initialize the weight of the condition
|
||||
solver.weighting.weights[self.condition_name] = self.initial_value
|
||||
pl_module.weighting.weights[self.condition_name] = self.initial_value
|
||||
|
||||
def on_train_epoch_start(self, trainer, solver):
|
||||
def on_train_epoch_start(self, trainer, pl_module):
|
||||
"""
|
||||
Adjust at each epoch the weight of the condition.
|
||||
|
||||
:param Trainer trainer: a pina:class:`Trainer` instance.
|
||||
:param SolverInterface solver: a pina:class:`SolverInterface` instance.
|
||||
:param Trainer trainer: A :class:`~pina.trainer.Trainer` instance.
|
||||
:param SolverInterface pl_module: A
|
||||
:class:`~pina.solver.solver.SolverInterface` instance.
|
||||
"""
|
||||
if 0 < trainer.current_epoch <= self.target_epoch:
|
||||
solver.weighting.weights[self.condition_name] += (
|
||||
pl_module.weighting.weights[self.condition_name] += (
|
||||
self.target_value - self.initial_value
|
||||
) / (self.target_epoch - 1)
|
||||
|
||||
@@ -1,27 +1,27 @@
|
||||
"""PINA Callbacks Implementations"""
|
||||
|
||||
from lightning.pytorch.callbacks import Callback
|
||||
import torch
|
||||
from ..optim import TorchOptimizer
|
||||
from ..utils import check_consistency
|
||||
from pina.optim import TorchOptimizer
|
||||
|
||||
|
||||
class SwitchOptimizer(Callback):
|
||||
|
||||
def __init__(self, new_optimizers, epoch_switch):
|
||||
"""
|
||||
PINA Implementation of a Lightning Callback to switch optimizer during
|
||||
training.
|
||||
"""
|
||||
|
||||
This callback allows for switching between different optimizers during
|
||||
def __init__(self, new_optimizers, epoch_switch):
|
||||
"""
|
||||
This callback allows switching between different optimizers during
|
||||
training, enabling the exploration of multiple optimization strategies
|
||||
without the need to stop training.
|
||||
without interrupting the training process.
|
||||
|
||||
:param new_optimizers: The model optimizers to switch to. Can be a
|
||||
single :class:`torch.optim.Optimizer` or a list of them for multiple
|
||||
model solver.
|
||||
single :class:`torch.optim.Optimizer` instance or a list of them
|
||||
for multiple model solver.
|
||||
:type new_optimizers: pina.optim.TorchOptimizer | list
|
||||
:param epoch_switch: The epoch at which to switch to the new optimizer.
|
||||
:param epoch_switch: The epoch at which the optimizer switch occurs.
|
||||
:type epoch_switch: int
|
||||
|
||||
Example:
|
||||
@@ -46,7 +46,7 @@ class SwitchOptimizer(Callback):
|
||||
|
||||
def on_train_epoch_start(self, trainer, __):
|
||||
"""
|
||||
Callback function to switch optimizer at the start of each training epoch.
|
||||
Switch the optimizer at the start of the specified training epoch.
|
||||
|
||||
:param trainer: The trainer object managing the training process.
|
||||
:type trainer: pytorch_lightning.Trainer
|
||||
@@ -59,7 +59,7 @@ class SwitchOptimizer(Callback):
|
||||
optims = []
|
||||
|
||||
for idx, optim in enumerate(self._new_optimizers):
|
||||
optim.hook(trainer.solver.models[idx].parameters())
|
||||
optims.append(optim.instance)
|
||||
optim.hook(trainer.solver._pina_models[idx].parameters())
|
||||
optims.append(optim)
|
||||
|
||||
trainer.optimizers = optims
|
||||
trainer.solver._pina_optimizers = optims
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""PINA Callbacks Implementations"""
|
||||
|
||||
import torch
|
||||
import copy
|
||||
import torch
|
||||
|
||||
from lightning.pytorch.callbacks import Callback, TQDMProgressBar
|
||||
from lightning.pytorch.callbacks.progress.progress_bar import (
|
||||
@@ -11,22 +11,37 @@ from pina.utils import check_consistency
|
||||
|
||||
|
||||
class MetricTracker(Callback):
|
||||
"""
|
||||
Lightning Callback for Metric Tracking.
|
||||
"""
|
||||
|
||||
def __init__(self, metrics_to_track=None):
|
||||
"""
|
||||
Lightning Callback for Metric Tracking.
|
||||
Tracks specified metrics during training.
|
||||
|
||||
Tracks specific metrics during the training process.
|
||||
|
||||
:ivar _collection: A list to store collected metrics after each epoch.
|
||||
|
||||
:param metrics_to_track: List of metrics to track. Defaults to train/val loss.
|
||||
:type metrics_to_track: list, optional
|
||||
:param metrics_to_track: List of metrics to track.
|
||||
Defaults to train loss.
|
||||
:type metrics_to_track: list[str], optional
|
||||
"""
|
||||
super().__init__()
|
||||
self._collection = []
|
||||
# Default to tracking 'train_loss' and 'val_loss' if not specified
|
||||
self.metrics_to_track = metrics_to_track or ["train_loss", "val_loss"]
|
||||
# Default to tracking 'train_loss' if not specified
|
||||
self.metrics_to_track = metrics_to_track
|
||||
|
||||
def setup(self, trainer, pl_module, stage):
|
||||
"""
|
||||
Called when fit, validate, test, predict, or tune begins.
|
||||
|
||||
:param Trainer trainer: A :class:`~pina.trainer.Trainer` instance.
|
||||
:param SolverInterface pl_module: A
|
||||
:class:`~pina.solver.solver.SolverInterface` instance.
|
||||
:param str stage: Either 'fit', 'test' or 'predict'.
|
||||
"""
|
||||
if self.metrics_to_track is None and trainer.batch_size is None:
|
||||
self.metrics_to_track = ["train_loss"]
|
||||
elif self.metrics_to_track is None:
|
||||
self.metrics_to_track = ["train_loss_epoch"]
|
||||
return super().setup(trainer, pl_module, stage)
|
||||
|
||||
def on_train_epoch_end(self, trainer, pl_module):
|
||||
"""
|
||||
@@ -71,26 +86,28 @@ class MetricTracker(Callback):
|
||||
|
||||
|
||||
class PINAProgressBar(TQDMProgressBar):
|
||||
"""
|
||||
PINA Implementation of a Lightning Callback for enriching the progress bar.
|
||||
"""
|
||||
|
||||
BAR_FORMAT = "{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_noinv_fmt}{postfix}]"
|
||||
BAR_FORMAT = (
|
||||
"{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, "
|
||||
"{rate_noinv_fmt}{postfix}]"
|
||||
)
|
||||
|
||||
def __init__(self, metrics="val", **kwargs):
|
||||
"""
|
||||
PINA Implementation of a Lightning Callback for enriching the progress
|
||||
bar.
|
||||
This class enables the display of only relevant metrics during training.
|
||||
|
||||
This class provides functionality to display only relevant metrics
|
||||
during the training process.
|
||||
|
||||
:param metrics: Logged metrics to display during the training. It should
|
||||
be a subset of the conditions keys defined in
|
||||
:param metrics: Logged metrics to be shown during the training.
|
||||
Must be a subset of the conditions keys defined in
|
||||
:obj:`pina.condition.Condition`.
|
||||
:type metrics: str | list(str) | tuple(str)
|
||||
|
||||
:Keyword Arguments:
|
||||
The additional keyword arguments specify the progress bar
|
||||
and can be choosen from the `pytorch-lightning
|
||||
TQDMProgressBar API <https://lightning.ai/docs/pytorch/stable/_modules/lightning/pytorch/callback/progress/tqdm_progress.html#TQDMProgressBar>`_
|
||||
The additional keyword arguments specify the progress bar and can be
|
||||
choosen from the `pytorch-lightning TQDMProgressBar API
|
||||
<https://lightning.ai/docs/pytorch/stable/_modules/lightning/pytorch/callback/progress/tqdm_progress.html#TQDMProgressBar>`_
|
||||
|
||||
Example:
|
||||
>>> pbar = PINAProgressBar(['mean'])
|
||||
@@ -105,9 +122,9 @@ class PINAProgressBar(TQDMProgressBar):
|
||||
self._sorted_metrics = metrics
|
||||
|
||||
def get_metrics(self, trainer, pl_module):
|
||||
r"""Combines progress bar metrics collected from the trainer with
|
||||
r"""Combine progress bar metrics collected from the trainer with
|
||||
standard metrics from get_standard_metrics.
|
||||
Implement this to override the items displayed in the progress bar.
|
||||
Override this method to customize the items shown in the progress bar.
|
||||
The progress bar metrics are sorted according to ``metrics``.
|
||||
|
||||
Here is an example of how to override the defaults:
|
||||
@@ -122,20 +139,20 @@ class PINAProgressBar(TQDMProgressBar):
|
||||
|
||||
:return: Dictionary with the items to be displayed in the progress bar.
|
||||
:rtype: tuple(dict)
|
||||
|
||||
"""
|
||||
standard_metrics = get_standard_metrics(trainer)
|
||||
pbar_metrics = trainer.progress_bar_metrics
|
||||
if pbar_metrics:
|
||||
pbar_metrics = {
|
||||
key: pbar_metrics[key] for key in self._sorted_metrics
|
||||
key: pbar_metrics[key]
|
||||
for key in pbar_metrics
|
||||
if key in self._sorted_metrics
|
||||
}
|
||||
return {**standard_metrics, **pbar_metrics}
|
||||
|
||||
def on_fit_start(self, trainer, pl_module):
|
||||
def setup(self, trainer, pl_module, stage):
|
||||
"""
|
||||
Check that the metrics defined in the initialization are available,
|
||||
i.e. are correctly logged.
|
||||
Check that the initialized metrics are available and correctly logged.
|
||||
|
||||
:param trainer: The trainer object managing the training process.
|
||||
:type trainer: pytorch_lightning.Trainer
|
||||
@@ -150,7 +167,11 @@ class PINAProgressBar(TQDMProgressBar):
|
||||
):
|
||||
raise KeyError(f"Key '{key}' is not present in the dictionary")
|
||||
# add the loss pedix
|
||||
if trainer.batch_size is not None:
|
||||
pedix = "_loss_epoch"
|
||||
else:
|
||||
pedix = "_loss"
|
||||
self._sorted_metrics = [
|
||||
metric + "_loss" for metric in self._sorted_metrics
|
||||
metric + pedix for metric in self._sorted_metrics
|
||||
]
|
||||
return super().on_fit_start(trainer, pl_module)
|
||||
return super().setup(trainer, pl_module, stage)
|
||||
|
||||
@@ -64,8 +64,7 @@ class Trainer(lightning.pytorch.Trainer):
|
||||
:Keyword Arguments:
|
||||
The additional keyword arguments specify the training setup
|
||||
and can be choosen from the `pytorch-lightning
|
||||
Trainer API <https://lightning.ai/docs/pytorch/stable/common/
|
||||
trainer.html#trainer-class-api>`_
|
||||
Trainer API <https://lightning.ai/docs/pytorch/stable/common/trainer.html#trainer-class-api>`_
|
||||
"""
|
||||
# check consistency for init types
|
||||
self._check_input_consistency(
|
||||
@@ -96,7 +95,6 @@ class Trainer(lightning.pytorch.Trainer):
|
||||
|
||||
# Setting default kwargs, overriding lightning defaults
|
||||
kwargs.setdefault("enable_progress_bar", True)
|
||||
kwargs.setdefault("logger", None)
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@@ -127,9 +125,6 @@ class Trainer(lightning.pytorch.Trainer):
|
||||
|
||||
# logging
|
||||
self.logging_kwargs = {
|
||||
"logger": bool(
|
||||
kwargs["logger"] is not None or kwargs["logger"] is True
|
||||
),
|
||||
"sync_dist": bool(
|
||||
len(self._accelerator_connector._parallel_devices) > 1
|
||||
),
|
||||
|
||||
@@ -23,17 +23,18 @@ def test_metric_tracker_constructor():
|
||||
MetricTracker()
|
||||
|
||||
|
||||
# def test_metric_tracker_routine(): #TODO revert
|
||||
# # make the trainer
|
||||
# trainer = Trainer(solver=solver,
|
||||
# callback=[
|
||||
# MetricTracker()
|
||||
# ],
|
||||
# accelerator='cpu',
|
||||
# max_epochs=5)
|
||||
# trainer.train()
|
||||
# # get the tracked metrics
|
||||
# metrics = trainer.callback[0].metrics
|
||||
# # assert the logged metrics are correct
|
||||
# logged_metrics = sorted(list(metrics.keys()))
|
||||
# assert logged_metrics == ['train_loss_epoch', 'train_loss_step', 'val_loss']
|
||||
def test_metric_tracker_routine():
|
||||
# make the trainer
|
||||
trainer = Trainer(
|
||||
solver=solver,
|
||||
callbacks=[MetricTracker()],
|
||||
accelerator="cpu",
|
||||
max_epochs=5,
|
||||
log_every_n_steps=1,
|
||||
)
|
||||
trainer.train()
|
||||
# get the tracked metrics
|
||||
metrics = trainer.callbacks[0].metrics
|
||||
# assert the logged metrics are correct
|
||||
logged_metrics = sorted(list(metrics.keys()))
|
||||
assert logged_metrics == ["train_loss"]
|
||||
|
||||
@@ -21,19 +21,25 @@ model = FeedForward(
|
||||
# make the solver
|
||||
solver = PINN(problem=poisson_problem, model=model)
|
||||
|
||||
adam_optimizer = TorchOptimizer(torch.optim.Adam, lr=0.01)
|
||||
lbfgs_optimizer = TorchOptimizer(torch.optim.LBFGS, lr=0.001)
|
||||
adam = TorchOptimizer(torch.optim.Adam, lr=0.01)
|
||||
lbfgs = TorchOptimizer(torch.optim.LBFGS, lr=0.001)
|
||||
|
||||
|
||||
def test_switch_optimizer_constructor():
|
||||
SwitchOptimizer(adam_optimizer, epoch_switch=10)
|
||||
SwitchOptimizer(adam, epoch_switch=10)
|
||||
|
||||
|
||||
# def test_switch_optimizer_routine(): #TODO revert
|
||||
# # make the trainer
|
||||
# switch_opt_callback = SwitchOptimizer(lbfgs_optimizer, epoch_switch=3)
|
||||
# trainer = Trainer(solver=solver,
|
||||
# callback=[switch_opt_callback],
|
||||
# accelerator='cpu',
|
||||
# max_epochs=5)
|
||||
# trainer.train()
|
||||
def test_switch_optimizer_routine():
|
||||
# check initial optimizer
|
||||
solver.configure_optimizers()
|
||||
assert solver.optimizer.instance.__class__ == torch.optim.Adam
|
||||
# make the trainer
|
||||
switch_opt_callback = SwitchOptimizer(lbfgs, epoch_switch=3)
|
||||
trainer = Trainer(
|
||||
solver=solver,
|
||||
callbacks=[switch_opt_callback],
|
||||
accelerator="cpu",
|
||||
max_epochs=5,
|
||||
)
|
||||
trainer.train()
|
||||
assert solver.optimizer.instance.__class__ == torch.optim.LBFGS
|
||||
|
||||
@@ -5,29 +5,32 @@ from pina.callback.processing_callback import PINAProgressBar
|
||||
from pina.problem.zoo import Poisson2DSquareProblem as Poisson
|
||||
|
||||
|
||||
# # make the problem
|
||||
# poisson_problem = Poisson()
|
||||
# boundaries = ['nil_g1', 'nil_g2', 'nil_g3', 'nil_g4']
|
||||
# n = 10
|
||||
# poisson_problem.discretise_domain(n, 'grid', locations=boundaries)
|
||||
# poisson_problem.discretise_domain(n, 'grid', locations='laplace_D')
|
||||
# model = FeedForward(len(poisson_problem.input_variables),
|
||||
# len(poisson_problem.output_variables))
|
||||
# make the problem
|
||||
poisson_problem = Poisson()
|
||||
boundaries = ["g1", "g2", "g3", "g4"]
|
||||
n = 10
|
||||
condition_names = list(poisson_problem.conditions.keys())
|
||||
poisson_problem.discretise_domain(n, "grid", domains=boundaries)
|
||||
poisson_problem.discretise_domain(n, "grid", domains="D")
|
||||
model = FeedForward(
|
||||
len(poisson_problem.input_variables), len(poisson_problem.output_variables)
|
||||
)
|
||||
|
||||
# # make the solver
|
||||
# solver = PINN(problem=poisson_problem, model=model)
|
||||
# make the solver
|
||||
solver = PINN(problem=poisson_problem, model=model)
|
||||
|
||||
|
||||
# def test_progress_bar_constructor():
|
||||
# PINAProgressBar(['mean'])
|
||||
def test_progress_bar_constructor():
|
||||
PINAProgressBar()
|
||||
|
||||
# def test_progress_bar_routine():
|
||||
# # make the trainer
|
||||
# trainer = Trainer(solver=solver,
|
||||
# callback=[
|
||||
# PINAProgressBar(['mean', 'laplace_D'])
|
||||
# ],
|
||||
# accelerator='cpu',
|
||||
# max_epochs=5)
|
||||
# trainer.train()
|
||||
# # TODO there should be a check that the correct metrics are displayed
|
||||
|
||||
def test_progress_bar_routine():
|
||||
# make the trainer
|
||||
trainer = Trainer(
|
||||
solver=solver,
|
||||
callbacks=[PINAProgressBar(["val", condition_names[0]])],
|
||||
accelerator="cpu",
|
||||
max_epochs=5,
|
||||
)
|
||||
trainer.train()
|
||||
# TODO there should be a check that the correct metrics are displayed
|
||||
|
||||
Reference in New Issue
Block a user