Update callbacks and tests (#482)

---------

Co-authored-by: giovanni <giovanni.canali98@yahoo.it>
This commit is contained in:
Dario Coscia
2025-03-13 16:19:38 +01:00
committed by Nicola Demo
parent 6ae301622b
commit 632934f9cc
8 changed files with 264 additions and 229 deletions

View File

@@ -1,5 +1,6 @@
"""PINA Callbacks Implementations""" """PINA Callbacks Implementations"""
import importlib.metadata
import torch import torch
from lightning.pytorch.callbacks import Callback from lightning.pytorch.callbacks import Callback
from ..label_tensor import LabelTensor from ..label_tensor import LabelTensor
@@ -7,17 +8,17 @@ from ..utils import check_consistency
class R3Refinement(Callback): class R3Refinement(Callback):
"""
PINA Implementation of an R3 Refinement Callback.
"""
def __init__(self, sample_every): def __init__(self, sample_every):
""" """
PINA Implementation of an R3 Refinement Callback.
This callback implements the R3 (Retain-Resample-Release) routine for This callback implements the R3 (Retain-Resample-Release) routine for
sampling new points based on adaptive search. sampling new points based on adaptive search.
The algorithm incrementally accumulates collocation points in regions The algorithm incrementally accumulates collocation points in regions
of high PDE residuals, and releases those of high PDE residuals, and releases those with low residuals.
with low residuals. Points are sampled uniformly in all regions Points are sampled uniformly in all regions where sampling is needed.
where sampling is needed.
.. seealso:: .. seealso::
@@ -33,142 +34,148 @@ class R3Refinement(Callback):
Example: Example:
>>> r3_callback = R3Refinement(sample_every=5) >>> r3_callback = R3Refinement(sample_every=5)
""" """
super().__init__() raise NotImplementedError(
"R3Refinement callback is being refactored in the pina "
# sample every f"{importlib.metadata.metadata('pina-mathlab')['Version']} "
check_consistency(sample_every, int) "version. Please use version 0.1 if R3Refinement is required."
self._sample_every = sample_every
self._const_pts = None
def _compute_residual(self, trainer):
"""
Computes the residuals for a PINN object.
:return: the total loss, and pointwise loss.
:rtype: tuple
"""
# extract the solver and device from trainer
solver = trainer.solver
device = trainer._accelerator_connector._accelerator_flag
precision = trainer.precision
if precision == "64-true":
precision = torch.float64
elif precision == "32-true":
precision = torch.float32
else:
raise RuntimeError(
"Currently R3Refinement is only implemented "
"for precision '32-true' and '64-true', set "
"Trainer precision to match one of the "
"available precisions."
) )
# compute residual # super().__init__()
res_loss = {}
tot_loss = []
for location in self._sampling_locations: # TODO fix for new collector
condition = solver.problem.conditions[location]
pts = solver.problem.input_pts[location]
# send points to correct device
pts = pts.to(device=device, dtype=precision)
pts = pts.requires_grad_(True)
pts.retain_grad()
# PINN loss: equation evaluated only for sampling locations
target = condition.equation.residual(pts, solver.forward(pts))
res_loss[location] = torch.abs(target).as_subclass(torch.Tensor)
tot_loss.append(torch.abs(target))
print(tot_loss) # # sample every
# check_consistency(sample_every, int)
# self._sample_every = sample_every
# self._const_pts = None
return torch.vstack(tot_loss), res_loss # def _compute_residual(self, trainer):
# """
# Computes the residuals for a PINN object.
def _r3_routine(self, trainer): # :return: the total loss, and pointwise loss.
""" # :rtype: tuple
R3 refinement main routine. # """
:param Trainer trainer: PINA Trainer. # # extract the solver and device from trainer
""" # solver = trainer.solver
# compute residual (all device possible) # device = trainer._accelerator_connector._accelerator_flag
tot_loss, res_loss = self._compute_residual(trainer) # precision = trainer.precision
tot_loss = tot_loss.as_subclass(torch.Tensor) # if precision == "64-true":
# precision = torch.float64
# elif precision == "32-true":
# precision = torch.float32
# else:
# raise RuntimeError(
# "Currently R3Refinement is only implemented "
# "for precision '32-true' and '64-true', set "
# "Trainer precision to match one of the "
# "available precisions."
# )
# !!!!!! From now everything is performed on CPU !!!!!! # # compute residual
# res_loss = {}
# tot_loss = []
# for location in self._sampling_locations:
# condition = solver.problem.conditions[location]
# pts = solver.problem.input_pts[location]
# # send points to correct device
# pts = pts.to(device=device, dtype=precision)
# pts = pts.requires_grad_(True)
# pts.retain_grad()
# # PINN loss: equation evaluated only for sampling locations
# target = condition.equation.residual(pts, solver.forward(pts))
# res_loss[location] = torch.abs(target).as_subclass(torch.Tensor)
# tot_loss.append(torch.abs(target))
# average loss # print(tot_loss)
avg = (tot_loss.mean()).to("cpu")
old_pts = {} # points to be retained
for location in self._sampling_locations:
pts = trainer._model.problem.input_pts[location]
labels = pts.labels
pts = pts.cpu().detach().as_subclass(torch.Tensor)
residuals = res_loss[location].cpu()
mask = (residuals > avg).flatten()
if any(mask): # append residuals greater than average
pts = (pts[mask]).as_subclass(LabelTensor)
pts.labels = labels
old_pts[location] = pts
numb_pts = self._const_pts[location] - len(old_pts[location])
# sample new points
trainer._model.problem.discretise_domain(
numb_pts, "random", locations=[location]
)
else: # if no res greater than average, samples all uniformly # return torch.vstack(tot_loss), res_loss
numb_pts = self._const_pts[location]
# sample new points
trainer._model.problem.discretise_domain(
numb_pts, "random", locations=[location]
)
# adding previous population points
trainer._model.problem.add_points(old_pts)
# update dataloader # def _r3_routine(self, trainer):
trainer._create_or_update_loader() # """
# R3 refinement main routine.
def on_train_start(self, trainer, _): # :param Trainer trainer: PINA Trainer.
""" # """
Callback function called at the start of training. # # compute residual (all device possible)
# tot_loss, res_loss = self._compute_residual(trainer)
# tot_loss = tot_loss.as_subclass(torch.Tensor)
This method extracts the locations for sampling from the problem # # !!!!!! From now everything is performed on CPU !!!!!!
conditions and calculates the total population.
:param trainer: The trainer object managing the training process. # # average loss
:type trainer: pytorch_lightning.Trainer # avg = (tot_loss.mean()).to("cpu")
:param _: Placeholder argument (not used). # old_pts = {} # points to be retained
# for location in self._sampling_locations:
# pts = trainer._model.problem.input_pts[location]
# labels = pts.labels
# pts = pts.cpu().detach().as_subclass(torch.Tensor)
# residuals = res_loss[location].cpu()
# mask = (residuals > avg).flatten()
# if any(mask): # append residuals greater than average
# pts = (pts[mask]).as_subclass(LabelTensor)
# pts.labels = labels
# old_pts[location] = pts
# numb_pts = self._const_pts[location] - len(old_pts[location])
# # sample new points
# trainer._model.problem.discretise_domain(
# numb_pts, "random", locations=[location]
# )
:return: None # else: # if no res greater than average, samples all uniformly
:rtype: None # numb_pts = self._const_pts[location]
""" # # sample new points
# extract locations for sampling # trainer._model.problem.discretise_domain(
problem = trainer.solver.problem # numb_pts, "random", locations=[location]
locations = [] # )
for condition_name in problem.conditions: # # adding previous population points
condition = problem.conditions[condition_name] # trainer._model.problem.add_points(old_pts)
if hasattr(condition, "location"):
locations.append(condition_name)
self._sampling_locations = locations
# extract total population # # update dataloader
const_pts = {} # for each location, store the # of pts to keep constant # trainer._create_or_update_loader()
for location in self._sampling_locations:
pts = trainer._model.problem.input_pts[location]
const_pts[location] = len(pts)
self._const_pts = const_pts
def on_train_epoch_end(self, trainer, __): # def on_train_start(self, trainer, _):
""" # """
Callback function called at the end of each training epoch. # Callback function called at the start of training.
This method triggers the R3 routine for refinement if the current # This method extracts the locations for sampling from the problem
epoch is a multiple of `_sample_every`. # conditions and calculates the total population.
:param trainer: The trainer object managing the training process. # :param trainer: The trainer object managing the training process.
:type trainer: pytorch_lightning.Trainer # :type trainer: pytorch_lightning.Trainer
:param __: Placeholder argument (not used). # :param _: Placeholder argument (not used).
:return: None # :return: None
:rtype: None # :rtype: None
""" # """
if trainer.current_epoch % self._sample_every == 0: # # extract locations for sampling
self._r3_routine(trainer) # problem = trainer.solver.problem
# locations = []
# for condition_name in problem.conditions:
# condition = problem.conditions[condition_name]
# if hasattr(condition, "location"):
# locations.append(condition_name)
# self._sampling_locations = locations
# # extract total population
# const_pts = {} # for each location, store the pts to keep constant
# for location in self._sampling_locations:
# pts = trainer._model.problem.input_pts[location]
# const_pts[location] = len(pts)
# self._const_pts = const_pts
# def on_train_epoch_end(self, trainer, __):
# """
# Callback function called at the end of each training epoch.
# This method triggers the R3 routine for refinement if the current
# epoch is a multiple of `_sample_every`.
# :param trainer: The trainer object managing the training process.
# :type trainer: pytorch_lightning.Trainer
# :param __: Placeholder argument (not used).
# :return: None
# :rtype: None
# """
# if trainer.current_epoch % self._sample_every == 0:
# self._r3_routine(trainer)

View File

@@ -37,12 +37,13 @@ class LinearWeightUpdate(Callback):
check_consistency(self.initial_value, (float, int), subclass=False) check_consistency(self.initial_value, (float, int), subclass=False)
check_consistency(self.target_value, (float, int), subclass=False) check_consistency(self.target_value, (float, int), subclass=False)
def on_train_start(self, trainer, solver): def on_train_start(self, trainer, pl_module):
""" """
Initialize the weight of the condition to the specified `initial_value`. Initialize the weight of the condition to the specified `initial_value`.
:param Trainer trainer: a pina:class:`Trainer` instance. :param Trainer trainer: A :class:`~pina.trainer.Trainer` instance.
:param SolverInterface solver: a pina:class:`SolverInterface` instance. :param SolverInterface pl_module: A
:class:`~pina.solver.solver.SolverInterface` instance.
""" """
# Check that the target epoch is valid # Check that the target epoch is valid
if not 0 < self.target_epoch <= trainer.max_epochs: if not 0 < self.target_epoch <= trainer.max_epochs:
@@ -52,7 +53,7 @@ class LinearWeightUpdate(Callback):
) )
# Check that the condition is a problem condition # Check that the condition is a problem condition
if self.condition_name not in solver.problem.conditions: if self.condition_name not in pl_module.problem.conditions:
raise ValueError( raise ValueError(
f"`{self.condition_name}` must be a problem condition." f"`{self.condition_name}` must be a problem condition."
) )
@@ -66,20 +67,21 @@ class LinearWeightUpdate(Callback):
) )
# Check that the weighting schema is ScalarWeighting # Check that the weighting schema is ScalarWeighting
if not isinstance(solver.weighting, ScalarWeighting): if not isinstance(pl_module.weighting, ScalarWeighting):
raise ValueError("The weighting schema must be ScalarWeighting.") raise ValueError("The weighting schema must be ScalarWeighting.")
# Initialize the weight of the condition # Initialize the weight of the condition
solver.weighting.weights[self.condition_name] = self.initial_value pl_module.weighting.weights[self.condition_name] = self.initial_value
def on_train_epoch_start(self, trainer, solver): def on_train_epoch_start(self, trainer, pl_module):
""" """
Adjust at each epoch the weight of the condition. Adjust at each epoch the weight of the condition.
:param Trainer trainer: a pina:class:`Trainer` instance. :param Trainer trainer: A :class:`~pina.trainer.Trainer` instance.
:param SolverInterface solver: a pina:class:`SolverInterface` instance. :param SolverInterface pl_module: A
:class:`~pina.solver.solver.SolverInterface` instance.
""" """
if 0 < trainer.current_epoch <= self.target_epoch: if 0 < trainer.current_epoch <= self.target_epoch:
solver.weighting.weights[self.condition_name] += ( pl_module.weighting.weights[self.condition_name] += (
self.target_value - self.initial_value self.target_value - self.initial_value
) / (self.target_epoch - 1) ) / (self.target_epoch - 1)

View File

@@ -1,27 +1,27 @@
"""PINA Callbacks Implementations""" """PINA Callbacks Implementations"""
from lightning.pytorch.callbacks import Callback from lightning.pytorch.callbacks import Callback
import torch from ..optim import TorchOptimizer
from ..utils import check_consistency from ..utils import check_consistency
from pina.optim import TorchOptimizer
class SwitchOptimizer(Callback): class SwitchOptimizer(Callback):
def __init__(self, new_optimizers, epoch_switch):
""" """
PINA Implementation of a Lightning Callback to switch optimizer during PINA Implementation of a Lightning Callback to switch optimizer during
training. training.
"""
This callback allows for switching between different optimizers during def __init__(self, new_optimizers, epoch_switch):
"""
This callback allows switching between different optimizers during
training, enabling the exploration of multiple optimization strategies training, enabling the exploration of multiple optimization strategies
without the need to stop training. without interrupting the training process.
:param new_optimizers: The model optimizers to switch to. Can be a :param new_optimizers: The model optimizers to switch to. Can be a
single :class:`torch.optim.Optimizer` or a list of them for multiple single :class:`torch.optim.Optimizer` instance or a list of them
model solver. for multiple model solver.
:type new_optimizers: pina.optim.TorchOptimizer | list :type new_optimizers: pina.optim.TorchOptimizer | list
:param epoch_switch: The epoch at which to switch to the new optimizer. :param epoch_switch: The epoch at which the optimizer switch occurs.
:type epoch_switch: int :type epoch_switch: int
Example: Example:
@@ -46,7 +46,7 @@ class SwitchOptimizer(Callback):
def on_train_epoch_start(self, trainer, __): def on_train_epoch_start(self, trainer, __):
""" """
Callback function to switch optimizer at the start of each training epoch. Switch the optimizer at the start of the specified training epoch.
:param trainer: The trainer object managing the training process. :param trainer: The trainer object managing the training process.
:type trainer: pytorch_lightning.Trainer :type trainer: pytorch_lightning.Trainer
@@ -59,7 +59,7 @@ class SwitchOptimizer(Callback):
optims = [] optims = []
for idx, optim in enumerate(self._new_optimizers): for idx, optim in enumerate(self._new_optimizers):
optim.hook(trainer.solver.models[idx].parameters()) optim.hook(trainer.solver._pina_models[idx].parameters())
optims.append(optim.instance) optims.append(optim)
trainer.optimizers = optims trainer.solver._pina_optimizers = optims

View File

@@ -1,7 +1,7 @@
"""PINA Callbacks Implementations""" """PINA Callbacks Implementations"""
import torch
import copy import copy
import torch
from lightning.pytorch.callbacks import Callback, TQDMProgressBar from lightning.pytorch.callbacks import Callback, TQDMProgressBar
from lightning.pytorch.callbacks.progress.progress_bar import ( from lightning.pytorch.callbacks.progress.progress_bar import (
@@ -11,22 +11,37 @@ from pina.utils import check_consistency
class MetricTracker(Callback): class MetricTracker(Callback):
"""
Lightning Callback for Metric Tracking.
"""
def __init__(self, metrics_to_track=None): def __init__(self, metrics_to_track=None):
""" """
Lightning Callback for Metric Tracking. Tracks specified metrics during training.
Tracks specific metrics during the training process. :param metrics_to_track: List of metrics to track.
Defaults to train loss.
:ivar _collection: A list to store collected metrics after each epoch. :type metrics_to_track: list[str], optional
:param metrics_to_track: List of metrics to track. Defaults to train/val loss.
:type metrics_to_track: list, optional
""" """
super().__init__() super().__init__()
self._collection = [] self._collection = []
# Default to tracking 'train_loss' and 'val_loss' if not specified # Default to tracking 'train_loss' if not specified
self.metrics_to_track = metrics_to_track or ["train_loss", "val_loss"] self.metrics_to_track = metrics_to_track
def setup(self, trainer, pl_module, stage):
"""
Called when fit, validate, test, predict, or tune begins.
:param Trainer trainer: A :class:`~pina.trainer.Trainer` instance.
:param SolverInterface pl_module: A
:class:`~pina.solver.solver.SolverInterface` instance.
:param str stage: Either 'fit', 'test' or 'predict'.
"""
if self.metrics_to_track is None and trainer.batch_size is None:
self.metrics_to_track = ["train_loss"]
elif self.metrics_to_track is None:
self.metrics_to_track = ["train_loss_epoch"]
return super().setup(trainer, pl_module, stage)
def on_train_epoch_end(self, trainer, pl_module): def on_train_epoch_end(self, trainer, pl_module):
""" """
@@ -71,26 +86,28 @@ class MetricTracker(Callback):
class PINAProgressBar(TQDMProgressBar): class PINAProgressBar(TQDMProgressBar):
"""
PINA Implementation of a Lightning Callback for enriching the progress bar.
"""
BAR_FORMAT = "{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_noinv_fmt}{postfix}]" BAR_FORMAT = (
"{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, "
"{rate_noinv_fmt}{postfix}]"
)
def __init__(self, metrics="val", **kwargs): def __init__(self, metrics="val", **kwargs):
""" """
PINA Implementation of a Lightning Callback for enriching the progress This class enables the display of only relevant metrics during training.
bar.
This class provides functionality to display only relevant metrics :param metrics: Logged metrics to be shown during the training.
during the training process. Must be a subset of the conditions keys defined in
:param metrics: Logged metrics to display during the training. It should
be a subset of the conditions keys defined in
:obj:`pina.condition.Condition`. :obj:`pina.condition.Condition`.
:type metrics: str | list(str) | tuple(str) :type metrics: str | list(str) | tuple(str)
:Keyword Arguments: :Keyword Arguments:
The additional keyword arguments specify the progress bar The additional keyword arguments specify the progress bar and can be
and can be choosen from the `pytorch-lightning choosen from the `pytorch-lightning TQDMProgressBar API
TQDMProgressBar API <https://lightning.ai/docs/pytorch/stable/_modules/lightning/pytorch/callback/progress/tqdm_progress.html#TQDMProgressBar>`_ <https://lightning.ai/docs/pytorch/stable/_modules/lightning/pytorch/callback/progress/tqdm_progress.html#TQDMProgressBar>`_
Example: Example:
>>> pbar = PINAProgressBar(['mean']) >>> pbar = PINAProgressBar(['mean'])
@@ -105,9 +122,9 @@ class PINAProgressBar(TQDMProgressBar):
self._sorted_metrics = metrics self._sorted_metrics = metrics
def get_metrics(self, trainer, pl_module): def get_metrics(self, trainer, pl_module):
r"""Combines progress bar metrics collected from the trainer with r"""Combine progress bar metrics collected from the trainer with
standard metrics from get_standard_metrics. standard metrics from get_standard_metrics.
Implement this to override the items displayed in the progress bar. Override this method to customize the items shown in the progress bar.
The progress bar metrics are sorted according to ``metrics``. The progress bar metrics are sorted according to ``metrics``.
Here is an example of how to override the defaults: Here is an example of how to override the defaults:
@@ -122,20 +139,20 @@ class PINAProgressBar(TQDMProgressBar):
:return: Dictionary with the items to be displayed in the progress bar. :return: Dictionary with the items to be displayed in the progress bar.
:rtype: tuple(dict) :rtype: tuple(dict)
""" """
standard_metrics = get_standard_metrics(trainer) standard_metrics = get_standard_metrics(trainer)
pbar_metrics = trainer.progress_bar_metrics pbar_metrics = trainer.progress_bar_metrics
if pbar_metrics: if pbar_metrics:
pbar_metrics = { pbar_metrics = {
key: pbar_metrics[key] for key in self._sorted_metrics key: pbar_metrics[key]
for key in pbar_metrics
if key in self._sorted_metrics
} }
return {**standard_metrics, **pbar_metrics} return {**standard_metrics, **pbar_metrics}
def on_fit_start(self, trainer, pl_module): def setup(self, trainer, pl_module, stage):
""" """
Check that the metrics defined in the initialization are available, Check that the initialized metrics are available and correctly logged.
i.e. are correctly logged.
:param trainer: The trainer object managing the training process. :param trainer: The trainer object managing the training process.
:type trainer: pytorch_lightning.Trainer :type trainer: pytorch_lightning.Trainer
@@ -150,7 +167,11 @@ class PINAProgressBar(TQDMProgressBar):
): ):
raise KeyError(f"Key '{key}' is not present in the dictionary") raise KeyError(f"Key '{key}' is not present in the dictionary")
# add the loss pedix # add the loss pedix
if trainer.batch_size is not None:
pedix = "_loss_epoch"
else:
pedix = "_loss"
self._sorted_metrics = [ self._sorted_metrics = [
metric + "_loss" for metric in self._sorted_metrics metric + pedix for metric in self._sorted_metrics
] ]
return super().on_fit_start(trainer, pl_module) return super().setup(trainer, pl_module, stage)

View File

@@ -64,8 +64,7 @@ class Trainer(lightning.pytorch.Trainer):
:Keyword Arguments: :Keyword Arguments:
The additional keyword arguments specify the training setup The additional keyword arguments specify the training setup
and can be choosen from the `pytorch-lightning and can be choosen from the `pytorch-lightning
Trainer API <https://lightning.ai/docs/pytorch/stable/common/ Trainer API <https://lightning.ai/docs/pytorch/stable/common/trainer.html#trainer-class-api>`_
trainer.html#trainer-class-api>`_
""" """
# check consistency for init types # check consistency for init types
self._check_input_consistency( self._check_input_consistency(
@@ -96,7 +95,6 @@ class Trainer(lightning.pytorch.Trainer):
# Setting default kwargs, overriding lightning defaults # Setting default kwargs, overriding lightning defaults
kwargs.setdefault("enable_progress_bar", True) kwargs.setdefault("enable_progress_bar", True)
kwargs.setdefault("logger", None)
super().__init__(**kwargs) super().__init__(**kwargs)
@@ -127,9 +125,6 @@ class Trainer(lightning.pytorch.Trainer):
# logging # logging
self.logging_kwargs = { self.logging_kwargs = {
"logger": bool(
kwargs["logger"] is not None or kwargs["logger"] is True
),
"sync_dist": bool( "sync_dist": bool(
len(self._accelerator_connector._parallel_devices) > 1 len(self._accelerator_connector._parallel_devices) > 1
), ),

View File

@@ -23,17 +23,18 @@ def test_metric_tracker_constructor():
MetricTracker() MetricTracker()
# def test_metric_tracker_routine(): #TODO revert def test_metric_tracker_routine():
# # make the trainer # make the trainer
# trainer = Trainer(solver=solver, trainer = Trainer(
# callback=[ solver=solver,
# MetricTracker() callbacks=[MetricTracker()],
# ], accelerator="cpu",
# accelerator='cpu', max_epochs=5,
# max_epochs=5) log_every_n_steps=1,
# trainer.train() )
# # get the tracked metrics trainer.train()
# metrics = trainer.callback[0].metrics # get the tracked metrics
# # assert the logged metrics are correct metrics = trainer.callbacks[0].metrics
# logged_metrics = sorted(list(metrics.keys())) # assert the logged metrics are correct
# assert logged_metrics == ['train_loss_epoch', 'train_loss_step', 'val_loss'] logged_metrics = sorted(list(metrics.keys()))
assert logged_metrics == ["train_loss"]

View File

@@ -21,19 +21,25 @@ model = FeedForward(
# make the solver # make the solver
solver = PINN(problem=poisson_problem, model=model) solver = PINN(problem=poisson_problem, model=model)
adam_optimizer = TorchOptimizer(torch.optim.Adam, lr=0.01) adam = TorchOptimizer(torch.optim.Adam, lr=0.01)
lbfgs_optimizer = TorchOptimizer(torch.optim.LBFGS, lr=0.001) lbfgs = TorchOptimizer(torch.optim.LBFGS, lr=0.001)
def test_switch_optimizer_constructor(): def test_switch_optimizer_constructor():
SwitchOptimizer(adam_optimizer, epoch_switch=10) SwitchOptimizer(adam, epoch_switch=10)
# def test_switch_optimizer_routine(): #TODO revert def test_switch_optimizer_routine():
# # make the trainer # check initial optimizer
# switch_opt_callback = SwitchOptimizer(lbfgs_optimizer, epoch_switch=3) solver.configure_optimizers()
# trainer = Trainer(solver=solver, assert solver.optimizer.instance.__class__ == torch.optim.Adam
# callback=[switch_opt_callback], # make the trainer
# accelerator='cpu', switch_opt_callback = SwitchOptimizer(lbfgs, epoch_switch=3)
# max_epochs=5) trainer = Trainer(
# trainer.train() solver=solver,
callbacks=[switch_opt_callback],
accelerator="cpu",
max_epochs=5,
)
trainer.train()
assert solver.optimizer.instance.__class__ == torch.optim.LBFGS

View File

@@ -5,29 +5,32 @@ from pina.callback.processing_callback import PINAProgressBar
from pina.problem.zoo import Poisson2DSquareProblem as Poisson from pina.problem.zoo import Poisson2DSquareProblem as Poisson
# # make the problem # make the problem
# poisson_problem = Poisson() poisson_problem = Poisson()
# boundaries = ['nil_g1', 'nil_g2', 'nil_g3', 'nil_g4'] boundaries = ["g1", "g2", "g3", "g4"]
# n = 10 n = 10
# poisson_problem.discretise_domain(n, 'grid', locations=boundaries) condition_names = list(poisson_problem.conditions.keys())
# poisson_problem.discretise_domain(n, 'grid', locations='laplace_D') poisson_problem.discretise_domain(n, "grid", domains=boundaries)
# model = FeedForward(len(poisson_problem.input_variables), poisson_problem.discretise_domain(n, "grid", domains="D")
# len(poisson_problem.output_variables)) model = FeedForward(
len(poisson_problem.input_variables), len(poisson_problem.output_variables)
)
# # make the solver # make the solver
# solver = PINN(problem=poisson_problem, model=model) solver = PINN(problem=poisson_problem, model=model)
# def test_progress_bar_constructor(): def test_progress_bar_constructor():
# PINAProgressBar(['mean']) PINAProgressBar()
# def test_progress_bar_routine():
# # make the trainer def test_progress_bar_routine():
# trainer = Trainer(solver=solver, # make the trainer
# callback=[ trainer = Trainer(
# PINAProgressBar(['mean', 'laplace_D']) solver=solver,
# ], callbacks=[PINAProgressBar(["val", condition_names[0]])],
# accelerator='cpu', accelerator="cpu",
# max_epochs=5) max_epochs=5,
# trainer.train() )
# # TODO there should be a check that the correct metrics are displayed trainer.train()
# TODO there should be a check that the correct metrics are displayed