PINN variants addition and Solvers Update (#263)
* gpinn/basepinn new classes, pinn restructure * codacy fix gpinn/basepinn/pinn * inverse problem fix * Causal PINN (#267) * fix GPU training in inverse problem (#283) * Create a `compute_residual` attribute for `PINNInterface` * Modify dataloading in solvers (#286) * Modify PINNInterface by removing _loss_phys, _loss_data * Adding in PINNInterface a variable to track the current condition during training * Modify GPINN,PINN,CausalPINN to match changes in PINNInterface * Competitive Pinn Addition (#288) * fixing after rebase/ fix loss * fixing final issues --------- Co-authored-by: Dario Coscia <dariocoscia@Dario-Coscia.local> * Modify min max formulation to max min for paper consistency * Adding SAPINN solver (#291) * rom solver * fix import --------- Co-authored-by: Dario Coscia <dariocoscia@Dario-Coscia.local> Co-authored-by: Anna Ivagnes <75523024+annaivagnes@users.noreply.github.com> Co-authored-by: valc89 <103250118+valc89@users.noreply.github.com> Co-authored-by: Monthly Tag bot <mtbot@noreply.github.com> Co-authored-by: Nicola Demo <demo.nicola@gmail.com>
This commit is contained in:
@@ -35,8 +35,14 @@ Solvers
|
|||||||
:titlesonly:
|
:titlesonly:
|
||||||
|
|
||||||
SolverInterface <solvers/solver_interface.rst>
|
SolverInterface <solvers/solver_interface.rst>
|
||||||
|
PINNInterface <solvers/basepinn.rst>
|
||||||
PINN <solvers/pinn.rst>
|
PINN <solvers/pinn.rst>
|
||||||
|
GPINN <solvers/gpinn.rst>
|
||||||
|
CausalPINN <solvers/causalpinn.rst>
|
||||||
|
CompetitivePINN <solvers/competitivepinn.rst>
|
||||||
|
SAPINN <solvers/sapinn.rst>
|
||||||
Supervised solver <solvers/supervised.rst>
|
Supervised solver <solvers/supervised.rst>
|
||||||
|
ReducedOrderModelSolver <solvers/rom.rst>
|
||||||
GAROM <solvers/garom.rst>
|
GAROM <solvers/garom.rst>
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
7
docs/source/_rst/solvers/basepinn.rst
Normal file
7
docs/source/_rst/solvers/basepinn.rst
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
PINNInterface
|
||||||
|
=================
|
||||||
|
.. currentmodule:: pina.solvers.pinns.basepinn
|
||||||
|
|
||||||
|
.. autoclass:: PINNInterface
|
||||||
|
:members:
|
||||||
|
:show-inheritance:
|
||||||
7
docs/source/_rst/solvers/causalpinn.rst
Normal file
7
docs/source/_rst/solvers/causalpinn.rst
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
CausalPINN
|
||||||
|
==============
|
||||||
|
.. currentmodule:: pina.solvers.pinns.causalpinn
|
||||||
|
|
||||||
|
.. autoclass:: CausalPINN
|
||||||
|
:members:
|
||||||
|
:show-inheritance:
|
||||||
7
docs/source/_rst/solvers/competitivepinn.rst
Normal file
7
docs/source/_rst/solvers/competitivepinn.rst
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
CompetitivePINN
|
||||||
|
=================
|
||||||
|
.. currentmodule:: pina.solvers.pinns.competitive_pinn
|
||||||
|
|
||||||
|
.. autoclass:: CompetitivePINN
|
||||||
|
:members:
|
||||||
|
:show-inheritance:
|
||||||
7
docs/source/_rst/solvers/gpinn.rst
Normal file
7
docs/source/_rst/solvers/gpinn.rst
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
GPINN
|
||||||
|
======
|
||||||
|
.. currentmodule:: pina.solvers.pinns.gpinn
|
||||||
|
|
||||||
|
.. autoclass:: GPINN
|
||||||
|
:members:
|
||||||
|
:show-inheritance:
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
PINN
|
PINN
|
||||||
======
|
======
|
||||||
.. currentmodule:: pina.solvers.pinn
|
.. currentmodule:: pina.solvers.pinns.pinn
|
||||||
|
|
||||||
.. autoclass:: PINN
|
.. autoclass:: PINN
|
||||||
:members:
|
:members:
|
||||||
|
|||||||
7
docs/source/_rst/solvers/rom.rst
Normal file
7
docs/source/_rst/solvers/rom.rst
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
ReducedOrderModelSolver
|
||||||
|
==========================
|
||||||
|
.. currentmodule:: pina.solvers.rom
|
||||||
|
|
||||||
|
.. autoclass:: ReducedOrderModelSolver
|
||||||
|
:members:
|
||||||
|
:show-inheritance:
|
||||||
7
docs/source/_rst/solvers/sapinn.rst
Normal file
7
docs/source/_rst/solvers/sapinn.rst
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
SAPINN
|
||||||
|
======
|
||||||
|
.. currentmodule:: pina.solvers.pinns.sapinn
|
||||||
|
|
||||||
|
.. autoclass:: SAPINN
|
||||||
|
:members:
|
||||||
|
:show-inheritance:
|
||||||
@@ -110,9 +110,9 @@ class AveragingNeuralOperator(KernelNeuralOperator):
|
|||||||
"""
|
"""
|
||||||
points_tmp = x.extract(self.coordinates_indices)
|
points_tmp = x.extract(self.coordinates_indices)
|
||||||
new_batch = x.extract(self.field_indices)
|
new_batch = x.extract(self.field_indices)
|
||||||
new_batch = concatenate((new_batch, points_tmp), dim=2)
|
new_batch = concatenate((new_batch, points_tmp), dim=-1)
|
||||||
new_batch = self._lifting_operator(new_batch)
|
new_batch = self._lifting_operator(new_batch)
|
||||||
new_batch = self._integral_kernels(new_batch)
|
new_batch = self._integral_kernels(new_batch)
|
||||||
new_batch = concatenate((new_batch, points_tmp), dim=2)
|
new_batch = concatenate((new_batch, points_tmp), dim=-1)
|
||||||
new_batch = self._projection_operator(new_batch)
|
new_batch = self._projection_operator(new_batch)
|
||||||
return new_batch
|
return new_batch
|
||||||
|
|||||||
@@ -1,6 +1,19 @@
|
|||||||
__all__ = ["PINN", "GAROM", "SupervisedSolver", "SolverInterface"]
|
__all__ = [
|
||||||
|
"SolverInterface",
|
||||||
|
"PINNInterface",
|
||||||
|
"PINN",
|
||||||
|
"GPINN",
|
||||||
|
"CausalPINN",
|
||||||
|
"CompetitivePINN",
|
||||||
|
"SAPINN",
|
||||||
|
"SupervisedSolver",
|
||||||
|
"ReducedOrderModelSolver",
|
||||||
|
"GAROM",
|
||||||
|
]
|
||||||
|
|
||||||
from .garom import GAROM
|
|
||||||
from .pinn import PINN
|
|
||||||
from .supervised import SupervisedSolver
|
|
||||||
from .solver import SolverInterface
|
from .solver import SolverInterface
|
||||||
|
from .pinns import *
|
||||||
|
from .supervised import SupervisedSolver
|
||||||
|
from .rom import ReducedOrderModelSolver
|
||||||
|
from .garom import GAROM
|
||||||
|
|
||||||
|
|||||||
@@ -253,18 +253,11 @@ class GAROM(SolverInterface):
|
|||||||
:rtype: LabelTensor
|
:rtype: LabelTensor
|
||||||
"""
|
"""
|
||||||
|
|
||||||
dataloader = self.trainer.train_dataloader
|
|
||||||
condition_idx = batch["condition"]
|
condition_idx = batch["condition"]
|
||||||
|
|
||||||
for condition_id in range(condition_idx.min(), condition_idx.max() + 1):
|
for condition_id in range(condition_idx.min(), condition_idx.max() + 1):
|
||||||
|
|
||||||
if sys.version_info >= (3, 8):
|
condition_name = self._dataloader.condition_names[condition_id]
|
||||||
condition_name = dataloader.condition_names[condition_id]
|
|
||||||
else:
|
|
||||||
condition_name = dataloader.loaders.condition_names[
|
|
||||||
condition_id
|
|
||||||
]
|
|
||||||
|
|
||||||
condition = self.problem.conditions[condition_name]
|
condition = self.problem.conditions[condition_name]
|
||||||
pts = batch["pts"].detach()
|
pts = batch["pts"].detach()
|
||||||
out = batch["output"]
|
out = batch["output"]
|
||||||
|
|||||||
@@ -1,232 +0,0 @@
|
|||||||
""" Module for PINN """
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
try:
|
|
||||||
from torch.optim.lr_scheduler import LRScheduler # torch >= 2.0
|
|
||||||
except ImportError:
|
|
||||||
from torch.optim.lr_scheduler import (
|
|
||||||
_LRScheduler as LRScheduler,
|
|
||||||
) # torch < 2.0
|
|
||||||
|
|
||||||
import sys
|
|
||||||
from torch.optim.lr_scheduler import ConstantLR
|
|
||||||
|
|
||||||
from .solver import SolverInterface
|
|
||||||
from ..label_tensor import LabelTensor
|
|
||||||
from ..utils import check_consistency
|
|
||||||
from ..loss import LossInterface
|
|
||||||
from ..problem import InverseProblem
|
|
||||||
from torch.nn.modules.loss import _Loss
|
|
||||||
|
|
||||||
torch.pi = torch.acos(torch.zeros(1)).item() * 2 # which is 3.1415927410125732
|
|
||||||
|
|
||||||
|
|
||||||
class PINN(SolverInterface):
|
|
||||||
"""
|
|
||||||
PINN solver class. This class implements Physics Informed Neural
|
|
||||||
Network solvers, using a user specified ``model`` to solve a specific
|
|
||||||
``problem``. It can be used for solving both forward and inverse problems.
|
|
||||||
|
|
||||||
.. seealso::
|
|
||||||
|
|
||||||
**Original reference**: Karniadakis, G. E., Kevrekidis, I. G., Lu, L.,
|
|
||||||
Perdikaris, P., Wang, S., & Yang, L. (2021).
|
|
||||||
Physics-informed machine learning. Nature Reviews Physics, 3(6), 422-440.
|
|
||||||
<https://doi.org/10.1038/s42254-021-00314-5>`_.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
problem,
|
|
||||||
model,
|
|
||||||
extra_features=None,
|
|
||||||
loss=torch.nn.MSELoss(),
|
|
||||||
optimizer=torch.optim.Adam,
|
|
||||||
optimizer_kwargs={"lr": 0.001},
|
|
||||||
scheduler=ConstantLR,
|
|
||||||
scheduler_kwargs={"factor": 1, "total_iters": 0},
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
:param AbstractProblem problem: The formulation of the problem.
|
|
||||||
:param torch.nn.Module model: The neural network model to use.
|
|
||||||
:param torch.nn.Module loss: The loss function used as minimizer,
|
|
||||||
default :class:`torch.nn.MSELoss`.
|
|
||||||
:param torch.nn.Module extra_features: The additional input
|
|
||||||
features to use as augmented input.
|
|
||||||
:param torch.optim.Optimizer optimizer: The neural network optimizer to
|
|
||||||
use; default is :class:`torch.optim.Adam`.
|
|
||||||
:param dict optimizer_kwargs: Optimizer constructor keyword args.
|
|
||||||
:param torch.optim.LRScheduler scheduler: Learning
|
|
||||||
rate scheduler.
|
|
||||||
:param dict scheduler_kwargs: LR scheduler constructor keyword args.
|
|
||||||
"""
|
|
||||||
super().__init__(
|
|
||||||
models=[model],
|
|
||||||
problem=problem,
|
|
||||||
optimizers=[optimizer],
|
|
||||||
optimizers_kwargs=[optimizer_kwargs],
|
|
||||||
extra_features=extra_features,
|
|
||||||
)
|
|
||||||
|
|
||||||
# check consistency
|
|
||||||
check_consistency(scheduler, LRScheduler, subclass=True)
|
|
||||||
check_consistency(scheduler_kwargs, dict)
|
|
||||||
check_consistency(loss, (LossInterface, _Loss), subclass=False)
|
|
||||||
|
|
||||||
# assign variables
|
|
||||||
self._scheduler = scheduler(self.optimizers[0], **scheduler_kwargs)
|
|
||||||
self._loss = loss
|
|
||||||
self._neural_net = self.models[0]
|
|
||||||
|
|
||||||
# inverse problem handling
|
|
||||||
if isinstance(self.problem, InverseProblem):
|
|
||||||
self._params = self.problem.unknown_parameters
|
|
||||||
else:
|
|
||||||
self._params = None
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
"""
|
|
||||||
Forward pass implementation for the PINN
|
|
||||||
solver.
|
|
||||||
|
|
||||||
:param torch.Tensor x: Input tensor.
|
|
||||||
:return: PINN solution.
|
|
||||||
:rtype: torch.Tensor
|
|
||||||
"""
|
|
||||||
return self.neural_net(x)
|
|
||||||
|
|
||||||
def configure_optimizers(self):
|
|
||||||
"""
|
|
||||||
Optimizer configuration for the PINN
|
|
||||||
solver.
|
|
||||||
|
|
||||||
:return: The optimizers and the schedulers
|
|
||||||
:rtype: tuple(list, list)
|
|
||||||
"""
|
|
||||||
# if the problem is an InverseProblem, add the unknown parameters
|
|
||||||
# to the parameters that the optimizer needs to optimize
|
|
||||||
if isinstance(self.problem, InverseProblem):
|
|
||||||
self.optimizers[0].add_param_group(
|
|
||||||
{
|
|
||||||
"params": [
|
|
||||||
self._params[var]
|
|
||||||
for var in self.problem.unknown_variables
|
|
||||||
]
|
|
||||||
}
|
|
||||||
)
|
|
||||||
return self.optimizers, [self.scheduler]
|
|
||||||
|
|
||||||
def _clamp_inverse_problem_params(self):
|
|
||||||
for v in self._params:
|
|
||||||
self._params[v].data.clamp_(
|
|
||||||
self.problem.unknown_parameter_domain.range_[v][0],
|
|
||||||
self.problem.unknown_parameter_domain.range_[v][1],
|
|
||||||
)
|
|
||||||
|
|
||||||
def _loss_data(self, input, output):
|
|
||||||
return self.loss(self.forward(input), output)
|
|
||||||
|
|
||||||
def _loss_phys(self, samples, equation):
|
|
||||||
try:
|
|
||||||
residual = equation.residual(samples, self.forward(samples))
|
|
||||||
except (
|
|
||||||
TypeError
|
|
||||||
): # this occurs when the function has three inputs, i.e. inverse problem
|
|
||||||
residual = equation.residual(
|
|
||||||
samples, self.forward(samples), self._params
|
|
||||||
)
|
|
||||||
return self.loss(
|
|
||||||
torch.zeros_like(residual, requires_grad=True), residual
|
|
||||||
)
|
|
||||||
|
|
||||||
def training_step(self, batch, batch_idx):
|
|
||||||
"""
|
|
||||||
PINN solver training step.
|
|
||||||
|
|
||||||
:param batch: The batch element in the dataloader.
|
|
||||||
:type batch: tuple
|
|
||||||
:param batch_idx: The batch index.
|
|
||||||
:type batch_idx: int
|
|
||||||
:return: The sum of the loss functions.
|
|
||||||
:rtype: LabelTensor
|
|
||||||
"""
|
|
||||||
|
|
||||||
dataloader = self.trainer.train_dataloader
|
|
||||||
condition_losses = []
|
|
||||||
|
|
||||||
condition_idx = batch["condition"]
|
|
||||||
|
|
||||||
for condition_id in range(condition_idx.min(), condition_idx.max() + 1):
|
|
||||||
|
|
||||||
if sys.version_info >= (3, 8):
|
|
||||||
condition_name = dataloader.condition_names[condition_id]
|
|
||||||
else:
|
|
||||||
condition_name = dataloader.loaders.condition_names[
|
|
||||||
condition_id
|
|
||||||
]
|
|
||||||
condition = self.problem.conditions[condition_name]
|
|
||||||
pts = batch["pts"]
|
|
||||||
|
|
||||||
if len(batch) == 2:
|
|
||||||
samples = pts[condition_idx == condition_id]
|
|
||||||
loss = self._loss_phys(samples, condition.equation)
|
|
||||||
elif len(batch) == 3:
|
|
||||||
samples = pts[condition_idx == condition_id]
|
|
||||||
ground_truth = batch["output"][condition_idx == condition_id]
|
|
||||||
loss = self._loss_data(samples, ground_truth)
|
|
||||||
else:
|
|
||||||
raise ValueError("Batch size not supported")
|
|
||||||
|
|
||||||
# TODO for users this us hard to remember when creating a new solver, to fix in a smarter way
|
|
||||||
loss = loss.as_subclass(torch.Tensor)
|
|
||||||
|
|
||||||
# # add condition losses and accumulate logging for each epoch
|
|
||||||
condition_losses.append(loss * condition.data_weight)
|
|
||||||
self.log(
|
|
||||||
condition_name + "_loss",
|
|
||||||
float(loss),
|
|
||||||
prog_bar=True,
|
|
||||||
logger=True,
|
|
||||||
on_epoch=True,
|
|
||||||
on_step=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
# clamp unknown parameters of the InverseProblem to their domain ranges (if needed)
|
|
||||||
if isinstance(self.problem, InverseProblem):
|
|
||||||
self._clamp_inverse_problem_params()
|
|
||||||
|
|
||||||
# TODO Fix the bug, tot_loss is a label tensor without labels
|
|
||||||
# we need to pass it as a torch tensor to make everything work
|
|
||||||
total_loss = sum(condition_losses)
|
|
||||||
self.log(
|
|
||||||
"mean_loss",
|
|
||||||
float(total_loss / len(condition_losses)),
|
|
||||||
prog_bar=True,
|
|
||||||
logger=True,
|
|
||||||
on_epoch=True,
|
|
||||||
on_step=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
return total_loss
|
|
||||||
|
|
||||||
@property
|
|
||||||
def scheduler(self):
|
|
||||||
"""
|
|
||||||
Scheduler for the PINN training.
|
|
||||||
"""
|
|
||||||
return self._scheduler
|
|
||||||
|
|
||||||
@property
|
|
||||||
def neural_net(self):
|
|
||||||
"""
|
|
||||||
Neural network for the PINN training.
|
|
||||||
"""
|
|
||||||
return self._neural_net
|
|
||||||
|
|
||||||
@property
|
|
||||||
def loss(self):
|
|
||||||
"""
|
|
||||||
Loss for the PINN training.
|
|
||||||
"""
|
|
||||||
return self._loss
|
|
||||||
15
pina/solvers/pinns/__init__.py
Normal file
15
pina/solvers/pinns/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
__all__ = [
|
||||||
|
"PINNInterface",
|
||||||
|
"PINN",
|
||||||
|
"GPINN",
|
||||||
|
"CausalPINN",
|
||||||
|
"CompetitivePINN",
|
||||||
|
"SAPINN",
|
||||||
|
]
|
||||||
|
|
||||||
|
from .basepinn import PINNInterface
|
||||||
|
from .pinn import PINN
|
||||||
|
from .gpinn import GPINN
|
||||||
|
from .causalpinn import CausalPINN
|
||||||
|
from .competitive_pinn import CompetitivePINN
|
||||||
|
from .sapinn import SAPINN
|
||||||
247
pina/solvers/pinns/basepinn.py
Normal file
247
pina/solvers/pinns/basepinn.py
Normal file
@@ -0,0 +1,247 @@
|
|||||||
|
""" Module for PINN """
|
||||||
|
|
||||||
|
import sys
|
||||||
|
from abc import ABCMeta, abstractmethod
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from ...solvers.solver import SolverInterface
|
||||||
|
from pina.utils import check_consistency
|
||||||
|
from pina.loss import LossInterface
|
||||||
|
from pina.problem import InverseProblem
|
||||||
|
from torch.nn.modules.loss import _Loss
|
||||||
|
|
||||||
|
torch.pi = torch.acos(torch.zeros(1)).item() * 2 # which is 3.1415927410125732
|
||||||
|
|
||||||
|
class PINNInterface(SolverInterface, metaclass=ABCMeta):
|
||||||
|
"""
|
||||||
|
Base PINN solver class. This class implements the Solver Interface
|
||||||
|
for Physics Informed Neural Network solvers.
|
||||||
|
|
||||||
|
This class can be used to
|
||||||
|
define PINNs with multiple ``optimizers``, and/or ``models``.
|
||||||
|
By default it takes
|
||||||
|
an :class:`~pina.problem.abstract_problem.AbstractProblem`, so it is up
|
||||||
|
to the user to choose which problem the implemented solver inheriting from
|
||||||
|
this class is suitable for.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
models,
|
||||||
|
problem,
|
||||||
|
optimizers,
|
||||||
|
optimizers_kwargs,
|
||||||
|
extra_features,
|
||||||
|
loss,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
:param models: Multiple torch neural network models instances.
|
||||||
|
:type models: list(torch.nn.Module)
|
||||||
|
:param problem: A problem definition instance.
|
||||||
|
:type problem: AbstractProblem
|
||||||
|
:param list(torch.optim.Optimizer) optimizer: A list of neural network
|
||||||
|
optimizers to use.
|
||||||
|
:param list(dict) optimizer_kwargs: A list of optimizer constructor
|
||||||
|
keyword args.
|
||||||
|
:param list(torch.nn.Module) extra_features: The additional input
|
||||||
|
features to use as augmented input. If ``None`` no extra features
|
||||||
|
are passed. If it is a list of :class:`torch.nn.Module`,
|
||||||
|
the extra feature list is passed to all models. If it is a list
|
||||||
|
of extra features' lists, each single list of extra feature
|
||||||
|
is passed to a model.
|
||||||
|
:param torch.nn.Module loss: The loss function used as minimizer,
|
||||||
|
default :class:`torch.nn.MSELoss`.
|
||||||
|
"""
|
||||||
|
super().__init__(
|
||||||
|
models=models,
|
||||||
|
problem=problem,
|
||||||
|
optimizers=optimizers,
|
||||||
|
optimizers_kwargs=optimizers_kwargs,
|
||||||
|
extra_features=extra_features,
|
||||||
|
)
|
||||||
|
|
||||||
|
# check consistency
|
||||||
|
check_consistency(loss, (LossInterface, _Loss), subclass=False)
|
||||||
|
|
||||||
|
# assign variables
|
||||||
|
self._loss = loss
|
||||||
|
|
||||||
|
# inverse problem handling
|
||||||
|
if isinstance(self.problem, InverseProblem):
|
||||||
|
self._params = self.problem.unknown_parameters
|
||||||
|
self._clamp_params = self._clamp_inverse_problem_params
|
||||||
|
else:
|
||||||
|
self._params = None
|
||||||
|
self._clamp_params = lambda : None
|
||||||
|
|
||||||
|
# variable used internally to store residual losses at each epoch
|
||||||
|
# this variable save the residual at each iteration (not weighted)
|
||||||
|
self.__logged_res_losses = []
|
||||||
|
|
||||||
|
# variable used internally in pina for logging. This variable points to
|
||||||
|
# the current condition during the training step and returns the
|
||||||
|
# condition name. Whenever :meth:`store_log` is called the logged
|
||||||
|
# variable will be stored with name = self.__logged_metric
|
||||||
|
self.__logged_metric = None
|
||||||
|
|
||||||
|
def training_step(self, batch, _):
|
||||||
|
"""
|
||||||
|
The Physics Informed Solver Training Step. This function takes care
|
||||||
|
of the physics informed training step, and it must not be override
|
||||||
|
if not intentionally. It handles the batching mechanism, the workload
|
||||||
|
division for the various conditions, the inverse problem clamping,
|
||||||
|
and loggers.
|
||||||
|
|
||||||
|
:param tuple batch: The batch element in the dataloader.
|
||||||
|
:param int batch_idx: The batch index.
|
||||||
|
:return: The sum of the loss functions.
|
||||||
|
:rtype: LabelTensor
|
||||||
|
"""
|
||||||
|
|
||||||
|
condition_losses = []
|
||||||
|
condition_idx = batch["condition"]
|
||||||
|
|
||||||
|
for condition_id in range(condition_idx.min(), condition_idx.max() + 1):
|
||||||
|
|
||||||
|
condition_name = self._dataloader.condition_names[condition_id]
|
||||||
|
condition = self.problem.conditions[condition_name]
|
||||||
|
pts = batch["pts"]
|
||||||
|
# condition name is logged (if logs enabled)
|
||||||
|
self.__logged_metric = condition_name
|
||||||
|
|
||||||
|
if len(batch) == 2:
|
||||||
|
samples = pts[condition_idx == condition_id]
|
||||||
|
loss = self.loss_phys(samples, condition.equation)
|
||||||
|
elif len(batch) == 3:
|
||||||
|
samples = pts[condition_idx == condition_id]
|
||||||
|
ground_truth = batch["output"][condition_idx == condition_id]
|
||||||
|
loss = self.loss_data(samples, ground_truth)
|
||||||
|
else:
|
||||||
|
raise ValueError("Batch size not supported")
|
||||||
|
|
||||||
|
# add condition losses for each epoch
|
||||||
|
condition_losses.append(loss * condition.data_weight)
|
||||||
|
|
||||||
|
# clamp unknown parameters in InverseProblem (if needed)
|
||||||
|
self._clamp_params()
|
||||||
|
|
||||||
|
# total loss (must be a torch.Tensor)
|
||||||
|
total_loss = sum(condition_losses)
|
||||||
|
return total_loss.as_subclass(torch.Tensor)
|
||||||
|
|
||||||
|
def loss_data(self, input_tensor, output_tensor):
|
||||||
|
"""
|
||||||
|
The data loss for the PINN solver. It computes the loss between
|
||||||
|
the network output against the true solution. This function
|
||||||
|
should not be override if not intentionally.
|
||||||
|
|
||||||
|
:param LabelTensor input_tensor: The input to the neural networks.
|
||||||
|
:param LabelTensor output_tensor: The true solution to compare the
|
||||||
|
network solution.
|
||||||
|
:return: The residual loss averaged on the input coordinates
|
||||||
|
:rtype: torch.Tensor
|
||||||
|
"""
|
||||||
|
loss_value = self.loss(self.forward(input_tensor), output_tensor)
|
||||||
|
self.store_log(loss_value=float(loss_value))
|
||||||
|
return self.loss(self.forward(input_tensor), output_tensor)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def loss_phys(self, samples, equation):
|
||||||
|
"""
|
||||||
|
Computes the physics loss for the physics informed solver based on given
|
||||||
|
samples and equation. This method must be override by all inherited
|
||||||
|
classes and it is the core to define a new physics informed solver.
|
||||||
|
|
||||||
|
:param LabelTensor samples: The samples to evaluate the physics loss.
|
||||||
|
:param EquationInterface equation: The governing equation
|
||||||
|
representing the physics.
|
||||||
|
:return: The physics loss calculated based on given
|
||||||
|
samples and equation.
|
||||||
|
:rtype: LabelTensor
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def compute_residual(self, samples, equation):
|
||||||
|
"""
|
||||||
|
Compute the residual for Physics Informed learning. This function
|
||||||
|
returns the :obj:`~pina.equation.equation.Equation` specified in the
|
||||||
|
:obj:`~pina.condition.Condition` evaluated at the ``samples`` points.
|
||||||
|
|
||||||
|
:param LabelTensor samples: The samples to evaluate the physics loss.
|
||||||
|
:param EquationInterface equation: The governing equation
|
||||||
|
representing the physics.
|
||||||
|
:return: The residual of the neural network solution.
|
||||||
|
:rtype: LabelTensor
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
residual = equation.residual(samples, self.forward(samples))
|
||||||
|
except (
|
||||||
|
TypeError
|
||||||
|
): # this occurs when the function has three inputs, i.e. inverse problem
|
||||||
|
residual = equation.residual(
|
||||||
|
samples, self.forward(samples), self._params
|
||||||
|
)
|
||||||
|
return residual
|
||||||
|
|
||||||
|
def store_log(self, loss_value):
|
||||||
|
"""
|
||||||
|
Stores the loss value in the logger. This function should be
|
||||||
|
called for all conditions. It automatically handles the storing
|
||||||
|
conditions names. It must be used
|
||||||
|
anytime a specific variable wants to be stored for a specific condition.
|
||||||
|
A simple example is to use the variable to store the residual.
|
||||||
|
|
||||||
|
:param str name: The name of the loss.
|
||||||
|
:param torch.Tensor loss_value: The value of the loss.
|
||||||
|
"""
|
||||||
|
self.log(
|
||||||
|
self.__logged_metric+'_loss',
|
||||||
|
loss_value,
|
||||||
|
prog_bar=True,
|
||||||
|
logger=True,
|
||||||
|
on_epoch=True,
|
||||||
|
on_step=False,
|
||||||
|
)
|
||||||
|
self.__logged_res_losses.append(loss_value)
|
||||||
|
|
||||||
|
def on_train_epoch_end(self):
|
||||||
|
"""
|
||||||
|
At the end of each epoch we free the stored losses. This function
|
||||||
|
should not be override if not intentionally.
|
||||||
|
"""
|
||||||
|
if self.__logged_res_losses:
|
||||||
|
# storing mean loss
|
||||||
|
self.__logged_metric = 'mean'
|
||||||
|
self.store_log(
|
||||||
|
sum(self.__logged_res_losses)/len(self.__logged_res_losses)
|
||||||
|
)
|
||||||
|
# free the logged losses
|
||||||
|
self.__logged_res_losses = []
|
||||||
|
return super().on_train_epoch_end()
|
||||||
|
|
||||||
|
def _clamp_inverse_problem_params(self):
|
||||||
|
"""
|
||||||
|
Clamps the parameters of the inverse problem
|
||||||
|
solver to the specified ranges.
|
||||||
|
"""
|
||||||
|
for v in self._params:
|
||||||
|
self._params[v].data.clamp_(
|
||||||
|
self.problem.unknown_parameter_domain.range_[v][0],
|
||||||
|
self.problem.unknown_parameter_domain.range_[v][1],
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def loss(self):
|
||||||
|
"""
|
||||||
|
Loss used for training.
|
||||||
|
"""
|
||||||
|
return self._loss
|
||||||
|
|
||||||
|
@property
|
||||||
|
def current_condition_name(self):
|
||||||
|
"""
|
||||||
|
Returns the condition name. This function can be used inside the
|
||||||
|
:meth:`loss_phys` to extract the condition at which the loss is
|
||||||
|
computed.
|
||||||
|
"""
|
||||||
|
return self.__logged_metric
|
||||||
221
pina/solvers/pinns/causalpinn.py
Normal file
221
pina/solvers/pinns/causalpinn.py
Normal file
@@ -0,0 +1,221 @@
|
|||||||
|
""" Module for CausalPINN """
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
from torch.optim.lr_scheduler import ConstantLR
|
||||||
|
|
||||||
|
from .pinn import PINN
|
||||||
|
from pina.problem import TimeDependentProblem
|
||||||
|
from pina.utils import check_consistency
|
||||||
|
|
||||||
|
|
||||||
|
class CausalPINN(PINN):
|
||||||
|
r"""
|
||||||
|
Causal Physics Informed Neural Network (PINN) solver class.
|
||||||
|
This class implements Causal Physics Informed Neural
|
||||||
|
Network solvers, using a user specified ``model`` to solve a specific
|
||||||
|
``problem``. It can be used for solving both forward and inverse problems.
|
||||||
|
|
||||||
|
The Causal Physics Informed Network aims to find
|
||||||
|
the solution :math:`\mathbf{u}:\Omega\rightarrow\mathbb{R}^m`
|
||||||
|
of the differential problem:
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
|
||||||
|
\begin{cases}
|
||||||
|
\mathcal{A}[\mathbf{u}](\mathbf{x})=0\quad,\mathbf{x}\in\Omega\\
|
||||||
|
\mathcal{B}[\mathbf{u}](\mathbf{x})=0\quad,
|
||||||
|
\mathbf{x}\in\partial\Omega
|
||||||
|
\end{cases}
|
||||||
|
|
||||||
|
minimizing the loss function
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
\mathcal{L}_{\rm{problem}} = \frac{1}{N_t}\sum_{i=1}^{N_t}
|
||||||
|
\omega_{i}\mathcal{L}_r(t_i),
|
||||||
|
|
||||||
|
where:
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
\mathcal{L}_r(t) = \frac{1}{N}\sum_{i=1}^N
|
||||||
|
\mathcal{L}(\mathcal{A}[\mathbf{u}](\mathbf{x}_i, t)) +
|
||||||
|
\frac{1}{N}\sum_{i=1}^N
|
||||||
|
\mathcal{L}(\mathcal{B}[\mathbf{u}](\mathbf{x}_i, t))
|
||||||
|
|
||||||
|
and,
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
\omega_i = \exp\left(\epsilon \sum_{k=1}^{i-1}\mathcal{L}_r(t_k)\right).
|
||||||
|
|
||||||
|
:math:`\epsilon` is an hyperparameter, default set to :math:`100`, while
|
||||||
|
:math:`\mathcal{L}` is a specific loss function,
|
||||||
|
default Mean Square Error:
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
\mathcal{L}(v) = \| v \|^2_2.
|
||||||
|
|
||||||
|
|
||||||
|
.. seealso::
|
||||||
|
|
||||||
|
**Original reference**: Wang, Sifan, Shyam Sankaran, and Paris
|
||||||
|
Perdikaris. "Respecting causality for training physics-informed
|
||||||
|
neural networks." Computer Methods in Applied Mechanics
|
||||||
|
and Engineering 421 (2024): 116813.
|
||||||
|
DOI `10.1016 <https://doi.org/10.1016/j.cma.2024.116813>`_.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
This class can only work for problems inheriting
|
||||||
|
from at least
|
||||||
|
:class:`~pina.problem.timedep_problem.TimeDependentProblem` class.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
problem,
|
||||||
|
model,
|
||||||
|
extra_features=None,
|
||||||
|
loss=torch.nn.MSELoss(),
|
||||||
|
optimizer=torch.optim.Adam,
|
||||||
|
optimizer_kwargs={"lr": 0.001},
|
||||||
|
scheduler=ConstantLR,
|
||||||
|
scheduler_kwargs={"factor": 1, "total_iters": 0},
|
||||||
|
eps=100,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
:param AbstractProblem problem: The formulation of the problem.
|
||||||
|
:param torch.nn.Module model: The neural network model to use.
|
||||||
|
:param torch.nn.Module loss: The loss function used as minimizer,
|
||||||
|
default :class:`torch.nn.MSELoss`.
|
||||||
|
:param torch.nn.Module extra_features: The additional input
|
||||||
|
features to use as augmented input.
|
||||||
|
:param torch.optim.Optimizer optimizer: The neural network optimizer to
|
||||||
|
use; default is :class:`torch.optim.Adam`.
|
||||||
|
:param dict optimizer_kwargs: Optimizer constructor keyword args.
|
||||||
|
:param torch.optim.LRScheduler scheduler: Learning
|
||||||
|
rate scheduler.
|
||||||
|
:param dict scheduler_kwargs: LR scheduler constructor keyword args.
|
||||||
|
:param int | float eps: The exponential decay parameter. Note that this
|
||||||
|
value is kept fixed during the training, but can be changed by means
|
||||||
|
of a callback, e.g. for annealing.
|
||||||
|
"""
|
||||||
|
super().__init__(
|
||||||
|
problem=problem,
|
||||||
|
model=model,
|
||||||
|
extra_features=extra_features,
|
||||||
|
loss=loss,
|
||||||
|
optimizer=optimizer,
|
||||||
|
optimizer_kwargs=optimizer_kwargs,
|
||||||
|
scheduler=scheduler,
|
||||||
|
scheduler_kwargs=scheduler_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
# checking consistency
|
||||||
|
check_consistency(eps, (int,float))
|
||||||
|
self._eps = eps
|
||||||
|
if not isinstance(self.problem, TimeDependentProblem):
|
||||||
|
raise ValueError('Casual PINN works only for problems'
|
||||||
|
'inheritig from TimeDependentProblem.')
|
||||||
|
|
||||||
|
def loss_phys(self, samples, equation):
|
||||||
|
"""
|
||||||
|
Computes the physics loss for the Causal PINN solver based on given
|
||||||
|
samples and equation.
|
||||||
|
|
||||||
|
:param LabelTensor samples: The samples to evaluate the physics loss.
|
||||||
|
:param EquationInterface equation: The governing equation
|
||||||
|
representing the physics.
|
||||||
|
:return: The physics loss calculated based on given
|
||||||
|
samples and equation.
|
||||||
|
:rtype: LabelTensor
|
||||||
|
"""
|
||||||
|
# split sequentially ordered time tensors into chunks
|
||||||
|
chunks, labels = self._split_tensor_into_chunks(samples)
|
||||||
|
# compute residuals - this correspond to ordered loss functions
|
||||||
|
# values for each time step. We apply `flatten` such that after
|
||||||
|
# concataning the residuals we obtain a tensor of shape #chunks
|
||||||
|
time_loss = []
|
||||||
|
for chunk in chunks:
|
||||||
|
chunk.labels = labels
|
||||||
|
# classical PINN loss
|
||||||
|
residual = self.compute_residual(samples=chunk, equation=equation)
|
||||||
|
loss_val = self.loss(
|
||||||
|
torch.zeros_like(residual, requires_grad=True), residual
|
||||||
|
)
|
||||||
|
time_loss.append(loss_val)
|
||||||
|
# store results
|
||||||
|
self.store_log(loss_value=float(sum(time_loss)/len(time_loss)))
|
||||||
|
# concatenate residuals
|
||||||
|
time_loss = torch.stack(time_loss)
|
||||||
|
# compute weights (without the gradient storing)
|
||||||
|
with torch.no_grad():
|
||||||
|
weights = self._compute_weights(time_loss)
|
||||||
|
return (weights * time_loss).mean()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def eps(self):
|
||||||
|
"""
|
||||||
|
The exponential decay parameter.
|
||||||
|
"""
|
||||||
|
return self._eps
|
||||||
|
|
||||||
|
@eps.setter
|
||||||
|
def eps(self, value):
|
||||||
|
"""
|
||||||
|
Setter method for the eps parameter.
|
||||||
|
|
||||||
|
:param float value: The exponential decay parameter.
|
||||||
|
"""
|
||||||
|
check_consistency(value, float)
|
||||||
|
self._eps = value
|
||||||
|
|
||||||
|
def _sort_label_tensor(self, tensor):
|
||||||
|
"""
|
||||||
|
Sorts the label tensor based on time variables.
|
||||||
|
|
||||||
|
:param LabelTensor tensor: The label tensor to be sorted.
|
||||||
|
:return: The sorted label tensor based on time variables.
|
||||||
|
:rtype: LabelTensor
|
||||||
|
"""
|
||||||
|
# labels input tensors
|
||||||
|
labels = tensor.labels
|
||||||
|
# extract time tensor
|
||||||
|
time_tensor = tensor.extract(self.problem.temporal_domain.variables)
|
||||||
|
# sort the time tensors (this is very bad for GPU)
|
||||||
|
_, idx = torch.sort(time_tensor.tensor.flatten())
|
||||||
|
tensor = tensor[idx]
|
||||||
|
tensor.labels = labels
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
def _split_tensor_into_chunks(self, tensor):
|
||||||
|
"""
|
||||||
|
Splits the label tensor into chunks based on time.
|
||||||
|
|
||||||
|
:param LabelTensor tensor: The label tensor to be split.
|
||||||
|
:return: Tuple containing the chunks and the original labels.
|
||||||
|
:rtype: Tuple[List[LabelTensor], List]
|
||||||
|
"""
|
||||||
|
# labels input tensors
|
||||||
|
labels = tensor.labels
|
||||||
|
# labels input tensors
|
||||||
|
tensor = self._sort_label_tensor(tensor)
|
||||||
|
# extract time tensor
|
||||||
|
time_tensor = tensor.extract(self.problem.temporal_domain.variables)
|
||||||
|
# count unique tensors in time
|
||||||
|
_, idx_split = time_tensor.unique(return_counts=True)
|
||||||
|
# splitting
|
||||||
|
chunks = torch.split(tensor, tuple(idx_split))
|
||||||
|
return chunks, labels # return chunks
|
||||||
|
|
||||||
|
def _compute_weights(self, loss):
|
||||||
|
"""
|
||||||
|
Computes the weights for the physics loss based on the cumulative loss.
|
||||||
|
|
||||||
|
:param LabelTensor loss: The physics loss values.
|
||||||
|
:return: The computed weights for the physics loss.
|
||||||
|
:rtype: LabelTensor
|
||||||
|
"""
|
||||||
|
# compute comulative loss and multiply by epsilos
|
||||||
|
cumulative_loss = self._eps * torch.cumsum(loss, dim=0)
|
||||||
|
# return the exponential of the weghited negative cumulative sum
|
||||||
|
return torch.exp(-cumulative_loss)
|
||||||
360
pina/solvers/pinns/competitive_pinn.py
Normal file
360
pina/solvers/pinns/competitive_pinn.py
Normal file
@@ -0,0 +1,360 @@
|
|||||||
|
""" Module for CompetitivePINN """
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import copy
|
||||||
|
|
||||||
|
try:
|
||||||
|
from torch.optim.lr_scheduler import LRScheduler # torch >= 2.0
|
||||||
|
except ImportError:
|
||||||
|
from torch.optim.lr_scheduler import (
|
||||||
|
_LRScheduler as LRScheduler,
|
||||||
|
) # torch < 2.0
|
||||||
|
|
||||||
|
from torch.optim.lr_scheduler import ConstantLR
|
||||||
|
|
||||||
|
from .basepinn import PINNInterface
|
||||||
|
from pina.utils import check_consistency
|
||||||
|
from pina.problem import InverseProblem
|
||||||
|
|
||||||
|
|
||||||
|
class CompetitivePINN(PINNInterface):
|
||||||
|
r"""
|
||||||
|
Competitive Physics Informed Neural Network (PINN) solver class.
|
||||||
|
This class implements Competitive Physics Informed Neural
|
||||||
|
Network solvers, using a user specified ``model`` to solve a specific
|
||||||
|
``problem``. It can be used for solving both forward and inverse problems.
|
||||||
|
|
||||||
|
The Competitive Physics Informed Network aims to find
|
||||||
|
the solution :math:`\mathbf{u}:\Omega\rightarrow\mathbb{R}^m`
|
||||||
|
of the differential problem:
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
|
||||||
|
\begin{cases}
|
||||||
|
\mathcal{A}[\mathbf{u}](\mathbf{x})=0\quad,\mathbf{x}\in\Omega\\
|
||||||
|
\mathcal{B}[\mathbf{u}](\mathbf{x})=0\quad,
|
||||||
|
\mathbf{x}\in\partial\Omega
|
||||||
|
\end{cases}
|
||||||
|
|
||||||
|
with a minimization (on ``model`` parameters) maximation (
|
||||||
|
on ``discriminator`` parameters) of the loss function
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
\mathcal{L}_{\rm{problem}} = \frac{1}{N}\sum_{i=1}^N
|
||||||
|
\mathcal{L}(D(\mathbf{x}_i)\mathcal{A}[\mathbf{u}](\mathbf{x}_i))+
|
||||||
|
\frac{1}{N}\sum_{i=1}^N
|
||||||
|
\mathcal{L}(D(\mathbf{x}_i)\mathcal{B}[\mathbf{u}](\mathbf{x}_i))
|
||||||
|
|
||||||
|
where :math:`D` is the discriminator network, which tries to find the points
|
||||||
|
where the network performs worst, and :math:`\mathcal{L}` is a specific loss
|
||||||
|
function, default Mean Square Error:
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
\mathcal{L}(v) = \| v \|^2_2.
|
||||||
|
|
||||||
|
.. seealso::
|
||||||
|
|
||||||
|
**Original reference**: Zeng, Qi, et al.
|
||||||
|
"Competitive physics informed networks." International Conference on
|
||||||
|
Learning Representations, ICLR 2022
|
||||||
|
`OpenReview Preprint <https://openreview.net/forum?id=z9SIj-IM7tn>`_.
|
||||||
|
|
||||||
|
.. warning::
|
||||||
|
This solver does not currently support the possibility to pass
|
||||||
|
``extra_feature``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
problem,
|
||||||
|
model,
|
||||||
|
discriminator=None,
|
||||||
|
loss=torch.nn.MSELoss(),
|
||||||
|
optimizer_model=torch.optim.Adam,
|
||||||
|
optimizer_model_kwargs={"lr": 0.001},
|
||||||
|
optimizer_discriminator=torch.optim.Adam,
|
||||||
|
optimizer_discriminator_kwargs={"lr": 0.001},
|
||||||
|
scheduler_model=ConstantLR,
|
||||||
|
scheduler_model_kwargs={"factor": 1, "total_iters": 0},
|
||||||
|
scheduler_discriminator=ConstantLR,
|
||||||
|
scheduler_discriminator_kwargs={"factor": 1, "total_iters": 0},
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
:param AbstractProblem problem: The formualation of the problem.
|
||||||
|
:param torch.nn.Module model: The neural network model to use
|
||||||
|
for the model.
|
||||||
|
:param torch.nn.Module discriminator: The neural network model to use
|
||||||
|
for the discriminator. If ``None``, the discriminator network will
|
||||||
|
have the same architecture as the model network.
|
||||||
|
:param torch.nn.Module loss: The loss function used as minimizer,
|
||||||
|
default :class:`torch.nn.MSELoss`.
|
||||||
|
:param torch.optim.Optimizer optimizer_model: The neural
|
||||||
|
network optimizer to use for the model network
|
||||||
|
, default is `torch.optim.Adam`.
|
||||||
|
:param dict optimizer_model_kwargs: Optimizer constructor keyword
|
||||||
|
args. for the model.
|
||||||
|
:param torch.optim.Optimizer optimizer_discriminator: The neural
|
||||||
|
network optimizer to use for the discriminator network
|
||||||
|
, default is `torch.optim.Adam`.
|
||||||
|
:param dict optimizer_discriminator_kwargs: Optimizer constructor
|
||||||
|
keyword args. for the discriminator.
|
||||||
|
:param torch.optim.LRScheduler scheduler_model: Learning
|
||||||
|
rate scheduler for the model.
|
||||||
|
:param dict scheduler_model_kwargs: LR scheduler constructor
|
||||||
|
keyword args.
|
||||||
|
:param torch.optim.LRScheduler scheduler_discriminator: Learning
|
||||||
|
rate scheduler for the discriminator.
|
||||||
|
"""
|
||||||
|
if discriminator is None:
|
||||||
|
discriminator = copy.deepcopy(model)
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
models=[model, discriminator],
|
||||||
|
problem=problem,
|
||||||
|
optimizers=[optimizer_model, optimizer_discriminator],
|
||||||
|
optimizers_kwargs=[
|
||||||
|
optimizer_model_kwargs,
|
||||||
|
optimizer_discriminator_kwargs,
|
||||||
|
],
|
||||||
|
extra_features=None, # CompetitivePINN doesn't take extra features
|
||||||
|
loss=loss
|
||||||
|
)
|
||||||
|
|
||||||
|
# set automatic optimization for GANs
|
||||||
|
self.automatic_optimization = False
|
||||||
|
|
||||||
|
# check consistency
|
||||||
|
check_consistency(scheduler_model, LRScheduler, subclass=True)
|
||||||
|
check_consistency(scheduler_model_kwargs, dict)
|
||||||
|
check_consistency(scheduler_discriminator, LRScheduler, subclass=True)
|
||||||
|
check_consistency(scheduler_discriminator_kwargs, dict)
|
||||||
|
|
||||||
|
# assign schedulers
|
||||||
|
self._schedulers = [
|
||||||
|
scheduler_model(
|
||||||
|
self.optimizers[0], **scheduler_model_kwargs
|
||||||
|
),
|
||||||
|
scheduler_discriminator(
|
||||||
|
self.optimizers[1], **scheduler_discriminator_kwargs
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
self._model = self.models[0]
|
||||||
|
self._discriminator = self.models[1]
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
r"""
|
||||||
|
Forward pass implementation for the PINN solver. It returns the function
|
||||||
|
evaluation :math:`\mathbf{u}(\mathbf{x})` at the control points
|
||||||
|
:math:`\mathbf{x}`.
|
||||||
|
|
||||||
|
:param LabelTensor x: Input tensor for the PINN solver. It expects
|
||||||
|
a tensor :math:`N \times D`, where :math:`N` the number of points
|
||||||
|
in the mesh, :math:`D` the dimension of the problem,
|
||||||
|
:return: PINN solution evaluated at contro points.
|
||||||
|
:rtype: LabelTensor
|
||||||
|
"""
|
||||||
|
return self.neural_net(x)
|
||||||
|
|
||||||
|
def loss_phys(self, samples, equation):
|
||||||
|
"""
|
||||||
|
Computes the physics loss for the Competitive PINN solver based on given
|
||||||
|
samples and equation.
|
||||||
|
|
||||||
|
:param LabelTensor samples: The samples to evaluate the physics loss.
|
||||||
|
:param EquationInterface equation: The governing equation
|
||||||
|
representing the physics.
|
||||||
|
:return: The physics loss calculated based on given
|
||||||
|
samples and equation.
|
||||||
|
:rtype: LabelTensor
|
||||||
|
"""
|
||||||
|
# train one step of the model
|
||||||
|
with torch.no_grad():
|
||||||
|
discriminator_bets = self.discriminator(samples)
|
||||||
|
loss_val = self._train_model(samples, equation, discriminator_bets)
|
||||||
|
self.store_log(loss_value=float(loss_val))
|
||||||
|
# detaching samples from the computational graph to erase it and setting
|
||||||
|
# the gradient to true to create a new computational graph.
|
||||||
|
# In alternative set `retain_graph=True`.
|
||||||
|
samples = samples.detach()
|
||||||
|
samples.requires_grad = True
|
||||||
|
# train one step of discriminator
|
||||||
|
discriminator_bets = self.discriminator(samples)
|
||||||
|
self._train_discriminator(samples, equation, discriminator_bets)
|
||||||
|
return loss_val
|
||||||
|
|
||||||
|
def loss_data(self, input_tensor, output_tensor):
|
||||||
|
"""
|
||||||
|
The data loss for the PINN solver. It computes the loss between the
|
||||||
|
network output against the true solution.
|
||||||
|
|
||||||
|
:param LabelTensor input_tensor: The input to the neural networks.
|
||||||
|
:param LabelTensor output_tensor: The true solution to compare the
|
||||||
|
network solution.
|
||||||
|
:return: The computed data loss.
|
||||||
|
:rtype: torch.Tensor
|
||||||
|
"""
|
||||||
|
self.optimizer_model.zero_grad()
|
||||||
|
loss_val = super().loss_data(
|
||||||
|
input_tensor, output_tensor).as_subclass(torch.Tensor)
|
||||||
|
loss_val.backward()
|
||||||
|
self.optimizer_model.step()
|
||||||
|
return loss_val
|
||||||
|
|
||||||
|
def configure_optimizers(self):
|
||||||
|
"""
|
||||||
|
Optimizer configuration for the Competitive PINN solver.
|
||||||
|
|
||||||
|
:return: The optimizers and the schedulers
|
||||||
|
:rtype: tuple(list, list)
|
||||||
|
"""
|
||||||
|
# if the problem is an InverseProblem, add the unknown parameters
|
||||||
|
# to the parameters that the optimizer needs to optimize
|
||||||
|
if isinstance(self.problem, InverseProblem):
|
||||||
|
self.optimizer_model.add_param_group(
|
||||||
|
{
|
||||||
|
"params": [
|
||||||
|
self._params[var]
|
||||||
|
for var in self.problem.unknown_variables
|
||||||
|
]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return self.optimizers, self._schedulers
|
||||||
|
|
||||||
|
def on_train_batch_end(self,outputs, batch, batch_idx):
|
||||||
|
"""
|
||||||
|
This method is called at the end of each training batch, and ovverides
|
||||||
|
the PytorchLightining implementation for logging the checkpoints.
|
||||||
|
|
||||||
|
:param torch.Tensor outputs: The output from the model for the
|
||||||
|
current batch.
|
||||||
|
:param tuple batch: The current batch of data.
|
||||||
|
:param int batch_idx: The index of the current batch.
|
||||||
|
:return: Whatever is returned by the parent
|
||||||
|
method ``on_train_batch_end``.
|
||||||
|
:rtype: Any
|
||||||
|
"""
|
||||||
|
# increase by one the counter of optimization to save loggers
|
||||||
|
self.trainer.fit_loop.epoch_loop.manual_optimization.optim_step_progress.total.completed += 1
|
||||||
|
return super().on_train_batch_end(outputs, batch, batch_idx)
|
||||||
|
|
||||||
|
def _train_discriminator(self, samples, equation, discriminator_bets):
|
||||||
|
"""
|
||||||
|
Trains the discriminator network of the Competitive PINN.
|
||||||
|
|
||||||
|
:param LabelTensor samples: Input samples to evaluate the physics loss.
|
||||||
|
:param EquationInterface equation: The governing equation representing
|
||||||
|
the physics.
|
||||||
|
:param Tensor discriminator_bets: Predictions made by the discriminator
|
||||||
|
network.
|
||||||
|
"""
|
||||||
|
# manual optimization
|
||||||
|
self.optimizer_discriminator.zero_grad()
|
||||||
|
# compute residual, we detach because the weights of the generator
|
||||||
|
# model are fixed
|
||||||
|
residual = self.compute_residual(samples=samples,
|
||||||
|
equation=equation).detach()
|
||||||
|
# compute competitive residual, the minus is because we maximise
|
||||||
|
competitive_residual = residual * discriminator_bets
|
||||||
|
loss_val = - self.loss(
|
||||||
|
torch.zeros_like(competitive_residual, requires_grad=True),
|
||||||
|
competitive_residual
|
||||||
|
).as_subclass(torch.Tensor)
|
||||||
|
# backprop
|
||||||
|
self.manual_backward(loss_val)
|
||||||
|
self.optimizer_discriminator.step()
|
||||||
|
return
|
||||||
|
|
||||||
|
def _train_model(self, samples, equation, discriminator_bets):
|
||||||
|
"""
|
||||||
|
Trains the model network of the Competitive PINN.
|
||||||
|
|
||||||
|
:param LabelTensor samples: Input samples to evaluate the physics loss.
|
||||||
|
:param EquationInterface equation: The governing equation representing
|
||||||
|
the physics.
|
||||||
|
:param Tensor discriminator_bets: Predictions made by the discriminator.
|
||||||
|
network.
|
||||||
|
:return: The computed data loss.
|
||||||
|
:rtype: torch.Tensor
|
||||||
|
"""
|
||||||
|
# manual optimization
|
||||||
|
self.optimizer_model.zero_grad()
|
||||||
|
# compute residual (detached for discriminator) and log
|
||||||
|
residual = self.compute_residual(samples=samples, equation=equation)
|
||||||
|
# store logging
|
||||||
|
with torch.no_grad():
|
||||||
|
loss_residual = self.loss(
|
||||||
|
torch.zeros_like(residual),
|
||||||
|
residual
|
||||||
|
)
|
||||||
|
# compute competitive residual, discriminator_bets are detached becase
|
||||||
|
# we optimize only the generator model
|
||||||
|
competitive_residual = residual * discriminator_bets.detach()
|
||||||
|
loss_val = self.loss(
|
||||||
|
torch.zeros_like(competitive_residual, requires_grad=True),
|
||||||
|
competitive_residual
|
||||||
|
).as_subclass(torch.Tensor)
|
||||||
|
# backprop
|
||||||
|
self.manual_backward(loss_val)
|
||||||
|
self.optimizer_model.step()
|
||||||
|
return loss_residual
|
||||||
|
|
||||||
|
@property
|
||||||
|
def neural_net(self):
|
||||||
|
"""
|
||||||
|
Returns the neural network model.
|
||||||
|
|
||||||
|
:return: The neural network model.
|
||||||
|
:rtype: torch.nn.Module
|
||||||
|
"""
|
||||||
|
return self._model
|
||||||
|
|
||||||
|
@property
|
||||||
|
def discriminator(self):
|
||||||
|
"""
|
||||||
|
Returns the discriminator model (if applicable).
|
||||||
|
|
||||||
|
:return: The discriminator model.
|
||||||
|
:rtype: torch.nn.Module
|
||||||
|
"""
|
||||||
|
return self._discriminator
|
||||||
|
|
||||||
|
@property
|
||||||
|
def optimizer_model(self):
|
||||||
|
"""
|
||||||
|
Returns the optimizer associated with the neural network model.
|
||||||
|
|
||||||
|
:return: The optimizer for the neural network model.
|
||||||
|
:rtype: torch.optim.Optimizer
|
||||||
|
"""
|
||||||
|
return self.optimizers[0]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def optimizer_discriminator(self):
|
||||||
|
"""
|
||||||
|
Returns the optimizer associated with the discriminator (if applicable).
|
||||||
|
|
||||||
|
:return: The optimizer for the discriminator.
|
||||||
|
:rtype: torch.optim.Optimizer
|
||||||
|
"""
|
||||||
|
return self.optimizers[1]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def scheduler_model(self):
|
||||||
|
"""
|
||||||
|
Returns the scheduler associated with the neural network model.
|
||||||
|
|
||||||
|
:return: The scheduler for the neural network model.
|
||||||
|
:rtype: torch.optim.lr_scheduler._LRScheduler
|
||||||
|
"""
|
||||||
|
return self._schedulers[0]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def scheduler_discriminator(self):
|
||||||
|
"""
|
||||||
|
Returns the scheduler associated with the discriminator (if applicable).
|
||||||
|
|
||||||
|
:return: The scheduler for the discriminator.
|
||||||
|
:rtype: torch.optim.lr_scheduler._LRScheduler
|
||||||
|
"""
|
||||||
|
return self._schedulers[1]
|
||||||
134
pina/solvers/pinns/gpinn.py
Normal file
134
pina/solvers/pinns/gpinn.py
Normal file
@@ -0,0 +1,134 @@
|
|||||||
|
""" Module for GPINN """
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
from torch.optim.lr_scheduler import ConstantLR
|
||||||
|
|
||||||
|
from .pinn import PINN
|
||||||
|
from pina.operators import grad
|
||||||
|
from pina.problem import SpatialProblem
|
||||||
|
|
||||||
|
|
||||||
|
class GPINN(PINN):
|
||||||
|
r"""
|
||||||
|
Gradient Physics Informed Neural Network (GPINN) solver class.
|
||||||
|
This class implements Gradient Physics Informed Neural
|
||||||
|
Network solvers, using a user specified ``model`` to solve a specific
|
||||||
|
``problem``. It can be used for solving both forward and inverse problems.
|
||||||
|
|
||||||
|
The Gradient Physics Informed Network aims to find
|
||||||
|
the solution :math:`\mathbf{u}:\Omega\rightarrow\mathbb{R}^m`
|
||||||
|
of the differential problem:
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
|
||||||
|
\begin{cases}
|
||||||
|
\mathcal{A}[\mathbf{u}](\mathbf{x})=0\quad,\mathbf{x}\in\Omega\\
|
||||||
|
\mathcal{B}[\mathbf{u}](\mathbf{x})=0\quad,
|
||||||
|
\mathbf{x}\in\partial\Omega
|
||||||
|
\end{cases}
|
||||||
|
|
||||||
|
minimizing the loss function
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
\mathcal{L}_{\rm{problem}} =& \frac{1}{N}\sum_{i=1}^N
|
||||||
|
\mathcal{L}(\mathcal{A}[\mathbf{u}](\mathbf{x}_i)) +
|
||||||
|
\frac{1}{N}\sum_{i=1}^N
|
||||||
|
\mathcal{L}(\mathcal{B}[\mathbf{u}](\mathbf{x}_i)) + \\
|
||||||
|
&\frac{1}{N}\sum_{i=1}^N
|
||||||
|
\nabla_{\mathbf{x}}\mathcal{L}(\mathcal{A}[\mathbf{u}](\mathbf{x}_i)) +
|
||||||
|
\frac{1}{N}\sum_{i=1}^N
|
||||||
|
\nabla_{\mathbf{x}}\mathcal{L}(\mathcal{B}[\mathbf{u}](\mathbf{x}_i))
|
||||||
|
|
||||||
|
|
||||||
|
where :math:`\mathcal{L}` is a specific loss function, default Mean Square Error:
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
\mathcal{L}(v) = \| v \|^2_2.
|
||||||
|
|
||||||
|
.. seealso::
|
||||||
|
|
||||||
|
**Original reference**: Yu, Jeremy, et al. "Gradient-enhanced
|
||||||
|
physics-informed neural networks for forward and inverse
|
||||||
|
PDE problems." Computer Methods in Applied Mechanics
|
||||||
|
and Engineering 393 (2022): 114823.
|
||||||
|
DOI: `10.1016 <https://doi.org/10.1016/j.cma.2022.114823>`_.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
This class can only work for problems inheriting
|
||||||
|
from at least :class:`~pina.problem.spatial_problem.SpatialProblem`
|
||||||
|
class.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
problem,
|
||||||
|
model,
|
||||||
|
extra_features=None,
|
||||||
|
loss=torch.nn.MSELoss(),
|
||||||
|
optimizer=torch.optim.Adam,
|
||||||
|
optimizer_kwargs={"lr": 0.001},
|
||||||
|
scheduler=ConstantLR,
|
||||||
|
scheduler_kwargs={"factor": 1, "total_iters": 0},
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
:param AbstractProblem problem: The formulation of the problem. It must
|
||||||
|
inherit from at least
|
||||||
|
:class:`~pina.problem.spatial_problem.SpatialProblem` in order to
|
||||||
|
compute the gradient of the loss.
|
||||||
|
:param torch.nn.Module model: The neural network model to use.
|
||||||
|
:param torch.nn.Module loss: The loss function used as minimizer,
|
||||||
|
default :class:`torch.nn.MSELoss`.
|
||||||
|
:param torch.nn.Module extra_features: The additional input
|
||||||
|
features to use as augmented input.
|
||||||
|
:param torch.optim.Optimizer optimizer: The neural network optimizer to
|
||||||
|
use; default is :class:`torch.optim.Adam`.
|
||||||
|
:param dict optimizer_kwargs: Optimizer constructor keyword args.
|
||||||
|
:param torch.optim.LRScheduler scheduler: Learning
|
||||||
|
rate scheduler.
|
||||||
|
:param dict scheduler_kwargs: LR scheduler constructor keyword args.
|
||||||
|
"""
|
||||||
|
super().__init__(
|
||||||
|
problem=problem,
|
||||||
|
model=model,
|
||||||
|
extra_features=extra_features,
|
||||||
|
loss=loss,
|
||||||
|
optimizer=optimizer,
|
||||||
|
optimizer_kwargs=optimizer_kwargs,
|
||||||
|
scheduler=scheduler,
|
||||||
|
scheduler_kwargs=scheduler_kwargs,
|
||||||
|
)
|
||||||
|
if not isinstance(self.problem, SpatialProblem):
|
||||||
|
raise ValueError('Gradient PINN computes the gradient of the '
|
||||||
|
'PINN loss with respect to the spatial '
|
||||||
|
'coordinates, thus the PINA problem must be '
|
||||||
|
'a SpatialProblem.')
|
||||||
|
|
||||||
|
|
||||||
|
def loss_phys(self, samples, equation):
|
||||||
|
"""
|
||||||
|
Computes the physics loss for the GPINN solver based on given
|
||||||
|
samples and equation.
|
||||||
|
|
||||||
|
:param LabelTensor samples: The samples to evaluate the physics loss.
|
||||||
|
:param EquationInterface equation: The governing equation
|
||||||
|
representing the physics.
|
||||||
|
:return: The physics loss calculated based on given
|
||||||
|
samples and equation.
|
||||||
|
:rtype: LabelTensor
|
||||||
|
"""
|
||||||
|
# classical PINN loss
|
||||||
|
residual = self.compute_residual(samples=samples, equation=equation)
|
||||||
|
loss_value = self.loss(
|
||||||
|
torch.zeros_like(residual, requires_grad=True), residual
|
||||||
|
)
|
||||||
|
self.store_log(loss_value=float(loss_value))
|
||||||
|
# gradient PINN loss
|
||||||
|
loss_value = loss_value.reshape(-1, 1)
|
||||||
|
loss_value.labels = ['__LOSS']
|
||||||
|
loss_grad = grad(loss_value, samples, d=self.problem.spatial_variables)
|
||||||
|
g_loss_phys = self.loss(
|
||||||
|
torch.zeros_like(loss_grad, requires_grad=True), loss_grad
|
||||||
|
)
|
||||||
|
return loss_value + g_loss_phys
|
||||||
170
pina/solvers/pinns/pinn.py
Normal file
170
pina/solvers/pinns/pinn.py
Normal file
@@ -0,0 +1,170 @@
|
|||||||
|
""" Module for Physics Informed Neural Network. """
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
try:
|
||||||
|
from torch.optim.lr_scheduler import LRScheduler # torch >= 2.0
|
||||||
|
except ImportError:
|
||||||
|
from torch.optim.lr_scheduler import (
|
||||||
|
_LRScheduler as LRScheduler,
|
||||||
|
) # torch < 2.0
|
||||||
|
|
||||||
|
from torch.optim.lr_scheduler import ConstantLR
|
||||||
|
|
||||||
|
from .basepinn import PINNInterface
|
||||||
|
from pina.utils import check_consistency
|
||||||
|
from pina.problem import InverseProblem
|
||||||
|
|
||||||
|
|
||||||
|
class PINN(PINNInterface):
|
||||||
|
r"""
|
||||||
|
Physics Informed Neural Network (PINN) solver class.
|
||||||
|
This class implements Physics Informed Neural
|
||||||
|
Network solvers, using a user specified ``model`` to solve a specific
|
||||||
|
``problem``. It can be used for solving both forward and inverse problems.
|
||||||
|
|
||||||
|
The Physics Informed Network aims to find
|
||||||
|
the solution :math:`\mathbf{u}:\Omega\rightarrow\mathbb{R}^m`
|
||||||
|
of the differential problem:
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
|
||||||
|
\begin{cases}
|
||||||
|
\mathcal{A}[\mathbf{u}](\mathbf{x})=0\quad,\mathbf{x}\in\Omega\\
|
||||||
|
\mathcal{B}[\mathbf{u}](\mathbf{x})=0\quad,
|
||||||
|
\mathbf{x}\in\partial\Omega
|
||||||
|
\end{cases}
|
||||||
|
|
||||||
|
minimizing the loss function
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
\mathcal{L}_{\rm{problem}} = \frac{1}{N}\sum_{i=1}^N
|
||||||
|
\mathcal{L}(\mathcal{A}[\mathbf{u}](\mathbf{x}_i)) +
|
||||||
|
\frac{1}{N}\sum_{i=1}^N
|
||||||
|
\mathcal{L}(\mathcal{B}[\mathbf{u}](\mathbf{x}_i))
|
||||||
|
|
||||||
|
where :math:`\mathcal{L}` is a specific loss function, default Mean Square Error:
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
\mathcal{L}(v) = \| v \|^2_2.
|
||||||
|
|
||||||
|
.. seealso::
|
||||||
|
|
||||||
|
**Original reference**: Karniadakis, G. E., Kevrekidis, I. G., Lu, L.,
|
||||||
|
Perdikaris, P., Wang, S., & Yang, L. (2021).
|
||||||
|
Physics-informed machine learning. Nature Reviews Physics, 3, 422-440.
|
||||||
|
DOI: `10.1038 <https://doi.org/10.1038/s42254-021-00314-5>`_.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
problem,
|
||||||
|
model,
|
||||||
|
extra_features=None,
|
||||||
|
loss=torch.nn.MSELoss(),
|
||||||
|
optimizer=torch.optim.Adam,
|
||||||
|
optimizer_kwargs={"lr": 0.001},
|
||||||
|
scheduler=ConstantLR,
|
||||||
|
scheduler_kwargs={"factor": 1, "total_iters": 0},
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
:param AbstractProblem problem: The formulation of the problem.
|
||||||
|
:param torch.nn.Module model: The neural network model to use.
|
||||||
|
:param torch.nn.Module loss: The loss function used as minimizer,
|
||||||
|
default :class:`torch.nn.MSELoss`.
|
||||||
|
:param torch.nn.Module extra_features: The additional input
|
||||||
|
features to use as augmented input.
|
||||||
|
:param torch.optim.Optimizer optimizer: The neural network optimizer to
|
||||||
|
use; default is :class:`torch.optim.Adam`.
|
||||||
|
:param dict optimizer_kwargs: Optimizer constructor keyword args.
|
||||||
|
:param torch.optim.LRScheduler scheduler: Learning
|
||||||
|
rate scheduler.
|
||||||
|
:param dict scheduler_kwargs: LR scheduler constructor keyword args.
|
||||||
|
"""
|
||||||
|
super().__init__(
|
||||||
|
models=[model],
|
||||||
|
problem=problem,
|
||||||
|
optimizers=[optimizer],
|
||||||
|
optimizers_kwargs=[optimizer_kwargs],
|
||||||
|
extra_features=extra_features,
|
||||||
|
loss=loss
|
||||||
|
)
|
||||||
|
|
||||||
|
# check consistency
|
||||||
|
check_consistency(scheduler, LRScheduler, subclass=True)
|
||||||
|
check_consistency(scheduler_kwargs, dict)
|
||||||
|
|
||||||
|
# assign variables
|
||||||
|
self._scheduler = scheduler(self.optimizers[0], **scheduler_kwargs)
|
||||||
|
self._neural_net = self.models[0]
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
r"""
|
||||||
|
Forward pass implementation for the PINN solver. It returns the function
|
||||||
|
evaluation :math:`\mathbf{u}(\mathbf{x})` at the control points
|
||||||
|
:math:`\mathbf{x}`.
|
||||||
|
|
||||||
|
:param LabelTensor x: Input tensor for the PINN solver. It expects
|
||||||
|
a tensor :math:`N \times D`, where :math:`N` the number of points
|
||||||
|
in the mesh, :math:`D` the dimension of the problem,
|
||||||
|
:return: PINN solution evaluated at contro points.
|
||||||
|
:rtype: LabelTensor
|
||||||
|
"""
|
||||||
|
return self.neural_net(x)
|
||||||
|
|
||||||
|
def loss_phys(self, samples, equation):
|
||||||
|
"""
|
||||||
|
Computes the physics loss for the PINN solver based on given
|
||||||
|
samples and equation.
|
||||||
|
|
||||||
|
:param LabelTensor samples: The samples to evaluate the physics loss.
|
||||||
|
:param EquationInterface equation: The governing equation
|
||||||
|
representing the physics.
|
||||||
|
:return: The physics loss calculated based on given
|
||||||
|
samples and equation.
|
||||||
|
:rtype: LabelTensor
|
||||||
|
"""
|
||||||
|
residual = self.compute_residual(samples=samples, equation=equation)
|
||||||
|
loss_value = self.loss(
|
||||||
|
torch.zeros_like(residual, requires_grad=True), residual
|
||||||
|
)
|
||||||
|
self.store_log(loss_value=float(loss_value))
|
||||||
|
return loss_value
|
||||||
|
|
||||||
|
|
||||||
|
def configure_optimizers(self):
|
||||||
|
"""
|
||||||
|
Optimizer configuration for the PINN
|
||||||
|
solver.
|
||||||
|
|
||||||
|
:return: The optimizers and the schedulers
|
||||||
|
:rtype: tuple(list, list)
|
||||||
|
"""
|
||||||
|
# if the problem is an InverseProblem, add the unknown parameters
|
||||||
|
# to the parameters that the optimizer needs to optimize
|
||||||
|
if isinstance(self.problem, InverseProblem):
|
||||||
|
self.optimizers[0].add_param_group(
|
||||||
|
{
|
||||||
|
"params": [
|
||||||
|
self._params[var]
|
||||||
|
for var in self.problem.unknown_variables
|
||||||
|
]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return self.optimizers, [self.scheduler]
|
||||||
|
|
||||||
|
|
||||||
|
@property
|
||||||
|
def scheduler(self):
|
||||||
|
"""
|
||||||
|
Scheduler for the PINN training.
|
||||||
|
"""
|
||||||
|
return self._scheduler
|
||||||
|
|
||||||
|
|
||||||
|
@property
|
||||||
|
def neural_net(self):
|
||||||
|
"""
|
||||||
|
Neural network for the PINN training.
|
||||||
|
"""
|
||||||
|
return self._neural_net
|
||||||
494
pina/solvers/pinns/sapinn.py
Normal file
494
pina/solvers/pinns/sapinn.py
Normal file
@@ -0,0 +1,494 @@
|
|||||||
|
import torch
|
||||||
|
from copy import deepcopy
|
||||||
|
|
||||||
|
try:
|
||||||
|
from torch.optim.lr_scheduler import LRScheduler # torch >= 2.0
|
||||||
|
except ImportError:
|
||||||
|
from torch.optim.lr_scheduler import (
|
||||||
|
_LRScheduler as LRScheduler,
|
||||||
|
) # torch < 2.0
|
||||||
|
|
||||||
|
from .basepinn import PINNInterface
|
||||||
|
from pina.utils import check_consistency
|
||||||
|
from pina.problem import InverseProblem
|
||||||
|
|
||||||
|
from torch.optim.lr_scheduler import ConstantLR
|
||||||
|
|
||||||
|
class Weights(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
This class aims to implements the mask model for
|
||||||
|
self adaptive weights of the Self-Adaptive
|
||||||
|
PINN solver.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, func):
|
||||||
|
"""
|
||||||
|
:param torch.nn.Module func: the mask module of SAPINN
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
check_consistency(func, torch.nn.Module)
|
||||||
|
self.sa_weights = torch.nn.Parameter(
|
||||||
|
torch.Tensor()
|
||||||
|
)
|
||||||
|
self.func = func
|
||||||
|
|
||||||
|
def forward(self):
|
||||||
|
"""
|
||||||
|
Forward pass implementation for the mask module.
|
||||||
|
It returns the function on the weights
|
||||||
|
evaluation.
|
||||||
|
|
||||||
|
:return: evaluation of self adaptive weights through the mask.
|
||||||
|
:rtype: torch.Tensor
|
||||||
|
"""
|
||||||
|
return self.func(self.sa_weights)
|
||||||
|
|
||||||
|
class SAPINN(PINNInterface):
|
||||||
|
r"""
|
||||||
|
Self Adaptive Physics Informed Neural Network (SAPINN) solver class.
|
||||||
|
This class implements Self-Adaptive Physics Informed Neural
|
||||||
|
Network solvers, using a user specified ``model`` to solve a specific
|
||||||
|
``problem``. It can be used for solving both forward and inverse problems.
|
||||||
|
|
||||||
|
The Self Adapive Physics Informed Neural Network aims to find
|
||||||
|
the solution :math:`\mathbf{u}:\Omega\rightarrow\mathbb{R}^m`
|
||||||
|
of the differential problem:
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
|
||||||
|
\begin{cases}
|
||||||
|
\mathcal{A}[\mathbf{u}](\mathbf{x})=0\quad,\mathbf{x}\in\Omega\\
|
||||||
|
\mathcal{B}[\mathbf{u}](\mathbf{x})=0\quad,
|
||||||
|
\mathbf{x}\in\partial\Omega
|
||||||
|
\end{cases}
|
||||||
|
|
||||||
|
integrating the pointwise loss evaluation through a mask :math:`m` and
|
||||||
|
self adaptive weights that permit to focus the loss function on
|
||||||
|
specific training samples.
|
||||||
|
The loss function to solve the problem is
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
|
||||||
|
\mathcal{L}_{\rm{problem}} = \frac{1}{N} \sum_{i=1}^{N_\Omega} m
|
||||||
|
\left( \lambda_{\Omega}^{i} \right) \mathcal{L} \left( \mathcal{A}
|
||||||
|
[\mathbf{u}](\mathbf{x}) \right) + \frac{1}{N}
|
||||||
|
\sum_{i=1}^{N_{\partial\Omega}}
|
||||||
|
m \left( \lambda_{\partial\Omega}^{i} \right) \mathcal{L}
|
||||||
|
\left( \mathcal{B}[\mathbf{u}](\mathbf{x})
|
||||||
|
\right),
|
||||||
|
|
||||||
|
|
||||||
|
denoting the self adaptive weights as
|
||||||
|
:math:`\lambda_{\Omega}^1, \dots, \lambda_{\Omega}^{N_\Omega}` and
|
||||||
|
:math:`\lambda_{\partial \Omega}^1, \dots,
|
||||||
|
\lambda_{\Omega}^{N_\partial \Omega}`
|
||||||
|
for :math:`\Omega` and :math:`\partial \Omega`, respectively.
|
||||||
|
|
||||||
|
Self Adaptive Physics Informed Neural Network identifies the solution
|
||||||
|
and appropriate self adaptive weights by solving the following problem
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
|
||||||
|
\min_{w} \max_{\lambda_{\Omega}^k, \lambda_{\partial \Omega}^s}
|
||||||
|
\mathcal{L} ,
|
||||||
|
|
||||||
|
where :math:`w` denotes the network parameters, and
|
||||||
|
:math:`\mathcal{L}` is a specific loss
|
||||||
|
function, default Mean Square Error:
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
\mathcal{L}(v) = \| v \|^2_2.
|
||||||
|
|
||||||
|
.. seealso::
|
||||||
|
**Original reference**: McClenny, Levi D., and Ulisses M. Braga-Neto.
|
||||||
|
"Self-adaptive physics-informed neural networks."
|
||||||
|
Journal of Computational Physics 474 (2023): 111722.
|
||||||
|
DOI: `10.1016/
|
||||||
|
j.jcp.2022.111722 <https://doi.org/10.1016/j.jcp.2022.111722>`_.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
problem,
|
||||||
|
model,
|
||||||
|
weights_function=torch.nn.Sigmoid(),
|
||||||
|
extra_features=None,
|
||||||
|
loss=torch.nn.MSELoss(),
|
||||||
|
optimizer_model=torch.optim.Adam,
|
||||||
|
optimizer_model_kwargs={"lr" : 0.001},
|
||||||
|
optimizer_weights=torch.optim.Adam,
|
||||||
|
optimizer_weights_kwargs={"lr" : 0.001},
|
||||||
|
scheduler_model=ConstantLR,
|
||||||
|
scheduler_model_kwargs={"factor" : 1, "total_iters" : 0},
|
||||||
|
scheduler_weights=ConstantLR,
|
||||||
|
scheduler_weights_kwargs={"factor" : 1, "total_iters" : 0}
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
:param AbstractProblem problem: The formualation of the problem.
|
||||||
|
:param torch.nn.Module model: The neural network model to use
|
||||||
|
for the model.
|
||||||
|
:param torch.nn.Module weights_function: The neural network model
|
||||||
|
related to the mask of SAPINN.
|
||||||
|
default :obj:`~torch.nn.Sigmoid`.
|
||||||
|
:param list(torch.nn.Module) extra_features: The additional input
|
||||||
|
features to use as augmented input. If ``None`` no extra features
|
||||||
|
are passed. If it is a list of :class:`torch.nn.Module`,
|
||||||
|
the extra feature list is passed to all models. If it is a list
|
||||||
|
of extra features' lists, each single list of extra feature
|
||||||
|
is passed to a model.
|
||||||
|
:param torch.nn.Module loss: The loss function used as minimizer,
|
||||||
|
default :class:`torch.nn.MSELoss`.
|
||||||
|
:param torch.optim.Optimizer optimizer_model: The neural
|
||||||
|
network optimizer to use for the model network
|
||||||
|
, default is `torch.optim.Adam`.
|
||||||
|
:param dict optimizer_model_kwargs: Optimizer constructor keyword
|
||||||
|
args. for the model.
|
||||||
|
:param torch.optim.Optimizer optimizer_weights: The neural
|
||||||
|
network optimizer to use for mask model model,
|
||||||
|
default is `torch.optim.Adam`.
|
||||||
|
:param dict optimizer_weights_kwargs: Optimizer constructor
|
||||||
|
keyword args. for the mask module.
|
||||||
|
:param torch.optim.LRScheduler scheduler_model: Learning
|
||||||
|
rate scheduler for the model.
|
||||||
|
:param dict scheduler_model_kwargs: LR scheduler constructor
|
||||||
|
keyword args.
|
||||||
|
:param torch.optim.LRScheduler scheduler_weights: Learning
|
||||||
|
rate scheduler for the mask model.
|
||||||
|
:param dict scheduler_model_kwargs: LR scheduler constructor
|
||||||
|
keyword args.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# check consistency weitghs_function
|
||||||
|
check_consistency(weights_function, torch.nn.Module)
|
||||||
|
|
||||||
|
# create models for weights
|
||||||
|
weights_dict = {}
|
||||||
|
for condition_name in problem.conditions:
|
||||||
|
weights_dict[condition_name] = Weights(weights_function)
|
||||||
|
weights_dict = torch.nn.ModuleDict(weights_dict)
|
||||||
|
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
models=[model, weights_dict],
|
||||||
|
problem=problem,
|
||||||
|
optimizers=[optimizer_model, optimizer_weights],
|
||||||
|
optimizers_kwargs=[
|
||||||
|
optimizer_model_kwargs,
|
||||||
|
optimizer_weights_kwargs
|
||||||
|
],
|
||||||
|
extra_features=extra_features,
|
||||||
|
loss=loss
|
||||||
|
)
|
||||||
|
|
||||||
|
# set automatic optimization
|
||||||
|
self.automatic_optimization = False
|
||||||
|
|
||||||
|
# check consistency
|
||||||
|
check_consistency(scheduler_model, LRScheduler, subclass=True)
|
||||||
|
check_consistency(scheduler_model_kwargs, dict)
|
||||||
|
check_consistency(scheduler_weights, LRScheduler, subclass=True)
|
||||||
|
check_consistency(scheduler_weights_kwargs, dict)
|
||||||
|
|
||||||
|
# assign schedulers
|
||||||
|
self._schedulers = [
|
||||||
|
scheduler_model(
|
||||||
|
self.optimizers[0], **scheduler_model_kwargs
|
||||||
|
),
|
||||||
|
scheduler_weights(
|
||||||
|
self.optimizers[1], **scheduler_weights_kwargs
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
self._model = self.models[0]
|
||||||
|
self._weights = self.models[1]
|
||||||
|
|
||||||
|
self._vectorial_loss = deepcopy(loss)
|
||||||
|
self._vectorial_loss.reduction = "none"
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""
|
||||||
|
Forward pass implementation for the PINN
|
||||||
|
solver. It returns the function
|
||||||
|
evaluation :math:`\mathbf{u}(\mathbf{x})` at the control points
|
||||||
|
:math:`\mathbf{x}`.
|
||||||
|
|
||||||
|
:param LabelTensor x: Input tensor for the SAPINN solver. It expects
|
||||||
|
a tensor :math:`N \\times D`, where :math:`N` the number of points
|
||||||
|
in the mesh, :math:`D` the dimension of the problem,
|
||||||
|
:return: PINN solution.
|
||||||
|
:rtype: LabelTensor
|
||||||
|
"""
|
||||||
|
return self.neural_net(x)
|
||||||
|
|
||||||
|
def loss_phys(self, samples, equation):
|
||||||
|
"""
|
||||||
|
Computes the physics loss for the SAPINN solver based on given
|
||||||
|
samples and equation.
|
||||||
|
|
||||||
|
:param LabelTensor samples: The samples to evaluate the physics loss.
|
||||||
|
:param EquationInterface equation: The governing equation
|
||||||
|
representing the physics.
|
||||||
|
:return: The physics loss calculated based on given
|
||||||
|
samples and equation.
|
||||||
|
:rtype: torch.Tensor
|
||||||
|
"""
|
||||||
|
# train weights
|
||||||
|
self.optimizer_weights.zero_grad()
|
||||||
|
weighted_loss, _ = self._loss_phys(samples, equation)
|
||||||
|
loss_value = - weighted_loss.as_subclass(torch.Tensor)
|
||||||
|
self.manual_backward(loss_value)
|
||||||
|
self.optimizer_weights.step()
|
||||||
|
|
||||||
|
# detaching samples from the computational graph to erase it and setting
|
||||||
|
# the gradient to true to create a new computational graph.
|
||||||
|
# In alternative set `retain_graph=True`.
|
||||||
|
samples = samples.detach()
|
||||||
|
samples.requires_grad = True
|
||||||
|
|
||||||
|
# train model
|
||||||
|
self.optimizer_model.zero_grad()
|
||||||
|
weighted_loss, loss = self._loss_phys(samples, equation)
|
||||||
|
loss_value = weighted_loss.as_subclass(torch.Tensor)
|
||||||
|
self.manual_backward(loss_value)
|
||||||
|
self.optimizer_model.step()
|
||||||
|
|
||||||
|
# store loss without weights
|
||||||
|
self.store_log(loss_value=float(loss))
|
||||||
|
return loss_value
|
||||||
|
|
||||||
|
def loss_data(self, input_tensor, output_tensor):
|
||||||
|
"""
|
||||||
|
Computes the data loss for the SAPINN solver based on input and
|
||||||
|
output. It computes the loss between the
|
||||||
|
network output against the true solution.
|
||||||
|
|
||||||
|
:param LabelTensor input_tensor: The input to the neural networks.
|
||||||
|
:param LabelTensor output_tensor: The true solution to compare the
|
||||||
|
network solution.
|
||||||
|
:return: The computed data loss.
|
||||||
|
:rtype: torch.Tensor
|
||||||
|
"""
|
||||||
|
# train weights
|
||||||
|
self.optimizer_weights.zero_grad()
|
||||||
|
weighted_loss, _ = self._loss_data(input_tensor, output_tensor)
|
||||||
|
loss_value = - weighted_loss.as_subclass(torch.Tensor)
|
||||||
|
self.manual_backward(loss_value)
|
||||||
|
self.optimizer_weights.step()
|
||||||
|
|
||||||
|
# detaching samples from the computational graph to erase it and setting
|
||||||
|
# the gradient to true to create a new computational graph.
|
||||||
|
# In alternative set `retain_graph=True`.
|
||||||
|
input_tensor = input_tensor.detach()
|
||||||
|
input_tensor.requires_grad = True
|
||||||
|
|
||||||
|
# train model
|
||||||
|
self.optimizer_model.zero_grad()
|
||||||
|
weighted_loss, loss = self._loss_data(input_tensor, output_tensor)
|
||||||
|
loss_value = weighted_loss.as_subclass(torch.Tensor)
|
||||||
|
self.manual_backward(loss_value)
|
||||||
|
self.optimizer_model.step()
|
||||||
|
|
||||||
|
# store loss without weights
|
||||||
|
self.store_log(loss_value=float(loss))
|
||||||
|
return loss_value
|
||||||
|
|
||||||
|
def configure_optimizers(self):
|
||||||
|
"""
|
||||||
|
Optimizer configuration for the SAPINN
|
||||||
|
solver.
|
||||||
|
|
||||||
|
:return: The optimizers and the schedulers
|
||||||
|
:rtype: tuple(list, list)
|
||||||
|
"""
|
||||||
|
# if the problem is an InverseProblem, add the unknown parameters
|
||||||
|
# to the parameters that the optimizer needs to optimize
|
||||||
|
if isinstance(self.problem, InverseProblem):
|
||||||
|
self.optimizers[0].add_param_group(
|
||||||
|
{
|
||||||
|
"params": [
|
||||||
|
self._params[var]
|
||||||
|
for var in self.problem.unknown_variables
|
||||||
|
]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return self.optimizers, self._schedulers
|
||||||
|
|
||||||
|
def on_train_batch_end(self,outputs, batch, batch_idx):
|
||||||
|
"""
|
||||||
|
This method is called at the end of each training batch, and ovverides
|
||||||
|
the PytorchLightining implementation for logging the checkpoints.
|
||||||
|
|
||||||
|
:param torch.Tensor outputs: The output from the model for the
|
||||||
|
current batch.
|
||||||
|
:param tuple batch: The current batch of data.
|
||||||
|
:param int batch_idx: The index of the current batch.
|
||||||
|
:return: Whatever is returned by the parent
|
||||||
|
method ``on_train_batch_end``.
|
||||||
|
:rtype: Any
|
||||||
|
"""
|
||||||
|
# increase by one the counter of optimization to save loggers
|
||||||
|
self.trainer.fit_loop.epoch_loop.manual_optimization.optim_step_progress.total.completed += 1
|
||||||
|
return super().on_train_batch_end(outputs, batch, batch_idx)
|
||||||
|
|
||||||
|
def on_train_start(self):
|
||||||
|
"""
|
||||||
|
This method is called at the start of the training for setting
|
||||||
|
the self adaptive weights as parameters of the mask model.
|
||||||
|
|
||||||
|
:return: Whatever is returned by the parent
|
||||||
|
method ``on_train_start``.
|
||||||
|
:rtype: Any
|
||||||
|
"""
|
||||||
|
device = torch.device(
|
||||||
|
self.trainer._accelerator_connector._accelerator_flag
|
||||||
|
)
|
||||||
|
for condition_name, tensor in self.problem.input_pts.items():
|
||||||
|
self.weights_dict.torchmodel[condition_name].sa_weights.data = torch.rand(
|
||||||
|
(tensor.shape[0], 1),
|
||||||
|
device = device
|
||||||
|
)
|
||||||
|
return super().on_train_start()
|
||||||
|
|
||||||
|
def on_load_checkpoint(self, checkpoint):
|
||||||
|
"""
|
||||||
|
Overriding the Pytorch Lightning ``on_load_checkpoint`` to handle
|
||||||
|
checkpoints for Self Adaptive Weights. This method should not be
|
||||||
|
overridden if not intentionally.
|
||||||
|
|
||||||
|
:param dict checkpoint: Pytorch Lightning checkpoint dict.
|
||||||
|
"""
|
||||||
|
for condition_name, tensor in self.problem.input_pts.items():
|
||||||
|
self.weights_dict.torchmodel[condition_name].sa_weights.data = torch.rand(
|
||||||
|
(tensor.shape[0], 1)
|
||||||
|
)
|
||||||
|
return super().on_load_checkpoint(checkpoint)
|
||||||
|
|
||||||
|
def _loss_phys(self, samples, equation):
|
||||||
|
"""
|
||||||
|
Elaboration of the physical loss for the SAPINN solver.
|
||||||
|
|
||||||
|
:param LabelTensor samples: Input samples to evaluate the physics loss.
|
||||||
|
:param EquationInterface equation: the governing equation representing
|
||||||
|
the physics.
|
||||||
|
|
||||||
|
:return: tuple with weighted and not weighted scalar loss
|
||||||
|
:rtype: List[LabelTensor, LabelTensor]
|
||||||
|
"""
|
||||||
|
residual = self.compute_residual(samples, equation)
|
||||||
|
return self._compute_loss(residual)
|
||||||
|
|
||||||
|
def _loss_data(self, input_tensor, output_tensor):
|
||||||
|
"""
|
||||||
|
Elaboration of the loss related to data for the SAPINN solver.
|
||||||
|
|
||||||
|
:param LabelTensor input_tensor: The input to the neural networks.
|
||||||
|
:param LabelTensor output_tensor: The true solution to compare the
|
||||||
|
network solution.
|
||||||
|
|
||||||
|
:return: tuple with weighted and not weighted scalar loss
|
||||||
|
:rtype: List[LabelTensor, LabelTensor]
|
||||||
|
"""
|
||||||
|
residual = self.forward(input_tensor) - output_tensor
|
||||||
|
return self._compute_loss(residual)
|
||||||
|
|
||||||
|
def _compute_loss(self, residual):
|
||||||
|
"""
|
||||||
|
Elaboration of the pointwise loss through the mask model and the
|
||||||
|
self adaptive weights
|
||||||
|
|
||||||
|
:param LabelTensor residual: the matrix of residuals that have to
|
||||||
|
be weighted
|
||||||
|
|
||||||
|
:return: tuple with weighted and not weighted loss
|
||||||
|
:rtype List[LabelTensor, LabelTensor]
|
||||||
|
"""
|
||||||
|
weights = self.weights_dict.torchmodel[
|
||||||
|
self.current_condition_name].forward()
|
||||||
|
loss_value = self._vectorial_loss(torch.zeros_like(
|
||||||
|
residual, requires_grad=True), residual)
|
||||||
|
return (
|
||||||
|
self._vect_to_scalar(weights * loss_value),
|
||||||
|
self._vect_to_scalar(loss_value)
|
||||||
|
)
|
||||||
|
|
||||||
|
def _vect_to_scalar(self, loss_value):
|
||||||
|
"""
|
||||||
|
Elaboration of the pointwise loss through the mask model and the
|
||||||
|
self adaptive weights
|
||||||
|
|
||||||
|
:param LabelTensor loss_value: the matrix of pointwise loss
|
||||||
|
|
||||||
|
:return: the scalar loss
|
||||||
|
:rtype LabelTensor
|
||||||
|
"""
|
||||||
|
if self.loss.reduction == "mean":
|
||||||
|
ret = torch.mean(loss_value)
|
||||||
|
elif self.loss.reduction == "sum":
|
||||||
|
ret = torch.sum(loss_value)
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"Invalid reduction, got {self.loss.reduction} "
|
||||||
|
"but expected mean or sum.")
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
@property
|
||||||
|
def neural_net(self):
|
||||||
|
"""
|
||||||
|
Returns the neural network model.
|
||||||
|
|
||||||
|
:return: The neural network model.
|
||||||
|
:rtype: torch.nn.Module
|
||||||
|
"""
|
||||||
|
return self.models[0]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def weights_dict(self):
|
||||||
|
"""
|
||||||
|
Return the mask models associate to the application of
|
||||||
|
the mask to the self adaptive weights for each loss that
|
||||||
|
compones the global loss of the problem.
|
||||||
|
|
||||||
|
:return: The ModuleDict for mask models.
|
||||||
|
:rtype: torch.nn.ModuleDict
|
||||||
|
"""
|
||||||
|
return self.models[1]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def scheduler_model(self):
|
||||||
|
"""
|
||||||
|
Returns the scheduler associated with the neural network model.
|
||||||
|
|
||||||
|
:return: The scheduler for the neural network model.
|
||||||
|
:rtype: torch.optim.lr_scheduler._LRScheduler
|
||||||
|
"""
|
||||||
|
return self._scheduler[0]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def scheduler_weights(self):
|
||||||
|
"""
|
||||||
|
Returns the scheduler associated with the mask model (if applicable).
|
||||||
|
|
||||||
|
:return: The scheduler for the mask model.
|
||||||
|
:rtype: torch.optim.lr_scheduler._LRScheduler
|
||||||
|
"""
|
||||||
|
return self._scheduler[1]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def optimizer_model(self):
|
||||||
|
"""
|
||||||
|
Returns the optimizer associated with the neural network model.
|
||||||
|
|
||||||
|
:return: The optimizer for the neural network model.
|
||||||
|
:rtype: torch.optim.Optimizer
|
||||||
|
"""
|
||||||
|
return self.optimizers[0]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def optimizer_weights(self):
|
||||||
|
"""
|
||||||
|
Returns the optimizer associated with the mask model (if applicable).
|
||||||
|
|
||||||
|
:return: The optimizer for the mask model.
|
||||||
|
:rtype: torch.optim.Optimizer
|
||||||
|
"""
|
||||||
|
return self.optimizers[1]
|
||||||
190
pina/solvers/rom.py
Normal file
190
pina/solvers/rom.py
Normal file
@@ -0,0 +1,190 @@
|
|||||||
|
""" Module for ReducedOrderModelSolver """
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from pina.solvers import SupervisedSolver
|
||||||
|
|
||||||
|
class ReducedOrderModelSolver(SupervisedSolver):
|
||||||
|
r"""
|
||||||
|
ReducedOrderModelSolver solver class. This class implements a
|
||||||
|
Reduced Order Model solver, using user specified ``reduction_network`` and
|
||||||
|
``interpolation_network`` to solve a specific ``problem``.
|
||||||
|
|
||||||
|
The Reduced Order Model approach aims to find
|
||||||
|
the solution :math:`\mathbf{u}:\Omega\rightarrow\mathbb{R}^m`
|
||||||
|
of the differential problem:
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
|
||||||
|
\begin{cases}
|
||||||
|
\mathcal{A}[\mathbf{u}(\mu)](\mathbf{x})=0\quad,\mathbf{x}\in\Omega\\
|
||||||
|
\mathcal{B}[\mathbf{u}(\mu)](\mathbf{x})=0\quad,
|
||||||
|
\mathbf{x}\in\partial\Omega
|
||||||
|
\end{cases}
|
||||||
|
|
||||||
|
This is done by using two neural networks. The ``reduction_network``, which
|
||||||
|
contains an encoder :math:`\mathcal{E}_{\rm{net}}`, a decoder
|
||||||
|
:math:`\mathcal{D}_{\rm{net}}`; and an ``interpolation_network``
|
||||||
|
:math:`\mathcal{I}_{\rm{net}}`. The input is assumed to be discretised in
|
||||||
|
the spatial dimensions.
|
||||||
|
|
||||||
|
The following loss function is minimized during training
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
\mathcal{L}_{\rm{problem}} = \frac{1}{N}\sum_{i=1}^N
|
||||||
|
\mathcal{L}(\mathcal{E}_{\rm{net}}[\mathbf{u}(\mu_i)] -
|
||||||
|
\mathcal{I}_{\rm{net}}[\mu_i]) +
|
||||||
|
\mathcal{L}(
|
||||||
|
\mathcal{D}_{\rm{net}}[\mathcal{E}_{\rm{net}}[\mathbf{u}(\mu_i)]] -
|
||||||
|
\mathbf{u}(\mu_i))
|
||||||
|
|
||||||
|
where :math:`\mathcal{L}` is a specific loss function, default Mean Square Error:
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
\mathcal{L}(v) = \| v \|^2_2.
|
||||||
|
|
||||||
|
|
||||||
|
.. seealso::
|
||||||
|
|
||||||
|
**Original reference**: Hesthaven, Jan S., and Stefano Ubbiali.
|
||||||
|
"Non-intrusive reduced order modeling of nonlinear problems
|
||||||
|
using neural networks." Journal of Computational
|
||||||
|
Physics 363 (2018): 55-78.
|
||||||
|
DOI `10.1016/j.jcp.2018.02.037
|
||||||
|
<https://doi.org/10.1016/j.jcp.2018.02.037>`_.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
The specified ``reduction_network`` must contain two methods,
|
||||||
|
namely ``encode`` for input encoding and ``decode`` for decoding the
|
||||||
|
former result. The ``interpolation_network`` network ``forward`` output
|
||||||
|
represents the interpolation of the latent space obtain with
|
||||||
|
``reduction_network.encode``.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
This solver uses the end-to-end training strategy, i.e. the
|
||||||
|
``reduction_network`` and ``interpolation_network`` are trained
|
||||||
|
simultaneously. For reference on this trainig strategy look at:
|
||||||
|
Pichi, Federico, Beatriz Moya, and Jan S. Hesthaven.
|
||||||
|
"A graph convolutional autoencoder approach to model order reduction
|
||||||
|
for parametrized PDEs." Journal of
|
||||||
|
Computational Physics 501 (2024): 112762.
|
||||||
|
DOI
|
||||||
|
`10.1016/j.jcp.2024.112762 <https://doi.org/10.1016/
|
||||||
|
j.jcp.2024.112762>`_.
|
||||||
|
|
||||||
|
.. warning::
|
||||||
|
This solver works only for data-driven model. Hence in the ``problem``
|
||||||
|
definition the codition must only contain ``input_points``
|
||||||
|
(e.g. coefficient parameters, time parameters), and ``output_points``.
|
||||||
|
|
||||||
|
.. warning::
|
||||||
|
This solver does not currently support the possibility to pass
|
||||||
|
``extra_feature``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
problem,
|
||||||
|
reduction_network,
|
||||||
|
interpolation_network,
|
||||||
|
loss=torch.nn.MSELoss(),
|
||||||
|
optimizer=torch.optim.Adam,
|
||||||
|
optimizer_kwargs={"lr": 0.001},
|
||||||
|
scheduler=torch.optim.lr_scheduler.ConstantLR,
|
||||||
|
scheduler_kwargs={"factor": 1, "total_iters": 0},
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
:param AbstractProblem problem: The formualation of the problem.
|
||||||
|
:param torch.nn.Module reduction_network: The reduction network used
|
||||||
|
for reducing the input space. It must contain two methods,
|
||||||
|
namely ``encode`` for input encoding and ``decode`` for decoding the
|
||||||
|
former result.
|
||||||
|
:param torch.nn.Module interpolation_network: The interpolation network
|
||||||
|
for interpolating the control parameters to latent space obtain by
|
||||||
|
the ``reduction_network`` encoding.
|
||||||
|
:param torch.nn.Module loss: The loss function used as minimizer,
|
||||||
|
default :class:`torch.nn.MSELoss`.
|
||||||
|
:param torch.nn.Module extra_features: The additional input
|
||||||
|
features to use as augmented input.
|
||||||
|
:param torch.optim.Optimizer optimizer: The neural network optimizer to
|
||||||
|
use; default is :class:`torch.optim.Adam`.
|
||||||
|
:param dict optimizer_kwargs: Optimizer constructor keyword args.
|
||||||
|
:param float lr: The learning rate; default is 0.001.
|
||||||
|
:param torch.optim.LRScheduler scheduler: Learning
|
||||||
|
rate scheduler.
|
||||||
|
:param dict scheduler_kwargs: LR scheduler constructor keyword args.
|
||||||
|
"""
|
||||||
|
model = torch.nn.ModuleDict({
|
||||||
|
'reduction_network' : reduction_network,
|
||||||
|
'interpolation_network' : interpolation_network})
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
model=model,
|
||||||
|
problem=problem,
|
||||||
|
loss=loss,
|
||||||
|
optimizer=optimizer,
|
||||||
|
optimizer_kwargs=optimizer_kwargs,
|
||||||
|
scheduler=scheduler,
|
||||||
|
scheduler_kwargs=scheduler_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
# assert reduction object contains encode/ decode
|
||||||
|
if not hasattr(self.neural_net['reduction_network'], 'encode'):
|
||||||
|
raise SyntaxError('reduction_network must have encode method. '
|
||||||
|
'The encode method should return a lower '
|
||||||
|
'dimensional representation of the input.')
|
||||||
|
if not hasattr(self.neural_net['reduction_network'], 'decode'):
|
||||||
|
raise SyntaxError('reduction_network must have decode method. '
|
||||||
|
'The decode method should return a high '
|
||||||
|
'dimensional representation of the encoding.')
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""
|
||||||
|
Forward pass implementation for the solver. It finds the encoder
|
||||||
|
representation by calling ``interpolation_network.forward`` on the
|
||||||
|
input, and maps this representation to output space by calling
|
||||||
|
``reduction_network.decode``.
|
||||||
|
|
||||||
|
:param torch.Tensor x: Input tensor.
|
||||||
|
:return: Solver solution.
|
||||||
|
:rtype: torch.Tensor
|
||||||
|
"""
|
||||||
|
reduction_network = self.neural_net['reduction_network']
|
||||||
|
interpolation_network = self.neural_net['interpolation_network']
|
||||||
|
return reduction_network.decode(interpolation_network(x))
|
||||||
|
|
||||||
|
def loss_data(self, input_pts, output_pts):
|
||||||
|
"""
|
||||||
|
The data loss for the ReducedOrderModelSolver solver.
|
||||||
|
It computes the loss between
|
||||||
|
the network output against the true solution. This function
|
||||||
|
should not be override if not intentionally.
|
||||||
|
|
||||||
|
:param LabelTensor input_tensor: The input to the neural networks.
|
||||||
|
:param LabelTensor output_tensor: The true solution to compare the
|
||||||
|
network solution.
|
||||||
|
:return: The residual loss averaged on the input coordinates
|
||||||
|
:rtype: torch.Tensor
|
||||||
|
"""
|
||||||
|
# extract networks
|
||||||
|
reduction_network = self.neural_net['reduction_network']
|
||||||
|
interpolation_network = self.neural_net['interpolation_network']
|
||||||
|
# encoded representations loss
|
||||||
|
encode_repr_inter_net = interpolation_network(input_pts)
|
||||||
|
encode_repr_reduction_network = reduction_network.encode(output_pts)
|
||||||
|
loss_encode = self.loss(encode_repr_inter_net,
|
||||||
|
encode_repr_reduction_network)
|
||||||
|
# reconstruction loss
|
||||||
|
loss_reconstruction = self.loss(
|
||||||
|
reduction_network.decode(encode_repr_reduction_network),
|
||||||
|
output_pts)
|
||||||
|
|
||||||
|
return loss_encode + loss_reconstruction
|
||||||
|
|
||||||
|
@property
|
||||||
|
def neural_net(self):
|
||||||
|
"""
|
||||||
|
Neural network for training. It returns a :obj:`~torch.nn.ModuleDict`
|
||||||
|
containing the ``reduction_network`` and ``interpolation_network``.
|
||||||
|
"""
|
||||||
|
return self._neural_net.torchmodel
|
||||||
@@ -6,6 +6,7 @@ import pytorch_lightning
|
|||||||
from ..utils import check_consistency
|
from ..utils import check_consistency
|
||||||
from ..problem import AbstractProblem
|
from ..problem import AbstractProblem
|
||||||
import torch
|
import torch
|
||||||
|
import sys
|
||||||
|
|
||||||
|
|
||||||
class SolverInterface(pytorch_lightning.LightningModule, metaclass=ABCMeta):
|
class SolverInterface(pytorch_lightning.LightningModule, metaclass=ABCMeta):
|
||||||
@@ -142,6 +143,20 @@ class SolverInterface(pytorch_lightning.LightningModule, metaclass=ABCMeta):
|
|||||||
The problem formulation."""
|
The problem formulation."""
|
||||||
return self._pina_problem
|
return self._pina_problem
|
||||||
|
|
||||||
|
def on_train_start(self):
|
||||||
|
"""
|
||||||
|
On training epoch start this function is call to do global checks for
|
||||||
|
the different solvers.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# 1. Check the verison for dataloader
|
||||||
|
dataloader = self.trainer.train_dataloader
|
||||||
|
if sys.version_info < (3, 8):
|
||||||
|
dataloader = dataloader.loaders
|
||||||
|
self._dataloader = dataloader
|
||||||
|
|
||||||
|
return super().on_train_start()
|
||||||
|
|
||||||
# @model.setter
|
# @model.setter
|
||||||
# def model(self, new_model):
|
# def model(self, new_model):
|
||||||
# """
|
# """
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
""" Module for SupervisedSolver """
|
""" Module for SupervisedSolver """
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import sys
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from torch.optim.lr_scheduler import LRScheduler # torch >= 2.0
|
from torch.optim.lr_scheduler import LRScheduler # torch >= 2.0
|
||||||
@@ -20,9 +19,32 @@ from torch.nn.modules.loss import _Loss
|
|||||||
|
|
||||||
|
|
||||||
class SupervisedSolver(SolverInterface):
|
class SupervisedSolver(SolverInterface):
|
||||||
"""
|
r"""
|
||||||
SupervisedSolver solver class. This class implements a SupervisedSolver,
|
SupervisedSolver solver class. This class implements a SupervisedSolver,
|
||||||
using a user specified ``model`` to solve a specific ``problem``.
|
using a user specified ``model`` to solve a specific ``problem``.
|
||||||
|
|
||||||
|
The Supervised Solver class aims to find
|
||||||
|
a map between the input :math:`\mathbf{s}:\Omega\rightarrow\mathbb{R}^m`
|
||||||
|
and the output :math:`\mathbf{u}:\Omega\rightarrow\mathbb{R}^m`. The input
|
||||||
|
can be discretised in space (as in :obj:`~pina.solvers.rom.ROMe2eSolver`),
|
||||||
|
or not (e.g. when training Neural Operators).
|
||||||
|
|
||||||
|
Given a model :math:`\mathcal{M}`, the following loss function is
|
||||||
|
minimized during training:
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
\mathcal{L}_{\rm{problem}} = \frac{1}{N}\sum_{i=1}^N
|
||||||
|
\mathcal{L}(\mathbf{u}_i - \mathcal{M}(\mathbf{v}_i))
|
||||||
|
|
||||||
|
where :math:`\mathcal{L}` is a specific loss function,
|
||||||
|
default Mean Square Error:
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
\mathcal{L}(v) = \| v \|^2_2.
|
||||||
|
|
||||||
|
In this context :math:`\mathbf{u}_i` and :math:`\mathbf{v}_i` means that
|
||||||
|
we are seeking to approximate multiple (discretised) functions given
|
||||||
|
multiple (discretised) input functions.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -97,17 +119,11 @@ class SupervisedSolver(SolverInterface):
|
|||||||
:rtype: LabelTensor
|
:rtype: LabelTensor
|
||||||
"""
|
"""
|
||||||
|
|
||||||
dataloader = self.trainer.train_dataloader
|
|
||||||
condition_idx = batch["condition"]
|
condition_idx = batch["condition"]
|
||||||
|
|
||||||
for condition_id in range(condition_idx.min(), condition_idx.max() + 1):
|
for condition_id in range(condition_idx.min(), condition_idx.max() + 1):
|
||||||
|
|
||||||
if sys.version_info >= (3, 8):
|
condition_name = self._dataloader.condition_names[condition_id]
|
||||||
condition_name = dataloader.condition_names[condition_id]
|
|
||||||
else:
|
|
||||||
condition_name = dataloader.loaders.condition_names[
|
|
||||||
condition_id
|
|
||||||
]
|
|
||||||
condition = self.problem.conditions[condition_name]
|
condition = self.problem.conditions[condition_name]
|
||||||
pts = batch["pts"]
|
pts = batch["pts"]
|
||||||
out = batch["output"]
|
out = batch["output"]
|
||||||
@@ -118,14 +134,14 @@ class SupervisedSolver(SolverInterface):
|
|||||||
# for data driven mode
|
# for data driven mode
|
||||||
if not hasattr(condition, "output_points"):
|
if not hasattr(condition, "output_points"):
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"Supervised solver works only in data-driven mode."
|
f"{type(self).__name__} works only in data-driven mode."
|
||||||
)
|
)
|
||||||
|
|
||||||
output_pts = out[condition_idx == condition_id]
|
output_pts = out[condition_idx == condition_id]
|
||||||
input_pts = pts[condition_idx == condition_id]
|
input_pts = pts[condition_idx == condition_id]
|
||||||
|
|
||||||
loss = (
|
loss = (
|
||||||
self.loss(self.forward(input_pts), output_pts)
|
self.loss_data(input_pts=input_pts, output_pts=output_pts)
|
||||||
* condition.data_weight
|
* condition.data_weight
|
||||||
)
|
)
|
||||||
loss = loss.as_subclass(torch.Tensor)
|
loss = loss.as_subclass(torch.Tensor)
|
||||||
@@ -133,6 +149,20 @@ class SupervisedSolver(SolverInterface):
|
|||||||
self.log("mean_loss", float(loss), prog_bar=True, logger=True)
|
self.log("mean_loss", float(loss), prog_bar=True, logger=True)
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
def loss_data(self, input_pts, output_pts):
|
||||||
|
"""
|
||||||
|
The data loss for the Supervised solver. It computes the loss between
|
||||||
|
the network output against the true solution. This function
|
||||||
|
should not be override if not intentionally.
|
||||||
|
|
||||||
|
:param LabelTensor input_tensor: The input to the neural networks.
|
||||||
|
:param LabelTensor output_tensor: The true solution to compare the
|
||||||
|
network solution.
|
||||||
|
:return: The residual loss averaged on the input coordinates
|
||||||
|
:rtype: torch.Tensor
|
||||||
|
"""
|
||||||
|
return self.loss(self.forward(input_pts), output_pts)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def scheduler(self):
|
def scheduler(self):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
""" Trainer module. """
|
""" Trainer module. """
|
||||||
|
|
||||||
|
import torch
|
||||||
import pytorch_lightning
|
import pytorch_lightning
|
||||||
from .utils import check_consistency
|
from .utils import check_consistency
|
||||||
from .dataset import SamplePointDataset, SamplePointLoader, DataPointDataset
|
from .dataset import SamplePointDataset, SamplePointLoader, DataPointDataset
|
||||||
@@ -63,6 +64,12 @@ class Trainer(pytorch_lightning.Trainer):
|
|||||||
self._loader = SamplePointLoader(
|
self._loader = SamplePointLoader(
|
||||||
dataset_phys, dataset_data, batch_size=self.batch_size, shuffle=True
|
dataset_phys, dataset_data, batch_size=self.batch_size, shuffle=True
|
||||||
)
|
)
|
||||||
|
pb = self._model.problem
|
||||||
|
if hasattr(pb, "unknown_parameters"):
|
||||||
|
for key in pb.unknown_parameters:
|
||||||
|
pb.unknown_parameters[key] = torch.nn.Parameter(pb.unknown_parameters[key].data.to(device))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def train(self, **kwargs):
|
def train(self, **kwargs):
|
||||||
"""
|
"""
|
||||||
|
|||||||
266
tests/test_solvers/test_causalpinn.py
Normal file
266
tests/test_solvers/test_causalpinn.py
Normal file
@@ -0,0 +1,266 @@
|
|||||||
|
import torch
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from pina.problem import TimeDependentProblem, InverseProblem, SpatialProblem
|
||||||
|
from pina.operators import grad
|
||||||
|
from pina.geometry import CartesianDomain
|
||||||
|
from pina import Condition, LabelTensor
|
||||||
|
from pina.solvers import CausalPINN
|
||||||
|
from pina.trainer import Trainer
|
||||||
|
from pina.model import FeedForward
|
||||||
|
from pina.equation.equation import Equation
|
||||||
|
from pina.equation.equation_factory import FixedValue
|
||||||
|
from pina.loss import LpLoss
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class FooProblem(SpatialProblem):
|
||||||
|
'''
|
||||||
|
Foo problem formulation.
|
||||||
|
'''
|
||||||
|
output_variables = ['u']
|
||||||
|
conditions = {}
|
||||||
|
spatial_domain = None
|
||||||
|
|
||||||
|
|
||||||
|
class InverseDiffusionReactionSystem(TimeDependentProblem, SpatialProblem, InverseProblem):
|
||||||
|
|
||||||
|
def diffusionreaction(input_, output_, params_):
|
||||||
|
x = input_.extract('x')
|
||||||
|
t = input_.extract('t')
|
||||||
|
u_t = grad(output_, input_, d='t')
|
||||||
|
u_x = grad(output_, input_, d='x')
|
||||||
|
u_xx = grad(u_x, input_, d='x')
|
||||||
|
r = torch.exp(-t) * (1.5 * torch.sin(2*x) + (8/3)*torch.sin(3*x) +
|
||||||
|
(15/4)*torch.sin(4*x) + (63/8)*torch.sin(8*x))
|
||||||
|
return u_t - params_['mu']*u_xx - r
|
||||||
|
|
||||||
|
def _solution(self, pts):
|
||||||
|
t = pts.extract('t')
|
||||||
|
x = pts.extract('x')
|
||||||
|
return torch.exp(-t) * (torch.sin(x) + (1/2)*torch.sin(2*x) +
|
||||||
|
(1/3)*torch.sin(3*x) + (1/4)*torch.sin(4*x) +
|
||||||
|
(1/8)*torch.sin(8*x))
|
||||||
|
|
||||||
|
# assign output/ spatial and temporal variables
|
||||||
|
output_variables = ['u']
|
||||||
|
spatial_domain = CartesianDomain({'x': [-torch.pi, torch.pi]})
|
||||||
|
temporal_domain = CartesianDomain({'t': [0, 1]})
|
||||||
|
unknown_parameter_domain = CartesianDomain({'mu': [-1, 1]})
|
||||||
|
|
||||||
|
# problem condition statement
|
||||||
|
conditions = {
|
||||||
|
'D': Condition(location=CartesianDomain({'x': [-torch.pi, torch.pi],
|
||||||
|
't': [0, 1]}),
|
||||||
|
equation=Equation(diffusionreaction)),
|
||||||
|
'data' : Condition(input_points=LabelTensor(torch.tensor([[0., 0.]]), ['x', 't']),
|
||||||
|
output_points=LabelTensor(torch.tensor([[0.]]), ['u'])),
|
||||||
|
}
|
||||||
|
|
||||||
|
class DiffusionReactionSystem(TimeDependentProblem, SpatialProblem):
|
||||||
|
|
||||||
|
def diffusionreaction(input_, output_):
|
||||||
|
x = input_.extract('x')
|
||||||
|
t = input_.extract('t')
|
||||||
|
u_t = grad(output_, input_, d='t')
|
||||||
|
u_x = grad(output_, input_, d='x')
|
||||||
|
u_xx = grad(u_x, input_, d='x')
|
||||||
|
r = torch.exp(-t) * (1.5 * torch.sin(2*x) + (8/3)*torch.sin(3*x) +
|
||||||
|
(15/4)*torch.sin(4*x) + (63/8)*torch.sin(8*x))
|
||||||
|
return u_t - u_xx - r
|
||||||
|
|
||||||
|
def _solution(self, pts):
|
||||||
|
t = pts.extract('t')
|
||||||
|
x = pts.extract('x')
|
||||||
|
return torch.exp(-t) * (torch.sin(x) + (1/2)*torch.sin(2*x) +
|
||||||
|
(1/3)*torch.sin(3*x) + (1/4)*torch.sin(4*x) +
|
||||||
|
(1/8)*torch.sin(8*x))
|
||||||
|
|
||||||
|
# assign output/ spatial and temporal variables
|
||||||
|
output_variables = ['u']
|
||||||
|
spatial_domain = CartesianDomain({'x': [-torch.pi, torch.pi]})
|
||||||
|
temporal_domain = CartesianDomain({'t': [0, 1]})
|
||||||
|
|
||||||
|
# problem condition statement
|
||||||
|
conditions = {
|
||||||
|
'D': Condition(location=CartesianDomain({'x': [-torch.pi, torch.pi],
|
||||||
|
't': [0, 1]}),
|
||||||
|
equation=Equation(diffusionreaction)),
|
||||||
|
}
|
||||||
|
|
||||||
|
class myFeature(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
Feature: sin(x)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super(myFeature, self).__init__()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
t = (torch.sin(x.extract(['x']) * torch.pi))
|
||||||
|
return LabelTensor(t, ['sin(x)'])
|
||||||
|
|
||||||
|
|
||||||
|
# make the problem
|
||||||
|
problem = DiffusionReactionSystem()
|
||||||
|
model = FeedForward(len(problem.input_variables),
|
||||||
|
len(problem.output_variables))
|
||||||
|
model_extra_feats = FeedForward(
|
||||||
|
len(problem.input_variables) + 1,
|
||||||
|
len(problem.output_variables))
|
||||||
|
extra_feats = [myFeature()]
|
||||||
|
|
||||||
|
|
||||||
|
def test_constructor():
|
||||||
|
CausalPINN(problem=problem, model=model, extra_features=None)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
CausalPINN(FooProblem(), model=model, extra_features=None)
|
||||||
|
|
||||||
|
|
||||||
|
def test_constructor_extra_feats():
|
||||||
|
model_extra_feats = FeedForward(
|
||||||
|
len(problem.input_variables) + 1,
|
||||||
|
len(problem.output_variables))
|
||||||
|
CausalPINN(problem=problem,
|
||||||
|
model=model_extra_feats,
|
||||||
|
extra_features=extra_feats)
|
||||||
|
|
||||||
|
|
||||||
|
def test_train_cpu():
|
||||||
|
problem = DiffusionReactionSystem()
|
||||||
|
boundaries = ['D']
|
||||||
|
n = 10
|
||||||
|
problem.discretise_domain(n, 'grid', locations=boundaries)
|
||||||
|
pinn = CausalPINN(problem = problem,
|
||||||
|
model=model, extra_features=None, loss=LpLoss())
|
||||||
|
trainer = Trainer(solver=pinn, max_epochs=1,
|
||||||
|
accelerator='cpu', batch_size=20)
|
||||||
|
trainer.train()
|
||||||
|
|
||||||
|
|
||||||
|
def test_train_restore():
|
||||||
|
tmpdir = "tests/tmp_restore"
|
||||||
|
problem = DiffusionReactionSystem()
|
||||||
|
boundaries = ['D']
|
||||||
|
n = 10
|
||||||
|
problem.discretise_domain(n, 'grid', locations=boundaries)
|
||||||
|
pinn = CausalPINN(problem=problem,
|
||||||
|
model=model,
|
||||||
|
extra_features=None,
|
||||||
|
loss=LpLoss())
|
||||||
|
trainer = Trainer(solver=pinn,
|
||||||
|
max_epochs=5,
|
||||||
|
accelerator='cpu',
|
||||||
|
default_root_dir=tmpdir)
|
||||||
|
trainer.train()
|
||||||
|
ntrainer = Trainer(solver=pinn, max_epochs=15, accelerator='cpu')
|
||||||
|
t = ntrainer.train(
|
||||||
|
ckpt_path=f'{tmpdir}/lightning_logs/version_0/'
|
||||||
|
'checkpoints/epoch=4-step=5.ckpt')
|
||||||
|
import shutil
|
||||||
|
shutil.rmtree(tmpdir)
|
||||||
|
|
||||||
|
|
||||||
|
def test_train_load():
|
||||||
|
tmpdir = "tests/tmp_load"
|
||||||
|
problem = DiffusionReactionSystem()
|
||||||
|
boundaries = ['D']
|
||||||
|
n = 10
|
||||||
|
problem.discretise_domain(n, 'grid', locations=boundaries)
|
||||||
|
pinn = CausalPINN(problem=problem,
|
||||||
|
model=model,
|
||||||
|
extra_features=None,
|
||||||
|
loss=LpLoss())
|
||||||
|
trainer = Trainer(solver=pinn,
|
||||||
|
max_epochs=15,
|
||||||
|
accelerator='cpu',
|
||||||
|
default_root_dir=tmpdir)
|
||||||
|
trainer.train()
|
||||||
|
new_pinn = CausalPINN.load_from_checkpoint(
|
||||||
|
f'{tmpdir}/lightning_logs/version_0/checkpoints/epoch=14-step=15.ckpt',
|
||||||
|
problem = problem, model=model)
|
||||||
|
test_pts = CartesianDomain({'x': [0, 1], 't': [0, 1]}).sample(10)
|
||||||
|
assert new_pinn.forward(test_pts).extract(['u']).shape == (10, 1)
|
||||||
|
assert new_pinn.forward(test_pts).extract(
|
||||||
|
['u']).shape == pinn.forward(test_pts).extract(['u']).shape
|
||||||
|
torch.testing.assert_close(
|
||||||
|
new_pinn.forward(test_pts).extract(['u']),
|
||||||
|
pinn.forward(test_pts).extract(['u']))
|
||||||
|
import shutil
|
||||||
|
shutil.rmtree(tmpdir)
|
||||||
|
|
||||||
|
def test_train_inverse_problem_cpu():
|
||||||
|
problem = InverseDiffusionReactionSystem()
|
||||||
|
boundaries = ['D']
|
||||||
|
n = 100
|
||||||
|
problem.discretise_domain(n, 'random', locations=boundaries)
|
||||||
|
pinn = CausalPINN(problem = problem,
|
||||||
|
model=model, extra_features=None, loss=LpLoss())
|
||||||
|
trainer = Trainer(solver=pinn, max_epochs=1,
|
||||||
|
accelerator='cpu', batch_size=20)
|
||||||
|
trainer.train()
|
||||||
|
|
||||||
|
|
||||||
|
# # TODO does not currently work
|
||||||
|
# def test_train_inverse_problem_restore():
|
||||||
|
# tmpdir = "tests/tmp_restore_inv"
|
||||||
|
# problem = InverseDiffusionReactionSystem()
|
||||||
|
# boundaries = ['D']
|
||||||
|
# n = 100
|
||||||
|
# problem.discretise_domain(n, 'random', locations=boundaries)
|
||||||
|
# pinn = CausalPINN(problem=problem,
|
||||||
|
# model=model,
|
||||||
|
# extra_features=None,
|
||||||
|
# loss=LpLoss())
|
||||||
|
# trainer = Trainer(solver=pinn,
|
||||||
|
# max_epochs=5,
|
||||||
|
# accelerator='cpu',
|
||||||
|
# default_root_dir=tmpdir)
|
||||||
|
# trainer.train()
|
||||||
|
# ntrainer = Trainer(solver=pinn, max_epochs=5, accelerator='cpu')
|
||||||
|
# t = ntrainer.train(
|
||||||
|
# ckpt_path=f'{tmpdir}/lightning_logs/version_0/checkpoints/epoch=4-step=5.ckpt')
|
||||||
|
# import shutil
|
||||||
|
# shutil.rmtree(tmpdir)
|
||||||
|
|
||||||
|
|
||||||
|
def test_train_inverse_problem_load():
|
||||||
|
tmpdir = "tests/tmp_load_inv"
|
||||||
|
problem = InverseDiffusionReactionSystem()
|
||||||
|
boundaries = ['D']
|
||||||
|
n = 100
|
||||||
|
problem.discretise_domain(n, 'random', locations=boundaries)
|
||||||
|
pinn = CausalPINN(problem=problem,
|
||||||
|
model=model,
|
||||||
|
extra_features=None,
|
||||||
|
loss=LpLoss())
|
||||||
|
trainer = Trainer(solver=pinn,
|
||||||
|
max_epochs=15,
|
||||||
|
accelerator='cpu',
|
||||||
|
default_root_dir=tmpdir)
|
||||||
|
trainer.train()
|
||||||
|
new_pinn = CausalPINN.load_from_checkpoint(
|
||||||
|
f'{tmpdir}/lightning_logs/version_0/checkpoints/epoch=14-step=30.ckpt',
|
||||||
|
problem = problem, model=model)
|
||||||
|
test_pts = CartesianDomain({'x': [0, 1], 't': [0, 1]}).sample(10)
|
||||||
|
assert new_pinn.forward(test_pts).extract(['u']).shape == (10, 1)
|
||||||
|
assert new_pinn.forward(test_pts).extract(
|
||||||
|
['u']).shape == pinn.forward(test_pts).extract(['u']).shape
|
||||||
|
torch.testing.assert_close(
|
||||||
|
new_pinn.forward(test_pts).extract(['u']),
|
||||||
|
pinn.forward(test_pts).extract(['u']))
|
||||||
|
import shutil
|
||||||
|
shutil.rmtree(tmpdir)
|
||||||
|
|
||||||
|
|
||||||
|
def test_train_extra_feats_cpu():
|
||||||
|
problem = DiffusionReactionSystem()
|
||||||
|
boundaries = ['D']
|
||||||
|
n = 10
|
||||||
|
problem.discretise_domain(n, 'grid', locations=boundaries)
|
||||||
|
pinn = CausalPINN(problem=problem,
|
||||||
|
model=model_extra_feats,
|
||||||
|
extra_features=extra_feats)
|
||||||
|
trainer = Trainer(solver=pinn, max_epochs=5, accelerator='cpu')
|
||||||
|
trainer.train()
|
||||||
418
tests/test_solvers/test_competitive_pinn.py
Normal file
418
tests/test_solvers/test_competitive_pinn.py
Normal file
@@ -0,0 +1,418 @@
|
|||||||
|
import torch
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from pina.problem import SpatialProblem, InverseProblem
|
||||||
|
from pina.operators import laplacian
|
||||||
|
from pina.geometry import CartesianDomain
|
||||||
|
from pina import Condition, LabelTensor
|
||||||
|
from pina.solvers import CompetitivePINN as PINN
|
||||||
|
from pina.trainer import Trainer
|
||||||
|
from pina.model import FeedForward
|
||||||
|
from pina.equation.equation import Equation
|
||||||
|
from pina.equation.equation_factory import FixedValue
|
||||||
|
from pina.loss import LpLoss
|
||||||
|
|
||||||
|
|
||||||
|
def laplace_equation(input_, output_):
|
||||||
|
force_term = (torch.sin(input_.extract(['x']) * torch.pi) *
|
||||||
|
torch.sin(input_.extract(['y']) * torch.pi))
|
||||||
|
delta_u = laplacian(output_.extract(['u']), input_)
|
||||||
|
return delta_u - force_term
|
||||||
|
|
||||||
|
|
||||||
|
my_laplace = Equation(laplace_equation)
|
||||||
|
in_ = LabelTensor(torch.tensor([[0., 1.]]), ['x', 'y'])
|
||||||
|
out_ = LabelTensor(torch.tensor([[0.]]), ['u'])
|
||||||
|
in2_ = LabelTensor(torch.rand(60, 2), ['x', 'y'])
|
||||||
|
out2_ = LabelTensor(torch.rand(60, 1), ['u'])
|
||||||
|
|
||||||
|
|
||||||
|
class InversePoisson(SpatialProblem, InverseProblem):
|
||||||
|
'''
|
||||||
|
Problem definition for the Poisson equation.
|
||||||
|
'''
|
||||||
|
output_variables = ['u']
|
||||||
|
x_min = -2
|
||||||
|
x_max = 2
|
||||||
|
y_min = -2
|
||||||
|
y_max = 2
|
||||||
|
data_input = LabelTensor(torch.rand(10, 2), ['x', 'y'])
|
||||||
|
data_output = LabelTensor(torch.rand(10, 1), ['u'])
|
||||||
|
spatial_domain = CartesianDomain({'x': [x_min, x_max], 'y': [y_min, y_max]})
|
||||||
|
# define the ranges for the parameters
|
||||||
|
unknown_parameter_domain = CartesianDomain({'mu1': [-1, 1], 'mu2': [-1, 1]})
|
||||||
|
|
||||||
|
def laplace_equation(input_, output_, params_):
|
||||||
|
'''
|
||||||
|
Laplace equation with a force term.
|
||||||
|
'''
|
||||||
|
force_term = torch.exp(
|
||||||
|
- 2*(input_.extract(['x']) - params_['mu1'])**2
|
||||||
|
- 2*(input_.extract(['y']) - params_['mu2'])**2)
|
||||||
|
delta_u = laplacian(output_, input_, components=['u'], d=['x', 'y'])
|
||||||
|
|
||||||
|
return delta_u - force_term
|
||||||
|
|
||||||
|
# define the conditions for the loss (boundary conditions, equation, data)
|
||||||
|
conditions = {
|
||||||
|
'gamma1': Condition(location=CartesianDomain({'x': [x_min, x_max],
|
||||||
|
'y': y_max}),
|
||||||
|
equation=FixedValue(0.0, components=['u'])),
|
||||||
|
'gamma2': Condition(location=CartesianDomain(
|
||||||
|
{'x': [x_min, x_max], 'y': y_min
|
||||||
|
}),
|
||||||
|
equation=FixedValue(0.0, components=['u'])),
|
||||||
|
'gamma3': Condition(location=CartesianDomain(
|
||||||
|
{'x': x_max, 'y': [y_min, y_max]
|
||||||
|
}),
|
||||||
|
equation=FixedValue(0.0, components=['u'])),
|
||||||
|
'gamma4': Condition(location=CartesianDomain(
|
||||||
|
{'x': x_min, 'y': [y_min, y_max]
|
||||||
|
}),
|
||||||
|
equation=FixedValue(0.0, components=['u'])),
|
||||||
|
'D': Condition(location=CartesianDomain(
|
||||||
|
{'x': [x_min, x_max], 'y': [y_min, y_max]
|
||||||
|
}),
|
||||||
|
equation=Equation(laplace_equation)),
|
||||||
|
'data': Condition(input_points=data_input.extract(['x', 'y']),
|
||||||
|
output_points=data_output)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class Poisson(SpatialProblem):
|
||||||
|
output_variables = ['u']
|
||||||
|
spatial_domain = CartesianDomain({'x': [0, 1], 'y': [0, 1]})
|
||||||
|
|
||||||
|
conditions = {
|
||||||
|
'gamma1': Condition(
|
||||||
|
location=CartesianDomain({'x': [0, 1], 'y': 1}),
|
||||||
|
equation=FixedValue(0.0)),
|
||||||
|
'gamma2': Condition(
|
||||||
|
location=CartesianDomain({'x': [0, 1], 'y': 0}),
|
||||||
|
equation=FixedValue(0.0)),
|
||||||
|
'gamma3': Condition(
|
||||||
|
location=CartesianDomain({'x': 1, 'y': [0, 1]}),
|
||||||
|
equation=FixedValue(0.0)),
|
||||||
|
'gamma4': Condition(
|
||||||
|
location=CartesianDomain({'x': 0, 'y': [0, 1]}),
|
||||||
|
equation=FixedValue(0.0)),
|
||||||
|
'D': Condition(
|
||||||
|
input_points=LabelTensor(torch.rand(size=(100, 2)), ['x', 'y']),
|
||||||
|
equation=my_laplace),
|
||||||
|
'data': Condition(
|
||||||
|
input_points=in_,
|
||||||
|
output_points=out_),
|
||||||
|
'data2': Condition(
|
||||||
|
input_points=in2_,
|
||||||
|
output_points=out2_)
|
||||||
|
}
|
||||||
|
|
||||||
|
def poisson_sol(self, pts):
|
||||||
|
return -(torch.sin(pts.extract(['x']) * torch.pi) *
|
||||||
|
torch.sin(pts.extract(['y']) * torch.pi)) / (2 * torch.pi**2)
|
||||||
|
|
||||||
|
truth_solution = poisson_sol
|
||||||
|
|
||||||
|
|
||||||
|
class myFeature(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
Feature: sin(x)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super(myFeature, self).__init__()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
t = (torch.sin(x.extract(['x']) * torch.pi) *
|
||||||
|
torch.sin(x.extract(['y']) * torch.pi))
|
||||||
|
return LabelTensor(t, ['sin(x)sin(y)'])
|
||||||
|
|
||||||
|
|
||||||
|
# make the problem
|
||||||
|
poisson_problem = Poisson()
|
||||||
|
model = FeedForward(len(poisson_problem.input_variables),
|
||||||
|
len(poisson_problem.output_variables))
|
||||||
|
model_extra_feats = FeedForward(
|
||||||
|
len(poisson_problem.input_variables) + 1,
|
||||||
|
len(poisson_problem.output_variables))
|
||||||
|
extra_feats = [myFeature()]
|
||||||
|
|
||||||
|
|
||||||
|
def test_constructor():
|
||||||
|
PINN(problem=poisson_problem, model=model)
|
||||||
|
PINN(problem=poisson_problem, model=model, discriminator = model)
|
||||||
|
|
||||||
|
|
||||||
|
def test_constructor_extra_feats():
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
model_extra_feats = FeedForward(
|
||||||
|
len(poisson_problem.input_variables) + 1,
|
||||||
|
len(poisson_problem.output_variables))
|
||||||
|
PINN(problem=poisson_problem,
|
||||||
|
model=model_extra_feats,
|
||||||
|
extra_features=extra_feats)
|
||||||
|
|
||||||
|
|
||||||
|
def test_train_cpu():
|
||||||
|
poisson_problem = Poisson()
|
||||||
|
boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
|
||||||
|
n = 10
|
||||||
|
poisson_problem.discretise_domain(n, 'grid', locations=boundaries)
|
||||||
|
pinn = PINN(problem = poisson_problem, model=model, loss=LpLoss())
|
||||||
|
trainer = Trainer(solver=pinn, max_epochs=1,
|
||||||
|
accelerator='cpu', batch_size=20)
|
||||||
|
trainer.train()
|
||||||
|
|
||||||
|
|
||||||
|
def test_train_restore():
|
||||||
|
tmpdir = "tests/tmp_restore"
|
||||||
|
poisson_problem = Poisson()
|
||||||
|
boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
|
||||||
|
n = 10
|
||||||
|
poisson_problem.discretise_domain(n, 'grid', locations=boundaries)
|
||||||
|
pinn = PINN(problem=poisson_problem,
|
||||||
|
model=model,
|
||||||
|
loss=LpLoss())
|
||||||
|
trainer = Trainer(solver=pinn,
|
||||||
|
max_epochs=5,
|
||||||
|
accelerator='cpu',
|
||||||
|
default_root_dir=tmpdir)
|
||||||
|
trainer.train()
|
||||||
|
ntrainer = Trainer(solver=pinn, max_epochs=15, accelerator='cpu')
|
||||||
|
t = ntrainer.train(
|
||||||
|
ckpt_path=f'{tmpdir}/lightning_logs/version_0/'
|
||||||
|
'checkpoints/epoch=4-step=10.ckpt')
|
||||||
|
import shutil
|
||||||
|
shutil.rmtree(tmpdir)
|
||||||
|
|
||||||
|
|
||||||
|
def test_train_load():
|
||||||
|
tmpdir = "tests/tmp_load"
|
||||||
|
poisson_problem = Poisson()
|
||||||
|
boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
|
||||||
|
n = 10
|
||||||
|
poisson_problem.discretise_domain(n, 'grid', locations=boundaries)
|
||||||
|
pinn = PINN(problem=poisson_problem,
|
||||||
|
model=model,
|
||||||
|
loss=LpLoss())
|
||||||
|
trainer = Trainer(solver=pinn,
|
||||||
|
max_epochs=15,
|
||||||
|
accelerator='cpu',
|
||||||
|
default_root_dir=tmpdir)
|
||||||
|
trainer.train()
|
||||||
|
new_pinn = PINN.load_from_checkpoint(
|
||||||
|
f'{tmpdir}/lightning_logs/version_0/checkpoints/epoch=14-step=30.ckpt',
|
||||||
|
problem = poisson_problem, model=model)
|
||||||
|
test_pts = CartesianDomain({'x': [0, 1], 'y': [0, 1]}).sample(10)
|
||||||
|
assert new_pinn.forward(test_pts).extract(['u']).shape == (10, 1)
|
||||||
|
assert new_pinn.forward(test_pts).extract(
|
||||||
|
['u']).shape == pinn.forward(test_pts).extract(['u']).shape
|
||||||
|
torch.testing.assert_close(
|
||||||
|
new_pinn.forward(test_pts).extract(['u']),
|
||||||
|
pinn.forward(test_pts).extract(['u']))
|
||||||
|
import shutil
|
||||||
|
shutil.rmtree(tmpdir)
|
||||||
|
|
||||||
|
def test_train_inverse_problem_cpu():
|
||||||
|
poisson_problem = InversePoisson()
|
||||||
|
boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4', 'D']
|
||||||
|
n = 100
|
||||||
|
poisson_problem.discretise_domain(n, 'random', locations=boundaries)
|
||||||
|
pinn = PINN(problem = poisson_problem, model=model, loss=LpLoss())
|
||||||
|
trainer = Trainer(solver=pinn, max_epochs=1,
|
||||||
|
accelerator='cpu', batch_size=20)
|
||||||
|
trainer.train()
|
||||||
|
|
||||||
|
|
||||||
|
# # TODO does not currently work
|
||||||
|
# def test_train_inverse_problem_restore():
|
||||||
|
# tmpdir = "tests/tmp_restore_inv"
|
||||||
|
# poisson_problem = InversePoisson()
|
||||||
|
# boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4', 'D']
|
||||||
|
# n = 100
|
||||||
|
# poisson_problem.discretise_domain(n, 'random', locations=boundaries)
|
||||||
|
# pinn = PINN(problem=poisson_problem,
|
||||||
|
# model=model,
|
||||||
|
# loss=LpLoss())
|
||||||
|
# trainer = Trainer(solver=pinn,
|
||||||
|
# max_epochs=5,
|
||||||
|
# accelerator='cpu',
|
||||||
|
# default_root_dir=tmpdir)
|
||||||
|
# trainer.train()
|
||||||
|
# ntrainer = Trainer(solver=pinn, max_epochs=5, accelerator='cpu')
|
||||||
|
# t = ntrainer.train(
|
||||||
|
# ckpt_path=f'{tmpdir}/lightning_logs/version_0/checkpoints/epoch=4-step=10.ckpt')
|
||||||
|
# import shutil
|
||||||
|
# shutil.rmtree(tmpdir)
|
||||||
|
|
||||||
|
|
||||||
|
def test_train_inverse_problem_load():
|
||||||
|
tmpdir = "tests/tmp_load_inv"
|
||||||
|
poisson_problem = InversePoisson()
|
||||||
|
boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4', 'D']
|
||||||
|
n = 100
|
||||||
|
poisson_problem.discretise_domain(n, 'random', locations=boundaries)
|
||||||
|
pinn = PINN(problem=poisson_problem,
|
||||||
|
model=model,
|
||||||
|
loss=LpLoss())
|
||||||
|
trainer = Trainer(solver=pinn,
|
||||||
|
max_epochs=15,
|
||||||
|
accelerator='cpu',
|
||||||
|
default_root_dir=tmpdir)
|
||||||
|
trainer.train()
|
||||||
|
new_pinn = PINN.load_from_checkpoint(
|
||||||
|
f'{tmpdir}/lightning_logs/version_0/checkpoints/epoch=14-step=30.ckpt',
|
||||||
|
problem = poisson_problem, model=model)
|
||||||
|
test_pts = CartesianDomain({'x': [0, 1], 'y': [0, 1]}).sample(10)
|
||||||
|
assert new_pinn.forward(test_pts).extract(['u']).shape == (10, 1)
|
||||||
|
assert new_pinn.forward(test_pts).extract(
|
||||||
|
['u']).shape == pinn.forward(test_pts).extract(['u']).shape
|
||||||
|
torch.testing.assert_close(
|
||||||
|
new_pinn.forward(test_pts).extract(['u']),
|
||||||
|
pinn.forward(test_pts).extract(['u']))
|
||||||
|
import shutil
|
||||||
|
shutil.rmtree(tmpdir)
|
||||||
|
|
||||||
|
# # TODO fix asap. Basically sampling few variables
|
||||||
|
# # works only if both variables are in a range.
|
||||||
|
# # if one is fixed and the other not, this will
|
||||||
|
# # not work. This test also needs to be fixed and
|
||||||
|
# # insert in test problem not in test pinn.
|
||||||
|
# def test_train_cpu_sampling_few_vars():
|
||||||
|
# poisson_problem = Poisson()
|
||||||
|
# boundaries = ['gamma1', 'gamma2', 'gamma3']
|
||||||
|
# n = 10
|
||||||
|
# poisson_problem.discretise_domain(n, 'grid', locations=boundaries)
|
||||||
|
# poisson_problem.discretise_domain(n, 'random', locations=['gamma4'], variables=['x'])
|
||||||
|
# poisson_problem.discretise_domain(n, 'random', locations=['gamma4'], variables=['y'])
|
||||||
|
# pinn = PINN(problem = poisson_problem, model=model, extra_features=None, loss=LpLoss())
|
||||||
|
# trainer = Trainer(solver=pinn, kwargs={'max_epochs' : 5, 'accelerator':'cpu'})
|
||||||
|
# trainer.train()
|
||||||
|
|
||||||
|
|
||||||
|
# TODO, fix GitHub actions to run also on GPU
|
||||||
|
# def test_train_gpu():
|
||||||
|
# poisson_problem = Poisson()
|
||||||
|
# boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
|
||||||
|
# n = 10
|
||||||
|
# poisson_problem.discretise_domain(n, 'grid', locations=boundaries)
|
||||||
|
# pinn = PINN(problem = poisson_problem, model=model, extra_features=None, loss=LpLoss())
|
||||||
|
# trainer = Trainer(solver=pinn, kwargs={'max_epochs' : 5, 'accelerator':'gpu'})
|
||||||
|
# trainer.train()
|
||||||
|
|
||||||
|
# def test_train_gpu(): #TODO fix ASAP
|
||||||
|
# poisson_problem = Poisson()
|
||||||
|
# boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
|
||||||
|
# n = 10
|
||||||
|
# poisson_problem.discretise_domain(n, 'grid', locations=boundaries)
|
||||||
|
# poisson_problem.conditions.pop('data') # The input/output pts are allocated on cpu
|
||||||
|
# pinn = PINN(problem = poisson_problem, model=model, extra_features=None, loss=LpLoss())
|
||||||
|
# trainer = Trainer(solver=pinn, kwargs={'max_epochs' : 5, 'accelerator':'gpu'})
|
||||||
|
# trainer.train()
|
||||||
|
|
||||||
|
# def test_train_2():
|
||||||
|
# boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
|
||||||
|
# n = 10
|
||||||
|
# expected_keys = [[], list(range(0, 50, 3))]
|
||||||
|
# param = [0, 3]
|
||||||
|
# for i, truth_key in zip(param, expected_keys):
|
||||||
|
# pinn = PINN(problem, model)
|
||||||
|
# pinn.discretise_domain(n, 'grid', locations=boundaries)
|
||||||
|
# pinn.discretise_domain(n, 'grid', locations=['D'])
|
||||||
|
# pinn.train(50, save_loss=i)
|
||||||
|
# assert list(pinn.history_loss.keys()) == truth_key
|
||||||
|
|
||||||
|
|
||||||
|
# def test_train_extra_feats():
|
||||||
|
# pinn = PINN(problem, model_extra_feat, [myFeature()])
|
||||||
|
# boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
|
||||||
|
# n = 10
|
||||||
|
# pinn.discretise_domain(n, 'grid', locations=boundaries)
|
||||||
|
# pinn.discretise_domain(n, 'grid', locations=['D'])
|
||||||
|
# pinn.train(5)
|
||||||
|
|
||||||
|
|
||||||
|
# def test_train_2_extra_feats():
|
||||||
|
# boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
|
||||||
|
# n = 10
|
||||||
|
# expected_keys = [[], list(range(0, 50, 3))]
|
||||||
|
# param = [0, 3]
|
||||||
|
# for i, truth_key in zip(param, expected_keys):
|
||||||
|
# pinn = PINN(problem, model_extra_feat, [myFeature()])
|
||||||
|
# pinn.discretise_domain(n, 'grid', locations=boundaries)
|
||||||
|
# pinn.discretise_domain(n, 'grid', locations=['D'])
|
||||||
|
# pinn.train(50, save_loss=i)
|
||||||
|
# assert list(pinn.history_loss.keys()) == truth_key
|
||||||
|
|
||||||
|
|
||||||
|
# def test_train_with_optimizer_kwargs():
|
||||||
|
# boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
|
||||||
|
# n = 10
|
||||||
|
# expected_keys = [[], list(range(0, 50, 3))]
|
||||||
|
# param = [0, 3]
|
||||||
|
# for i, truth_key in zip(param, expected_keys):
|
||||||
|
# pinn = PINN(problem, model, optimizer_kwargs={'lr' : 0.3})
|
||||||
|
# pinn.discretise_domain(n, 'grid', locations=boundaries)
|
||||||
|
# pinn.discretise_domain(n, 'grid', locations=['D'])
|
||||||
|
# pinn.train(50, save_loss=i)
|
||||||
|
# assert list(pinn.history_loss.keys()) == truth_key
|
||||||
|
|
||||||
|
|
||||||
|
# def test_train_with_lr_scheduler():
|
||||||
|
# boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
|
||||||
|
# n = 10
|
||||||
|
# expected_keys = [[], list(range(0, 50, 3))]
|
||||||
|
# param = [0, 3]
|
||||||
|
# for i, truth_key in zip(param, expected_keys):
|
||||||
|
# pinn = PINN(
|
||||||
|
# problem,
|
||||||
|
# model,
|
||||||
|
# lr_scheduler_type=torch.optim.lr_scheduler.CyclicLR,
|
||||||
|
# lr_scheduler_kwargs={'base_lr' : 0.1, 'max_lr' : 0.3, 'cycle_momentum': False}
|
||||||
|
# )
|
||||||
|
# pinn.discretise_domain(n, 'grid', locations=boundaries)
|
||||||
|
# pinn.discretise_domain(n, 'grid', locations=['D'])
|
||||||
|
# pinn.train(50, save_loss=i)
|
||||||
|
# assert list(pinn.history_loss.keys()) == truth_key
|
||||||
|
|
||||||
|
|
||||||
|
# # def test_train_batch():
|
||||||
|
# # pinn = PINN(problem, model, batch_size=6)
|
||||||
|
# # boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
|
||||||
|
# # n = 10
|
||||||
|
# # pinn.discretise_domain(n, 'grid', locations=boundaries)
|
||||||
|
# # pinn.discretise_domain(n, 'grid', locations=['D'])
|
||||||
|
# # pinn.train(5)
|
||||||
|
|
||||||
|
|
||||||
|
# # def test_train_batch_2():
|
||||||
|
# # boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
|
||||||
|
# # n = 10
|
||||||
|
# # expected_keys = [[], list(range(0, 50, 3))]
|
||||||
|
# # param = [0, 3]
|
||||||
|
# # for i, truth_key in zip(param, expected_keys):
|
||||||
|
# # pinn = PINN(problem, model, batch_size=6)
|
||||||
|
# # pinn.discretise_domain(n, 'grid', locations=boundaries)
|
||||||
|
# # pinn.discretise_domain(n, 'grid', locations=['D'])
|
||||||
|
# # pinn.train(50, save_loss=i)
|
||||||
|
# # assert list(pinn.history_loss.keys()) == truth_key
|
||||||
|
|
||||||
|
|
||||||
|
# if torch.cuda.is_available():
|
||||||
|
|
||||||
|
# # def test_gpu_train():
|
||||||
|
# # pinn = PINN(problem, model, batch_size=20, device='cuda')
|
||||||
|
# # boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
|
||||||
|
# # n = 100
|
||||||
|
# # pinn.discretise_domain(n, 'grid', locations=boundaries)
|
||||||
|
# # pinn.discretise_domain(n, 'grid', locations=['D'])
|
||||||
|
# # pinn.train(5)
|
||||||
|
|
||||||
|
# def test_gpu_train_nobatch():
|
||||||
|
# pinn = PINN(problem, model, batch_size=None, device='cuda')
|
||||||
|
# boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
|
||||||
|
# n = 100
|
||||||
|
# pinn.discretise_domain(n, 'grid', locations=boundaries)
|
||||||
|
# pinn.discretise_domain(n, 'grid', locations=['D'])
|
||||||
|
# pinn.train(5)
|
||||||
|
|
||||||
432
tests/test_solvers/test_gpinn.py
Normal file
432
tests/test_solvers/test_gpinn.py
Normal file
@@ -0,0 +1,432 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
from pina.problem import SpatialProblem, InverseProblem
|
||||||
|
from pina.operators import laplacian
|
||||||
|
from pina.geometry import CartesianDomain
|
||||||
|
from pina import Condition, LabelTensor
|
||||||
|
from pina.solvers import GPINN
|
||||||
|
from pina.trainer import Trainer
|
||||||
|
from pina.model import FeedForward
|
||||||
|
from pina.equation.equation import Equation
|
||||||
|
from pina.equation.equation_factory import FixedValue
|
||||||
|
from pina.loss import LpLoss
|
||||||
|
|
||||||
|
|
||||||
|
def laplace_equation(input_, output_):
|
||||||
|
force_term = (torch.sin(input_.extract(['x']) * torch.pi) *
|
||||||
|
torch.sin(input_.extract(['y']) * torch.pi))
|
||||||
|
delta_u = laplacian(output_.extract(['u']), input_)
|
||||||
|
return delta_u - force_term
|
||||||
|
|
||||||
|
|
||||||
|
my_laplace = Equation(laplace_equation)
|
||||||
|
in_ = LabelTensor(torch.tensor([[0., 1.]]), ['x', 'y'])
|
||||||
|
out_ = LabelTensor(torch.tensor([[0.]]), ['u'])
|
||||||
|
in2_ = LabelTensor(torch.rand(60, 2), ['x', 'y'])
|
||||||
|
out2_ = LabelTensor(torch.rand(60, 1), ['u'])
|
||||||
|
|
||||||
|
|
||||||
|
class InversePoisson(SpatialProblem, InverseProblem):
|
||||||
|
'''
|
||||||
|
Problem definition for the Poisson equation.
|
||||||
|
'''
|
||||||
|
output_variables = ['u']
|
||||||
|
x_min = -2
|
||||||
|
x_max = 2
|
||||||
|
y_min = -2
|
||||||
|
y_max = 2
|
||||||
|
data_input = LabelTensor(torch.rand(10, 2), ['x', 'y'])
|
||||||
|
data_output = LabelTensor(torch.rand(10, 1), ['u'])
|
||||||
|
spatial_domain = CartesianDomain({'x': [x_min, x_max], 'y': [y_min, y_max]})
|
||||||
|
# define the ranges for the parameters
|
||||||
|
unknown_parameter_domain = CartesianDomain({'mu1': [-1, 1], 'mu2': [-1, 1]})
|
||||||
|
|
||||||
|
def laplace_equation(input_, output_, params_):
|
||||||
|
'''
|
||||||
|
Laplace equation with a force term.
|
||||||
|
'''
|
||||||
|
force_term = torch.exp(
|
||||||
|
- 2*(input_.extract(['x']) - params_['mu1'])**2
|
||||||
|
- 2*(input_.extract(['y']) - params_['mu2'])**2)
|
||||||
|
delta_u = laplacian(output_, input_, components=['u'], d=['x', 'y'])
|
||||||
|
|
||||||
|
return delta_u - force_term
|
||||||
|
|
||||||
|
# define the conditions for the loss (boundary conditions, equation, data)
|
||||||
|
conditions = {
|
||||||
|
'gamma1': Condition(location=CartesianDomain({'x': [x_min, x_max],
|
||||||
|
'y': y_max}),
|
||||||
|
equation=FixedValue(0.0, components=['u'])),
|
||||||
|
'gamma2': Condition(location=CartesianDomain(
|
||||||
|
{'x': [x_min, x_max], 'y': y_min}),
|
||||||
|
equation=FixedValue(0.0, components=['u'])),
|
||||||
|
'gamma3': Condition(location=CartesianDomain(
|
||||||
|
{'x': x_max, 'y': [y_min, y_max]}),
|
||||||
|
equation=FixedValue(0.0, components=['u'])),
|
||||||
|
'gamma4': Condition(location=CartesianDomain(
|
||||||
|
{'x': x_min, 'y': [y_min, y_max]
|
||||||
|
}),
|
||||||
|
equation=FixedValue(0.0, components=['u'])),
|
||||||
|
'D': Condition(location=CartesianDomain(
|
||||||
|
{'x': [x_min, x_max], 'y': [y_min, y_max]
|
||||||
|
}),
|
||||||
|
equation=Equation(laplace_equation)),
|
||||||
|
'data': Condition(
|
||||||
|
input_points=data_input.extract(['x', 'y']),
|
||||||
|
output_points=data_output)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class Poisson(SpatialProblem):
|
||||||
|
output_variables = ['u']
|
||||||
|
spatial_domain = CartesianDomain({'x': [0, 1], 'y': [0, 1]})
|
||||||
|
|
||||||
|
conditions = {
|
||||||
|
'gamma1': Condition(
|
||||||
|
location=CartesianDomain({'x': [0, 1], 'y': 1}),
|
||||||
|
equation=FixedValue(0.0)),
|
||||||
|
'gamma2': Condition(
|
||||||
|
location=CartesianDomain({'x': [0, 1], 'y': 0}),
|
||||||
|
equation=FixedValue(0.0)),
|
||||||
|
'gamma3': Condition(
|
||||||
|
location=CartesianDomain({'x': 1, 'y': [0, 1]}),
|
||||||
|
equation=FixedValue(0.0)),
|
||||||
|
'gamma4': Condition(
|
||||||
|
location=CartesianDomain({'x': 0, 'y': [0, 1]}),
|
||||||
|
equation=FixedValue(0.0)),
|
||||||
|
'D': Condition(
|
||||||
|
input_points=LabelTensor(torch.rand(size=(100, 2)), ['x', 'y']),
|
||||||
|
equation=my_laplace),
|
||||||
|
'data': Condition(
|
||||||
|
input_points=in_,
|
||||||
|
output_points=out_),
|
||||||
|
'data2': Condition(
|
||||||
|
input_points=in2_,
|
||||||
|
output_points=out2_)
|
||||||
|
}
|
||||||
|
|
||||||
|
def poisson_sol(self, pts):
|
||||||
|
return -(torch.sin(pts.extract(['x']) * torch.pi) *
|
||||||
|
torch.sin(pts.extract(['y']) * torch.pi)) / (2 * torch.pi**2)
|
||||||
|
|
||||||
|
truth_solution = poisson_sol
|
||||||
|
|
||||||
|
|
||||||
|
class myFeature(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
Feature: sin(x)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super(myFeature, self).__init__()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
t = (torch.sin(x.extract(['x']) * torch.pi) *
|
||||||
|
torch.sin(x.extract(['y']) * torch.pi))
|
||||||
|
return LabelTensor(t, ['sin(x)sin(y)'])
|
||||||
|
|
||||||
|
|
||||||
|
# make the problem
|
||||||
|
poisson_problem = Poisson()
|
||||||
|
model = FeedForward(len(poisson_problem.input_variables),
|
||||||
|
len(poisson_problem.output_variables))
|
||||||
|
model_extra_feats = FeedForward(
|
||||||
|
len(poisson_problem.input_variables) + 1,
|
||||||
|
len(poisson_problem.output_variables))
|
||||||
|
extra_feats = [myFeature()]
|
||||||
|
|
||||||
|
|
||||||
|
def test_constructor():
|
||||||
|
GPINN(problem=poisson_problem, model=model, extra_features=None)
|
||||||
|
|
||||||
|
|
||||||
|
def test_constructor_extra_feats():
|
||||||
|
model_extra_feats = FeedForward(
|
||||||
|
len(poisson_problem.input_variables) + 1,
|
||||||
|
len(poisson_problem.output_variables))
|
||||||
|
GPINN(problem=poisson_problem,
|
||||||
|
model=model_extra_feats,
|
||||||
|
extra_features=extra_feats)
|
||||||
|
|
||||||
|
|
||||||
|
def test_train_cpu():
|
||||||
|
poisson_problem = Poisson()
|
||||||
|
boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
|
||||||
|
n = 10
|
||||||
|
poisson_problem.discretise_domain(n, 'grid', locations=boundaries)
|
||||||
|
pinn = GPINN(problem = poisson_problem,
|
||||||
|
model=model, extra_features=None, loss=LpLoss())
|
||||||
|
trainer = Trainer(solver=pinn, max_epochs=1,
|
||||||
|
accelerator='cpu', batch_size=20)
|
||||||
|
trainer.train()
|
||||||
|
|
||||||
|
|
||||||
|
def test_train_restore():
|
||||||
|
tmpdir = "tests/tmp_restore"
|
||||||
|
poisson_problem = Poisson()
|
||||||
|
boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
|
||||||
|
n = 10
|
||||||
|
poisson_problem.discretise_domain(n, 'grid', locations=boundaries)
|
||||||
|
pinn = GPINN(problem=poisson_problem,
|
||||||
|
model=model,
|
||||||
|
extra_features=None,
|
||||||
|
loss=LpLoss())
|
||||||
|
trainer = Trainer(solver=pinn,
|
||||||
|
max_epochs=5,
|
||||||
|
accelerator='cpu',
|
||||||
|
default_root_dir=tmpdir)
|
||||||
|
trainer.train()
|
||||||
|
ntrainer = Trainer(solver=pinn, max_epochs=15, accelerator='cpu')
|
||||||
|
t = ntrainer.train(
|
||||||
|
ckpt_path=f'{tmpdir}/lightning_logs/version_0/'
|
||||||
|
'checkpoints/epoch=4-step=10.ckpt')
|
||||||
|
import shutil
|
||||||
|
shutil.rmtree(tmpdir)
|
||||||
|
|
||||||
|
|
||||||
|
def test_train_load():
|
||||||
|
tmpdir = "tests/tmp_load"
|
||||||
|
poisson_problem = Poisson()
|
||||||
|
boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
|
||||||
|
n = 10
|
||||||
|
poisson_problem.discretise_domain(n, 'grid', locations=boundaries)
|
||||||
|
pinn = GPINN(problem=poisson_problem,
|
||||||
|
model=model,
|
||||||
|
extra_features=None,
|
||||||
|
loss=LpLoss())
|
||||||
|
trainer = Trainer(solver=pinn,
|
||||||
|
max_epochs=15,
|
||||||
|
accelerator='cpu',
|
||||||
|
default_root_dir=tmpdir)
|
||||||
|
trainer.train()
|
||||||
|
new_pinn = GPINN.load_from_checkpoint(
|
||||||
|
f'{tmpdir}/lightning_logs/version_0/checkpoints/epoch=14-step=30.ckpt',
|
||||||
|
problem = poisson_problem, model=model)
|
||||||
|
test_pts = CartesianDomain({'x': [0, 1], 'y': [0, 1]}).sample(10)
|
||||||
|
assert new_pinn.forward(test_pts).extract(['u']).shape == (10, 1)
|
||||||
|
assert new_pinn.forward(test_pts).extract(
|
||||||
|
['u']).shape == pinn.forward(test_pts).extract(['u']).shape
|
||||||
|
torch.testing.assert_close(
|
||||||
|
new_pinn.forward(test_pts).extract(['u']),
|
||||||
|
pinn.forward(test_pts).extract(['u']))
|
||||||
|
import shutil
|
||||||
|
shutil.rmtree(tmpdir)
|
||||||
|
|
||||||
|
def test_train_inverse_problem_cpu():
|
||||||
|
poisson_problem = InversePoisson()
|
||||||
|
boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4', 'D']
|
||||||
|
n = 100
|
||||||
|
poisson_problem.discretise_domain(n, 'random', locations=boundaries)
|
||||||
|
pinn = GPINN(problem = poisson_problem,
|
||||||
|
model=model, extra_features=None, loss=LpLoss())
|
||||||
|
trainer = Trainer(solver=pinn, max_epochs=1,
|
||||||
|
accelerator='cpu', batch_size=20)
|
||||||
|
trainer.train()
|
||||||
|
|
||||||
|
|
||||||
|
# # TODO does not currently work
|
||||||
|
# def test_train_inverse_problem_restore():
|
||||||
|
# tmpdir = "tests/tmp_restore_inv"
|
||||||
|
# poisson_problem = InversePoisson()
|
||||||
|
# boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4', 'D']
|
||||||
|
# n = 100
|
||||||
|
# poisson_problem.discretise_domain(n, 'random', locations=boundaries)
|
||||||
|
# pinn = GPINN(problem=poisson_problem,
|
||||||
|
# model=model,
|
||||||
|
# extra_features=None,
|
||||||
|
# loss=LpLoss())
|
||||||
|
# trainer = Trainer(solver=pinn,
|
||||||
|
# max_epochs=5,
|
||||||
|
# accelerator='cpu',
|
||||||
|
# default_root_dir=tmpdir)
|
||||||
|
# trainer.train()
|
||||||
|
# ntrainer = Trainer(solver=pinn, max_epochs=5, accelerator='cpu')
|
||||||
|
# t = ntrainer.train(
|
||||||
|
# ckpt_path=f'{tmpdir}/lightning_logs/version_0/checkpoints/epoch=4-step=10.ckpt')
|
||||||
|
# import shutil
|
||||||
|
# shutil.rmtree(tmpdir)
|
||||||
|
|
||||||
|
|
||||||
|
def test_train_inverse_problem_load():
|
||||||
|
tmpdir = "tests/tmp_load_inv"
|
||||||
|
poisson_problem = InversePoisson()
|
||||||
|
boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4', 'D']
|
||||||
|
n = 100
|
||||||
|
poisson_problem.discretise_domain(n, 'random', locations=boundaries)
|
||||||
|
pinn = GPINN(problem=poisson_problem,
|
||||||
|
model=model,
|
||||||
|
extra_features=None,
|
||||||
|
loss=LpLoss())
|
||||||
|
trainer = Trainer(solver=pinn,
|
||||||
|
max_epochs=15,
|
||||||
|
accelerator='cpu',
|
||||||
|
default_root_dir=tmpdir)
|
||||||
|
trainer.train()
|
||||||
|
new_pinn = GPINN.load_from_checkpoint(
|
||||||
|
f'{tmpdir}/lightning_logs/version_0/checkpoints/epoch=14-step=30.ckpt',
|
||||||
|
problem = poisson_problem, model=model)
|
||||||
|
test_pts = CartesianDomain({'x': [0, 1], 'y': [0, 1]}).sample(10)
|
||||||
|
assert new_pinn.forward(test_pts).extract(['u']).shape == (10, 1)
|
||||||
|
assert new_pinn.forward(test_pts).extract(
|
||||||
|
['u']).shape == pinn.forward(test_pts).extract(['u']).shape
|
||||||
|
torch.testing.assert_close(
|
||||||
|
new_pinn.forward(test_pts).extract(['u']),
|
||||||
|
pinn.forward(test_pts).extract(['u']))
|
||||||
|
import shutil
|
||||||
|
shutil.rmtree(tmpdir)
|
||||||
|
|
||||||
|
# # TODO fix asap. Basically sampling few variables
|
||||||
|
# # works only if both variables are in a range.
|
||||||
|
# # if one is fixed and the other not, this will
|
||||||
|
# # not work. This test also needs to be fixed and
|
||||||
|
# # insert in test problem not in test pinn.
|
||||||
|
# def test_train_cpu_sampling_few_vars():
|
||||||
|
# poisson_problem = Poisson()
|
||||||
|
# boundaries = ['gamma1', 'gamma2', 'gamma3']
|
||||||
|
# n = 10
|
||||||
|
# poisson_problem.discretise_domain(n, 'grid', locations=boundaries)
|
||||||
|
# poisson_problem.discretise_domain(n, 'random', locations=['gamma4'], variables=['x'])
|
||||||
|
# poisson_problem.discretise_domain(n, 'random', locations=['gamma4'], variables=['y'])
|
||||||
|
# pinn = GPINN(problem = poisson_problem, model=model, extra_features=None, loss=LpLoss())
|
||||||
|
# trainer = Trainer(solver=pinn, kwargs={'max_epochs' : 5, 'accelerator':'cpu'})
|
||||||
|
# trainer.train()
|
||||||
|
|
||||||
|
|
||||||
|
def test_train_extra_feats_cpu():
|
||||||
|
poisson_problem = Poisson()
|
||||||
|
boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
|
||||||
|
n = 10
|
||||||
|
poisson_problem.discretise_domain(n, 'grid', locations=boundaries)
|
||||||
|
pinn = GPINN(problem=poisson_problem,
|
||||||
|
model=model_extra_feats,
|
||||||
|
extra_features=extra_feats)
|
||||||
|
trainer = Trainer(solver=pinn, max_epochs=5, accelerator='cpu')
|
||||||
|
trainer.train()
|
||||||
|
|
||||||
|
|
||||||
|
# TODO, fix GitHub actions to run also on GPU
|
||||||
|
# def test_train_gpu():
|
||||||
|
# poisson_problem = Poisson()
|
||||||
|
# boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
|
||||||
|
# n = 10
|
||||||
|
# poisson_problem.discretise_domain(n, 'grid', locations=boundaries)
|
||||||
|
# pinn = GPINN(problem = poisson_problem, model=model, extra_features=None, loss=LpLoss())
|
||||||
|
# trainer = Trainer(solver=pinn, kwargs={'max_epochs' : 5, 'accelerator':'gpu'})
|
||||||
|
# trainer.train()
|
||||||
|
|
||||||
|
# def test_train_gpu(): #TODO fix ASAP
|
||||||
|
# poisson_problem = Poisson()
|
||||||
|
# boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
|
||||||
|
# n = 10
|
||||||
|
# poisson_problem.discretise_domain(n, 'grid', locations=boundaries)
|
||||||
|
# poisson_problem.conditions.pop('data') # The input/output pts are allocated on cpu
|
||||||
|
# pinn = GPINN(problem = poisson_problem, model=model, extra_features=None, loss=LpLoss())
|
||||||
|
# trainer = Trainer(solver=pinn, kwargs={'max_epochs' : 5, 'accelerator':'gpu'})
|
||||||
|
# trainer.train()
|
||||||
|
|
||||||
|
# def test_train_2():
|
||||||
|
# boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
|
||||||
|
# n = 10
|
||||||
|
# expected_keys = [[], list(range(0, 50, 3))]
|
||||||
|
# param = [0, 3]
|
||||||
|
# for i, truth_key in zip(param, expected_keys):
|
||||||
|
# pinn = GPINN(problem, model)
|
||||||
|
# pinn.discretise_domain(n, 'grid', locations=boundaries)
|
||||||
|
# pinn.discretise_domain(n, 'grid', locations=['D'])
|
||||||
|
# pinn.train(50, save_loss=i)
|
||||||
|
# assert list(pinn.history_loss.keys()) == truth_key
|
||||||
|
|
||||||
|
|
||||||
|
# def test_train_extra_feats():
|
||||||
|
# pinn = GPINN(problem, model_extra_feat, [myFeature()])
|
||||||
|
# boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
|
||||||
|
# n = 10
|
||||||
|
# pinn.discretise_domain(n, 'grid', locations=boundaries)
|
||||||
|
# pinn.discretise_domain(n, 'grid', locations=['D'])
|
||||||
|
# pinn.train(5)
|
||||||
|
|
||||||
|
|
||||||
|
# def test_train_2_extra_feats():
|
||||||
|
# boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
|
||||||
|
# n = 10
|
||||||
|
# expected_keys = [[], list(range(0, 50, 3))]
|
||||||
|
# param = [0, 3]
|
||||||
|
# for i, truth_key in zip(param, expected_keys):
|
||||||
|
# pinn = GPINN(problem, model_extra_feat, [myFeature()])
|
||||||
|
# pinn.discretise_domain(n, 'grid', locations=boundaries)
|
||||||
|
# pinn.discretise_domain(n, 'grid', locations=['D'])
|
||||||
|
# pinn.train(50, save_loss=i)
|
||||||
|
# assert list(pinn.history_loss.keys()) == truth_key
|
||||||
|
|
||||||
|
|
||||||
|
# def test_train_with_optimizer_kwargs():
|
||||||
|
# boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
|
||||||
|
# n = 10
|
||||||
|
# expected_keys = [[], list(range(0, 50, 3))]
|
||||||
|
# param = [0, 3]
|
||||||
|
# for i, truth_key in zip(param, expected_keys):
|
||||||
|
# pinn = GPINN(problem, model, optimizer_kwargs={'lr' : 0.3})
|
||||||
|
# pinn.discretise_domain(n, 'grid', locations=boundaries)
|
||||||
|
# pinn.discretise_domain(n, 'grid', locations=['D'])
|
||||||
|
# pinn.train(50, save_loss=i)
|
||||||
|
# assert list(pinn.history_loss.keys()) == truth_key
|
||||||
|
|
||||||
|
|
||||||
|
# def test_train_with_lr_scheduler():
|
||||||
|
# boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
|
||||||
|
# n = 10
|
||||||
|
# expected_keys = [[], list(range(0, 50, 3))]
|
||||||
|
# param = [0, 3]
|
||||||
|
# for i, truth_key in zip(param, expected_keys):
|
||||||
|
# pinn = GPINN(
|
||||||
|
# problem,
|
||||||
|
# model,
|
||||||
|
# lr_scheduler_type=torch.optim.lr_scheduler.CyclicLR,
|
||||||
|
# lr_scheduler_kwargs={'base_lr' : 0.1, 'max_lr' : 0.3, 'cycle_momentum': False}
|
||||||
|
# )
|
||||||
|
# pinn.discretise_domain(n, 'grid', locations=boundaries)
|
||||||
|
# pinn.discretise_domain(n, 'grid', locations=['D'])
|
||||||
|
# pinn.train(50, save_loss=i)
|
||||||
|
# assert list(pinn.history_loss.keys()) == truth_key
|
||||||
|
|
||||||
|
|
||||||
|
# # def test_train_batch():
|
||||||
|
# # pinn = GPINN(problem, model, batch_size=6)
|
||||||
|
# # boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
|
||||||
|
# # n = 10
|
||||||
|
# # pinn.discretise_domain(n, 'grid', locations=boundaries)
|
||||||
|
# # pinn.discretise_domain(n, 'grid', locations=['D'])
|
||||||
|
# # pinn.train(5)
|
||||||
|
|
||||||
|
|
||||||
|
# # def test_train_batch_2():
|
||||||
|
# # boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
|
||||||
|
# # n = 10
|
||||||
|
# # expected_keys = [[], list(range(0, 50, 3))]
|
||||||
|
# # param = [0, 3]
|
||||||
|
# # for i, truth_key in zip(param, expected_keys):
|
||||||
|
# # pinn = GPINN(problem, model, batch_size=6)
|
||||||
|
# # pinn.discretise_domain(n, 'grid', locations=boundaries)
|
||||||
|
# # pinn.discretise_domain(n, 'grid', locations=['D'])
|
||||||
|
# # pinn.train(50, save_loss=i)
|
||||||
|
# # assert list(pinn.history_loss.keys()) == truth_key
|
||||||
|
|
||||||
|
|
||||||
|
# if torch.cuda.is_available():
|
||||||
|
|
||||||
|
# # def test_gpu_train():
|
||||||
|
# # pinn = GPINN(problem, model, batch_size=20, device='cuda')
|
||||||
|
# # boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
|
||||||
|
# # n = 100
|
||||||
|
# # pinn.discretise_domain(n, 'grid', locations=boundaries)
|
||||||
|
# # pinn.discretise_domain(n, 'grid', locations=['D'])
|
||||||
|
# # pinn.train(5)
|
||||||
|
|
||||||
|
# def test_gpu_train_nobatch():
|
||||||
|
# pinn = GPINN(problem, model, batch_size=None, device='cuda')
|
||||||
|
# boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
|
||||||
|
# n = 100
|
||||||
|
# pinn.discretise_domain(n, 'grid', locations=boundaries)
|
||||||
|
# pinn.discretise_domain(n, 'grid', locations=['D'])
|
||||||
|
# pinn.train(5)
|
||||||
|
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from pina.problem import SpatialProblem
|
from pina.problem import SpatialProblem, InverseProblem
|
||||||
from pina.operators import laplacian
|
from pina.operators import laplacian
|
||||||
from pina.geometry import CartesianDomain
|
from pina.geometry import CartesianDomain
|
||||||
from pina import Condition, LabelTensor
|
from pina import Condition, LabelTensor
|
||||||
@@ -26,6 +26,58 @@ in2_ = LabelTensor(torch.rand(60, 2), ['x', 'y'])
|
|||||||
out2_ = LabelTensor(torch.rand(60, 1), ['u'])
|
out2_ = LabelTensor(torch.rand(60, 1), ['u'])
|
||||||
|
|
||||||
|
|
||||||
|
class InversePoisson(SpatialProblem, InverseProblem):
|
||||||
|
'''
|
||||||
|
Problem definition for the Poisson equation.
|
||||||
|
'''
|
||||||
|
output_variables = ['u']
|
||||||
|
x_min = -2
|
||||||
|
x_max = 2
|
||||||
|
y_min = -2
|
||||||
|
y_max = 2
|
||||||
|
data_input = LabelTensor(torch.rand(10, 2), ['x', 'y'])
|
||||||
|
data_output = LabelTensor(torch.rand(10, 1), ['u'])
|
||||||
|
spatial_domain = CartesianDomain({'x': [x_min, x_max], 'y': [y_min, y_max]})
|
||||||
|
# define the ranges for the parameters
|
||||||
|
unknown_parameter_domain = CartesianDomain({'mu1': [-1, 1], 'mu2': [-1, 1]})
|
||||||
|
|
||||||
|
def laplace_equation(input_, output_, params_):
|
||||||
|
'''
|
||||||
|
Laplace equation with a force term.
|
||||||
|
'''
|
||||||
|
force_term = torch.exp(
|
||||||
|
- 2*(input_.extract(['x']) - params_['mu1'])**2
|
||||||
|
- 2*(input_.extract(['y']) - params_['mu2'])**2)
|
||||||
|
delta_u = laplacian(output_, input_, components=['u'], d=['x', 'y'])
|
||||||
|
|
||||||
|
return delta_u - force_term
|
||||||
|
|
||||||
|
# define the conditions for the loss (boundary conditions, equation, data)
|
||||||
|
conditions = {
|
||||||
|
'gamma1': Condition(location=CartesianDomain({'x': [x_min, x_max],
|
||||||
|
'y': y_max}),
|
||||||
|
equation=FixedValue(0.0, components=['u'])),
|
||||||
|
'gamma2': Condition(location=CartesianDomain(
|
||||||
|
{'x': [x_min, x_max], 'y': y_min
|
||||||
|
}),
|
||||||
|
equation=FixedValue(0.0, components=['u'])),
|
||||||
|
'gamma3': Condition(location=CartesianDomain(
|
||||||
|
{'x': x_max, 'y': [y_min, y_max]
|
||||||
|
}),
|
||||||
|
equation=FixedValue(0.0, components=['u'])),
|
||||||
|
'gamma4': Condition(location=CartesianDomain(
|
||||||
|
{'x': x_min, 'y': [y_min, y_max]
|
||||||
|
}),
|
||||||
|
equation=FixedValue(0.0, components=['u'])),
|
||||||
|
'D': Condition(location=CartesianDomain(
|
||||||
|
{'x': [x_min, x_max], 'y': [y_min, y_max]
|
||||||
|
}),
|
||||||
|
equation=Equation(laplace_equation)),
|
||||||
|
'data': Condition(input_points=data_input.extract(['x', 'y']),
|
||||||
|
output_points=data_output)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class Poisson(SpatialProblem):
|
class Poisson(SpatialProblem):
|
||||||
output_variables = ['u']
|
output_variables = ['u']
|
||||||
spatial_domain = CartesianDomain({'x': [0, 1], 'y': [0, 1]})
|
spatial_domain = CartesianDomain({'x': [0, 1], 'y': [0, 1]})
|
||||||
@@ -103,8 +155,10 @@ def test_train_cpu():
|
|||||||
boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
|
boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
|
||||||
n = 10
|
n = 10
|
||||||
poisson_problem.discretise_domain(n, 'grid', locations=boundaries)
|
poisson_problem.discretise_domain(n, 'grid', locations=boundaries)
|
||||||
pinn = PINN(problem = poisson_problem, model=model, extra_features=None, loss=LpLoss())
|
pinn = PINN(problem = poisson_problem, model=model,
|
||||||
trainer = Trainer(solver=pinn, max_epochs=1, accelerator='cpu', batch_size=20)
|
extra_features=None, loss=LpLoss())
|
||||||
|
trainer = Trainer(solver=pinn, max_epochs=1,
|
||||||
|
accelerator='cpu', batch_size=20)
|
||||||
trainer.train()
|
trainer.train()
|
||||||
|
|
||||||
|
|
||||||
@@ -125,7 +179,8 @@ def test_train_restore():
|
|||||||
trainer.train()
|
trainer.train()
|
||||||
ntrainer = Trainer(solver=pinn, max_epochs=15, accelerator='cpu')
|
ntrainer = Trainer(solver=pinn, max_epochs=15, accelerator='cpu')
|
||||||
t = ntrainer.train(
|
t = ntrainer.train(
|
||||||
ckpt_path=f'{tmpdir}/lightning_logs/version_0/checkpoints/epoch=4-step=10.ckpt')
|
ckpt_path=f'{tmpdir}/lightning_logs/version_0/'
|
||||||
|
'checkpoints/epoch=4-step=10.ckpt')
|
||||||
import shutil
|
import shutil
|
||||||
shutil.rmtree(tmpdir)
|
shutil.rmtree(tmpdir)
|
||||||
|
|
||||||
@@ -158,6 +213,68 @@ def test_train_load():
|
|||||||
import shutil
|
import shutil
|
||||||
shutil.rmtree(tmpdir)
|
shutil.rmtree(tmpdir)
|
||||||
|
|
||||||
|
def test_train_inverse_problem_cpu():
|
||||||
|
poisson_problem = InversePoisson()
|
||||||
|
boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4', 'D']
|
||||||
|
n = 100
|
||||||
|
poisson_problem.discretise_domain(n, 'random', locations=boundaries)
|
||||||
|
pinn = PINN(problem = poisson_problem, model=model,
|
||||||
|
extra_features=None, loss=LpLoss())
|
||||||
|
trainer = Trainer(solver=pinn, max_epochs=1,
|
||||||
|
accelerator='cpu', batch_size=20)
|
||||||
|
trainer.train()
|
||||||
|
|
||||||
|
|
||||||
|
# # TODO does not currently work
|
||||||
|
# def test_train_inverse_problem_restore():
|
||||||
|
# tmpdir = "tests/tmp_restore_inv"
|
||||||
|
# poisson_problem = InversePoisson()
|
||||||
|
# boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4', 'D']
|
||||||
|
# n = 100
|
||||||
|
# poisson_problem.discretise_domain(n, 'random', locations=boundaries)
|
||||||
|
# pinn = PINN(problem=poisson_problem,
|
||||||
|
# model=model,
|
||||||
|
# extra_features=None,
|
||||||
|
# loss=LpLoss())
|
||||||
|
# trainer = Trainer(solver=pinn,
|
||||||
|
# max_epochs=5,
|
||||||
|
# accelerator='cpu',
|
||||||
|
# default_root_dir=tmpdir)
|
||||||
|
# trainer.train()
|
||||||
|
# ntrainer = Trainer(solver=pinn, max_epochs=5, accelerator='cpu')
|
||||||
|
# t = ntrainer.train(
|
||||||
|
# ckpt_path=f'{tmpdir}/lightning_logs/version_0/checkpoints/epoch=4-step=10.ckpt')
|
||||||
|
# import shutil
|
||||||
|
# shutil.rmtree(tmpdir)
|
||||||
|
|
||||||
|
|
||||||
|
def test_train_inverse_problem_load():
|
||||||
|
tmpdir = "tests/tmp_load_inv"
|
||||||
|
poisson_problem = InversePoisson()
|
||||||
|
boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4', 'D']
|
||||||
|
n = 100
|
||||||
|
poisson_problem.discretise_domain(n, 'random', locations=boundaries)
|
||||||
|
pinn = PINN(problem=poisson_problem,
|
||||||
|
model=model,
|
||||||
|
extra_features=None,
|
||||||
|
loss=LpLoss())
|
||||||
|
trainer = Trainer(solver=pinn,
|
||||||
|
max_epochs=15,
|
||||||
|
accelerator='cpu',
|
||||||
|
default_root_dir=tmpdir)
|
||||||
|
trainer.train()
|
||||||
|
new_pinn = PINN.load_from_checkpoint(
|
||||||
|
f'{tmpdir}/lightning_logs/version_0/checkpoints/epoch=14-step=30.ckpt',
|
||||||
|
problem = poisson_problem, model=model)
|
||||||
|
test_pts = CartesianDomain({'x': [0, 1], 'y': [0, 1]}).sample(10)
|
||||||
|
assert new_pinn.forward(test_pts).extract(['u']).shape == (10, 1)
|
||||||
|
assert new_pinn.forward(test_pts).extract(
|
||||||
|
['u']).shape == pinn.forward(test_pts).extract(['u']).shape
|
||||||
|
torch.testing.assert_close(
|
||||||
|
new_pinn.forward(test_pts).extract(['u']),
|
||||||
|
pinn.forward(test_pts).extract(['u']))
|
||||||
|
import shutil
|
||||||
|
shutil.rmtree(tmpdir)
|
||||||
|
|
||||||
# # TODO fix asap. Basically sampling few variables
|
# # TODO fix asap. Basically sampling few variables
|
||||||
# # works only if both variables are in a range.
|
# # works only if both variables are in a range.
|
||||||
@@ -197,85 +314,32 @@ def test_train_extra_feats_cpu():
|
|||||||
# pinn = PINN(problem = poisson_problem, model=model, extra_features=None, loss=LpLoss())
|
# pinn = PINN(problem = poisson_problem, model=model, extra_features=None, loss=LpLoss())
|
||||||
# trainer = Trainer(solver=pinn, kwargs={'max_epochs' : 5, 'accelerator':'gpu'})
|
# trainer = Trainer(solver=pinn, kwargs={'max_epochs' : 5, 'accelerator':'gpu'})
|
||||||
# trainer.train()
|
# trainer.train()
|
||||||
"""
|
|
||||||
def test_train_gpu(): #TODO fix ASAP
|
|
||||||
poisson_problem = Poisson()
|
|
||||||
boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
|
|
||||||
n = 10
|
|
||||||
poisson_problem.discretise_domain(n, 'grid', locations=boundaries)
|
|
||||||
poisson_problem.conditions.pop('data') # The input/output pts are allocated on cpu
|
|
||||||
pinn = PINN(problem = poisson_problem, model=model, extra_features=None, loss=LpLoss())
|
|
||||||
trainer = Trainer(solver=pinn, kwargs={'max_epochs' : 5, 'accelerator':'gpu'})
|
|
||||||
trainer.train()
|
|
||||||
|
|
||||||
def test_train_2():
|
# def test_train_gpu(): #TODO fix ASAP
|
||||||
boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
|
# poisson_problem = Poisson()
|
||||||
n = 10
|
# boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
|
||||||
expected_keys = [[], list(range(0, 50, 3))]
|
# n = 10
|
||||||
param = [0, 3]
|
# poisson_problem.discretise_domain(n, 'grid', locations=boundaries)
|
||||||
for i, truth_key in zip(param, expected_keys):
|
# poisson_problem.conditions.pop('data') # The input/output pts are allocated on cpu
|
||||||
pinn = PINN(problem, model)
|
# pinn = PINN(problem = poisson_problem, model=model, extra_features=None, loss=LpLoss())
|
||||||
pinn.discretise_domain(n, 'grid', locations=boundaries)
|
# trainer = Trainer(solver=pinn, kwargs={'max_epochs' : 5, 'accelerator':'gpu'})
|
||||||
pinn.discretise_domain(n, 'grid', locations=['D'])
|
# trainer.train()
|
||||||
pinn.train(50, save_loss=i)
|
|
||||||
assert list(pinn.history_loss.keys()) == truth_key
|
# def test_train_2():
|
||||||
|
# boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
|
||||||
|
# n = 10
|
||||||
|
# expected_keys = [[], list(range(0, 50, 3))]
|
||||||
|
# param = [0, 3]
|
||||||
|
# for i, truth_key in zip(param, expected_keys):
|
||||||
|
# pinn = PINN(problem, model)
|
||||||
|
# pinn.discretise_domain(n, 'grid', locations=boundaries)
|
||||||
|
# pinn.discretise_domain(n, 'grid', locations=['D'])
|
||||||
|
# pinn.train(50, save_loss=i)
|
||||||
|
# assert list(pinn.history_loss.keys()) == truth_key
|
||||||
|
|
||||||
|
|
||||||
def test_train_extra_feats():
|
# def test_train_extra_feats():
|
||||||
pinn = PINN(problem, model_extra_feat, [myFeature()])
|
# pinn = PINN(problem, model_extra_feat, [myFeature()])
|
||||||
boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
|
|
||||||
n = 10
|
|
||||||
pinn.discretise_domain(n, 'grid', locations=boundaries)
|
|
||||||
pinn.discretise_domain(n, 'grid', locations=['D'])
|
|
||||||
pinn.train(5)
|
|
||||||
|
|
||||||
|
|
||||||
def test_train_2_extra_feats():
|
|
||||||
boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
|
|
||||||
n = 10
|
|
||||||
expected_keys = [[], list(range(0, 50, 3))]
|
|
||||||
param = [0, 3]
|
|
||||||
for i, truth_key in zip(param, expected_keys):
|
|
||||||
pinn = PINN(problem, model_extra_feat, [myFeature()])
|
|
||||||
pinn.discretise_domain(n, 'grid', locations=boundaries)
|
|
||||||
pinn.discretise_domain(n, 'grid', locations=['D'])
|
|
||||||
pinn.train(50, save_loss=i)
|
|
||||||
assert list(pinn.history_loss.keys()) == truth_key
|
|
||||||
|
|
||||||
|
|
||||||
def test_train_with_optimizer_kwargs():
|
|
||||||
boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
|
|
||||||
n = 10
|
|
||||||
expected_keys = [[], list(range(0, 50, 3))]
|
|
||||||
param = [0, 3]
|
|
||||||
for i, truth_key in zip(param, expected_keys):
|
|
||||||
pinn = PINN(problem, model, optimizer_kwargs={'lr' : 0.3})
|
|
||||||
pinn.discretise_domain(n, 'grid', locations=boundaries)
|
|
||||||
pinn.discretise_domain(n, 'grid', locations=['D'])
|
|
||||||
pinn.train(50, save_loss=i)
|
|
||||||
assert list(pinn.history_loss.keys()) == truth_key
|
|
||||||
|
|
||||||
|
|
||||||
def test_train_with_lr_scheduler():
|
|
||||||
boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
|
|
||||||
n = 10
|
|
||||||
expected_keys = [[], list(range(0, 50, 3))]
|
|
||||||
param = [0, 3]
|
|
||||||
for i, truth_key in zip(param, expected_keys):
|
|
||||||
pinn = PINN(
|
|
||||||
problem,
|
|
||||||
model,
|
|
||||||
lr_scheduler_type=torch.optim.lr_scheduler.CyclicLR,
|
|
||||||
lr_scheduler_kwargs={'base_lr' : 0.1, 'max_lr' : 0.3, 'cycle_momentum': False}
|
|
||||||
)
|
|
||||||
pinn.discretise_domain(n, 'grid', locations=boundaries)
|
|
||||||
pinn.discretise_domain(n, 'grid', locations=['D'])
|
|
||||||
pinn.train(50, save_loss=i)
|
|
||||||
assert list(pinn.history_loss.keys()) == truth_key
|
|
||||||
|
|
||||||
|
|
||||||
# def test_train_batch():
|
|
||||||
# pinn = PINN(problem, model, batch_size=6)
|
|
||||||
# boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
|
# boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
|
||||||
# n = 10
|
# n = 10
|
||||||
# pinn.discretise_domain(n, 'grid', locations=boundaries)
|
# pinn.discretise_domain(n, 'grid', locations=boundaries)
|
||||||
@@ -283,34 +347,87 @@ def test_train_with_lr_scheduler():
|
|||||||
# pinn.train(5)
|
# pinn.train(5)
|
||||||
|
|
||||||
|
|
||||||
# def test_train_batch_2():
|
# def test_train_2_extra_feats():
|
||||||
# boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
|
# boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
|
||||||
# n = 10
|
# n = 10
|
||||||
# expected_keys = [[], list(range(0, 50, 3))]
|
# expected_keys = [[], list(range(0, 50, 3))]
|
||||||
# param = [0, 3]
|
# param = [0, 3]
|
||||||
# for i, truth_key in zip(param, expected_keys):
|
# for i, truth_key in zip(param, expected_keys):
|
||||||
# pinn = PINN(problem, model, batch_size=6)
|
# pinn = PINN(problem, model_extra_feat, [myFeature()])
|
||||||
# pinn.discretise_domain(n, 'grid', locations=boundaries)
|
# pinn.discretise_domain(n, 'grid', locations=boundaries)
|
||||||
# pinn.discretise_domain(n, 'grid', locations=['D'])
|
# pinn.discretise_domain(n, 'grid', locations=['D'])
|
||||||
# pinn.train(50, save_loss=i)
|
# pinn.train(50, save_loss=i)
|
||||||
# assert list(pinn.history_loss.keys()) == truth_key
|
# assert list(pinn.history_loss.keys()) == truth_key
|
||||||
|
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
# def test_train_with_optimizer_kwargs():
|
||||||
|
# boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
|
||||||
|
# n = 10
|
||||||
|
# expected_keys = [[], list(range(0, 50, 3))]
|
||||||
|
# param = [0, 3]
|
||||||
|
# for i, truth_key in zip(param, expected_keys):
|
||||||
|
# pinn = PINN(problem, model, optimizer_kwargs={'lr' : 0.3})
|
||||||
|
# pinn.discretise_domain(n, 'grid', locations=boundaries)
|
||||||
|
# pinn.discretise_domain(n, 'grid', locations=['D'])
|
||||||
|
# pinn.train(50, save_loss=i)
|
||||||
|
# assert list(pinn.history_loss.keys()) == truth_key
|
||||||
|
|
||||||
# def test_gpu_train():
|
|
||||||
# pinn = PINN(problem, model, batch_size=20, device='cuda')
|
|
||||||
# boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
|
|
||||||
# n = 100
|
|
||||||
# pinn.discretise_domain(n, 'grid', locations=boundaries)
|
|
||||||
# pinn.discretise_domain(n, 'grid', locations=['D'])
|
|
||||||
# pinn.train(5)
|
|
||||||
|
|
||||||
def test_gpu_train_nobatch():
|
# def test_train_with_lr_scheduler():
|
||||||
pinn = PINN(problem, model, batch_size=None, device='cuda')
|
# boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
|
||||||
boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
|
# n = 10
|
||||||
n = 100
|
# expected_keys = [[], list(range(0, 50, 3))]
|
||||||
pinn.discretise_domain(n, 'grid', locations=boundaries)
|
# param = [0, 3]
|
||||||
pinn.discretise_domain(n, 'grid', locations=['D'])
|
# for i, truth_key in zip(param, expected_keys):
|
||||||
pinn.train(5)
|
# pinn = PINN(
|
||||||
"""
|
# problem,
|
||||||
|
# model,
|
||||||
|
# lr_scheduler_type=torch.optim.lr_scheduler.CyclicLR,
|
||||||
|
# lr_scheduler_kwargs={'base_lr' : 0.1, 'max_lr' : 0.3, 'cycle_momentum': False}
|
||||||
|
# )
|
||||||
|
# pinn.discretise_domain(n, 'grid', locations=boundaries)
|
||||||
|
# pinn.discretise_domain(n, 'grid', locations=['D'])
|
||||||
|
# pinn.train(50, save_loss=i)
|
||||||
|
# assert list(pinn.history_loss.keys()) == truth_key
|
||||||
|
|
||||||
|
|
||||||
|
# # def test_train_batch():
|
||||||
|
# # pinn = PINN(problem, model, batch_size=6)
|
||||||
|
# # boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
|
||||||
|
# # n = 10
|
||||||
|
# # pinn.discretise_domain(n, 'grid', locations=boundaries)
|
||||||
|
# # pinn.discretise_domain(n, 'grid', locations=['D'])
|
||||||
|
# # pinn.train(5)
|
||||||
|
|
||||||
|
|
||||||
|
# # def test_train_batch_2():
|
||||||
|
# # boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
|
||||||
|
# # n = 10
|
||||||
|
# # expected_keys = [[], list(range(0, 50, 3))]
|
||||||
|
# # param = [0, 3]
|
||||||
|
# # for i, truth_key in zip(param, expected_keys):
|
||||||
|
# # pinn = PINN(problem, model, batch_size=6)
|
||||||
|
# # pinn.discretise_domain(n, 'grid', locations=boundaries)
|
||||||
|
# # pinn.discretise_domain(n, 'grid', locations=['D'])
|
||||||
|
# # pinn.train(50, save_loss=i)
|
||||||
|
# # assert list(pinn.history_loss.keys()) == truth_key
|
||||||
|
|
||||||
|
|
||||||
|
# if torch.cuda.is_available():
|
||||||
|
|
||||||
|
# # def test_gpu_train():
|
||||||
|
# # pinn = PINN(problem, model, batch_size=20, device='cuda')
|
||||||
|
# # boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
|
||||||
|
# # n = 100
|
||||||
|
# # pinn.discretise_domain(n, 'grid', locations=boundaries)
|
||||||
|
# # pinn.discretise_domain(n, 'grid', locations=['D'])
|
||||||
|
# # pinn.train(5)
|
||||||
|
|
||||||
|
# def test_gpu_train_nobatch():
|
||||||
|
# pinn = PINN(problem, model, batch_size=None, device='cuda')
|
||||||
|
# boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
|
||||||
|
# n = 100
|
||||||
|
# pinn.discretise_domain(n, 'grid', locations=boundaries)
|
||||||
|
# pinn.discretise_domain(n, 'grid', locations=['D'])
|
||||||
|
# pinn.train(5)
|
||||||
|
|
||||||
|
|||||||
105
tests/test_solvers/test_rom_solver.py
Normal file
105
tests/test_solvers/test_rom_solver.py
Normal file
@@ -0,0 +1,105 @@
|
|||||||
|
import torch
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from pina.problem import AbstractProblem
|
||||||
|
from pina import Condition, LabelTensor
|
||||||
|
from pina.solvers import ReducedOrderModelSolver
|
||||||
|
from pina.trainer import Trainer
|
||||||
|
from pina.model import FeedForward
|
||||||
|
from pina.loss import LpLoss
|
||||||
|
|
||||||
|
|
||||||
|
class NeuralOperatorProblem(AbstractProblem):
|
||||||
|
input_variables = ['u_0', 'u_1']
|
||||||
|
output_variables = [f'u_{i}' for i in range(100)]
|
||||||
|
conditions = {'data' : Condition(input_points=
|
||||||
|
LabelTensor(torch.rand(10, 2),
|
||||||
|
input_variables),
|
||||||
|
output_points=
|
||||||
|
LabelTensor(torch.rand(10, 100),
|
||||||
|
output_variables))}
|
||||||
|
|
||||||
|
|
||||||
|
# make the problem + extra feats
|
||||||
|
class AE(torch.nn.Module):
|
||||||
|
def __init__(self, input_dimensions, rank):
|
||||||
|
super().__init__()
|
||||||
|
self.encode = FeedForward(input_dimensions, rank, layers=[input_dimensions//4])
|
||||||
|
self.decode = FeedForward(rank, input_dimensions, layers=[input_dimensions//4])
|
||||||
|
class AE_missing_encode(torch.nn.Module):
|
||||||
|
def __init__(self, input_dimensions, rank):
|
||||||
|
super().__init__()
|
||||||
|
self.encode = FeedForward(input_dimensions, rank, layers=[input_dimensions//4])
|
||||||
|
class AE_missing_decode(torch.nn.Module):
|
||||||
|
def __init__(self, input_dimensions, rank):
|
||||||
|
super().__init__()
|
||||||
|
self.decode = FeedForward(rank, input_dimensions, layers=[input_dimensions//4])
|
||||||
|
|
||||||
|
rank = 10
|
||||||
|
problem = NeuralOperatorProblem()
|
||||||
|
interpolation_net = FeedForward(len(problem.input_variables),
|
||||||
|
rank)
|
||||||
|
reduction_net = AE(len(problem.output_variables), rank)
|
||||||
|
|
||||||
|
def test_constructor():
|
||||||
|
ReducedOrderModelSolver(problem=problem,reduction_network=reduction_net,
|
||||||
|
interpolation_network=interpolation_net)
|
||||||
|
with pytest.raises(SyntaxError):
|
||||||
|
ReducedOrderModelSolver(problem=problem,
|
||||||
|
reduction_network=AE_missing_encode(
|
||||||
|
len(problem.output_variables), rank),
|
||||||
|
interpolation_network=interpolation_net)
|
||||||
|
ReducedOrderModelSolver(problem=problem,
|
||||||
|
reduction_network=AE_missing_decode(
|
||||||
|
len(problem.output_variables), rank),
|
||||||
|
interpolation_network=interpolation_net)
|
||||||
|
|
||||||
|
|
||||||
|
def test_train_cpu():
|
||||||
|
solver = ReducedOrderModelSolver(problem = problem,reduction_network=reduction_net,
|
||||||
|
interpolation_network=interpolation_net, loss=LpLoss())
|
||||||
|
trainer = Trainer(solver=solver, max_epochs=3, accelerator='cpu', batch_size=20)
|
||||||
|
trainer.train()
|
||||||
|
|
||||||
|
|
||||||
|
def test_train_restore():
|
||||||
|
tmpdir = "tests/tmp_restore"
|
||||||
|
solver = ReducedOrderModelSolver(problem=problem,
|
||||||
|
reduction_network=reduction_net,
|
||||||
|
interpolation_network=interpolation_net,
|
||||||
|
loss=LpLoss())
|
||||||
|
trainer = Trainer(solver=solver,
|
||||||
|
max_epochs=5,
|
||||||
|
accelerator='cpu',
|
||||||
|
default_root_dir=tmpdir)
|
||||||
|
trainer.train()
|
||||||
|
ntrainer = Trainer(solver=solver, max_epochs=15, accelerator='cpu')
|
||||||
|
t = ntrainer.train(
|
||||||
|
ckpt_path=f'{tmpdir}/lightning_logs/version_0/checkpoints/epoch=4-step=5.ckpt')
|
||||||
|
import shutil
|
||||||
|
shutil.rmtree(tmpdir)
|
||||||
|
|
||||||
|
|
||||||
|
def test_train_load():
|
||||||
|
tmpdir = "tests/tmp_load"
|
||||||
|
solver = ReducedOrderModelSolver(problem=problem,
|
||||||
|
reduction_network=reduction_net,
|
||||||
|
interpolation_network=interpolation_net,
|
||||||
|
loss=LpLoss())
|
||||||
|
trainer = Trainer(solver=solver,
|
||||||
|
max_epochs=15,
|
||||||
|
accelerator='cpu',
|
||||||
|
default_root_dir=tmpdir)
|
||||||
|
trainer.train()
|
||||||
|
new_solver = ReducedOrderModelSolver.load_from_checkpoint(
|
||||||
|
f'{tmpdir}/lightning_logs/version_0/checkpoints/epoch=14-step=15.ckpt',
|
||||||
|
problem = problem,reduction_network=reduction_net,
|
||||||
|
interpolation_network=interpolation_net)
|
||||||
|
test_pts = LabelTensor(torch.rand(20, 2), problem.input_variables)
|
||||||
|
assert new_solver.forward(test_pts).shape == (20, 100)
|
||||||
|
assert new_solver.forward(test_pts).shape == solver.forward(test_pts).shape
|
||||||
|
torch.testing.assert_close(
|
||||||
|
new_solver.forward(test_pts),
|
||||||
|
solver.forward(test_pts))
|
||||||
|
import shutil
|
||||||
|
shutil.rmtree(tmpdir)
|
||||||
437
tests/test_solvers/test_sapinn.py
Normal file
437
tests/test_solvers/test_sapinn.py
Normal file
@@ -0,0 +1,437 @@
|
|||||||
|
import torch
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from pina.problem import SpatialProblem, InverseProblem
|
||||||
|
from pina.operators import laplacian
|
||||||
|
from pina.geometry import CartesianDomain
|
||||||
|
from pina import Condition, LabelTensor
|
||||||
|
from pina.solvers import SAPINN as PINN
|
||||||
|
from pina.trainer import Trainer
|
||||||
|
from pina.model import FeedForward
|
||||||
|
from pina.equation.equation import Equation
|
||||||
|
from pina.equation.equation_factory import FixedValue
|
||||||
|
from pina.loss import LpLoss
|
||||||
|
|
||||||
|
|
||||||
|
def laplace_equation(input_, output_):
|
||||||
|
force_term = (torch.sin(input_.extract(['x']) * torch.pi) *
|
||||||
|
torch.sin(input_.extract(['y']) * torch.pi))
|
||||||
|
delta_u = laplacian(output_.extract(['u']), input_)
|
||||||
|
return delta_u - force_term
|
||||||
|
|
||||||
|
|
||||||
|
my_laplace = Equation(laplace_equation)
|
||||||
|
in_ = LabelTensor(torch.tensor([[0., 1.]]), ['x', 'y'])
|
||||||
|
out_ = LabelTensor(torch.tensor([[0.]]), ['u'])
|
||||||
|
in2_ = LabelTensor(torch.rand(60, 2), ['x', 'y'])
|
||||||
|
out2_ = LabelTensor(torch.rand(60, 1), ['u'])
|
||||||
|
|
||||||
|
|
||||||
|
class InversePoisson(SpatialProblem, InverseProblem):
|
||||||
|
'''
|
||||||
|
Problem definition for the Poisson equation.
|
||||||
|
'''
|
||||||
|
output_variables = ['u']
|
||||||
|
x_min = -2
|
||||||
|
x_max = 2
|
||||||
|
y_min = -2
|
||||||
|
y_max = 2
|
||||||
|
data_input = LabelTensor(torch.rand(10, 2), ['x', 'y'])
|
||||||
|
data_output = LabelTensor(torch.rand(10, 1), ['u'])
|
||||||
|
spatial_domain = CartesianDomain({'x': [x_min, x_max], 'y': [y_min, y_max]})
|
||||||
|
# define the ranges for the parameters
|
||||||
|
unknown_parameter_domain = CartesianDomain({'mu1': [-1, 1], 'mu2': [-1, 1]})
|
||||||
|
|
||||||
|
def laplace_equation(input_, output_, params_):
|
||||||
|
'''
|
||||||
|
Laplace equation with a force term.
|
||||||
|
'''
|
||||||
|
force_term = torch.exp(
|
||||||
|
- 2*(input_.extract(['x']) - params_['mu1'])**2
|
||||||
|
- 2*(input_.extract(['y']) - params_['mu2'])**2)
|
||||||
|
delta_u = laplacian(output_, input_, components=['u'], d=['x', 'y'])
|
||||||
|
|
||||||
|
return delta_u - force_term
|
||||||
|
|
||||||
|
# define the conditions for the loss (boundary conditions, equation, data)
|
||||||
|
conditions = {
|
||||||
|
'gamma1': Condition(location=CartesianDomain({'x': [x_min, x_max],
|
||||||
|
'y': y_max}),
|
||||||
|
equation=FixedValue(0.0, components=['u'])),
|
||||||
|
'gamma2': Condition(location=CartesianDomain(
|
||||||
|
{'x': [x_min, x_max], 'y': y_min
|
||||||
|
}),
|
||||||
|
equation=FixedValue(0.0, components=['u'])),
|
||||||
|
'gamma3': Condition(location=CartesianDomain(
|
||||||
|
{'x': x_max, 'y': [y_min, y_max]
|
||||||
|
}),
|
||||||
|
equation=FixedValue(0.0, components=['u'])),
|
||||||
|
'gamma4': Condition(location=CartesianDomain(
|
||||||
|
{'x': x_min, 'y': [y_min, y_max]
|
||||||
|
}),
|
||||||
|
equation=FixedValue(0.0, components=['u'])),
|
||||||
|
'D': Condition(location=CartesianDomain(
|
||||||
|
{'x': [x_min, x_max], 'y': [y_min, y_max]
|
||||||
|
}),
|
||||||
|
equation=Equation(laplace_equation)),
|
||||||
|
'data': Condition(input_points=data_input.extract(['x', 'y']),
|
||||||
|
output_points=data_output)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class Poisson(SpatialProblem):
|
||||||
|
output_variables = ['u']
|
||||||
|
spatial_domain = CartesianDomain({'x': [0, 1], 'y': [0, 1]})
|
||||||
|
|
||||||
|
conditions = {
|
||||||
|
'gamma1': Condition(
|
||||||
|
location=CartesianDomain({'x': [0, 1], 'y': 1}),
|
||||||
|
equation=FixedValue(0.0)),
|
||||||
|
'gamma2': Condition(
|
||||||
|
location=CartesianDomain({'x': [0, 1], 'y': 0}),
|
||||||
|
equation=FixedValue(0.0)),
|
||||||
|
'gamma3': Condition(
|
||||||
|
location=CartesianDomain({'x': 1, 'y': [0, 1]}),
|
||||||
|
equation=FixedValue(0.0)),
|
||||||
|
'gamma4': Condition(
|
||||||
|
location=CartesianDomain({'x': 0, 'y': [0, 1]}),
|
||||||
|
equation=FixedValue(0.0)),
|
||||||
|
'D': Condition(
|
||||||
|
input_points=LabelTensor(torch.rand(size=(100, 2)), ['x', 'y']),
|
||||||
|
equation=my_laplace),
|
||||||
|
'data': Condition(
|
||||||
|
input_points=in_,
|
||||||
|
output_points=out_),
|
||||||
|
'data2': Condition(
|
||||||
|
input_points=in2_,
|
||||||
|
output_points=out2_)
|
||||||
|
}
|
||||||
|
|
||||||
|
def poisson_sol(self, pts):
|
||||||
|
return -(torch.sin(pts.extract(['x']) * torch.pi) *
|
||||||
|
torch.sin(pts.extract(['y']) * torch.pi)) / (2 * torch.pi**2)
|
||||||
|
|
||||||
|
truth_solution = poisson_sol
|
||||||
|
|
||||||
|
|
||||||
|
class myFeature(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
Feature: sin(x)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super(myFeature, self).__init__()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
t = (torch.sin(x.extract(['x']) * torch.pi) *
|
||||||
|
torch.sin(x.extract(['y']) * torch.pi))
|
||||||
|
return LabelTensor(t, ['sin(x)sin(y)'])
|
||||||
|
|
||||||
|
|
||||||
|
# make the problem
|
||||||
|
poisson_problem = Poisson()
|
||||||
|
model = FeedForward(len(poisson_problem.input_variables),
|
||||||
|
len(poisson_problem.output_variables))
|
||||||
|
model_extra_feats = FeedForward(
|
||||||
|
len(poisson_problem.input_variables) + 1,
|
||||||
|
len(poisson_problem.output_variables))
|
||||||
|
extra_feats = [myFeature()]
|
||||||
|
|
||||||
|
|
||||||
|
def test_constructor():
|
||||||
|
PINN(problem=poisson_problem, model=model, extra_features=None)
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
PINN(problem=poisson_problem, model=model, extra_features=None,
|
||||||
|
weights_function=1)
|
||||||
|
|
||||||
|
|
||||||
|
def test_constructor_extra_feats():
|
||||||
|
model_extra_feats = FeedForward(
|
||||||
|
len(poisson_problem.input_variables) + 1,
|
||||||
|
len(poisson_problem.output_variables))
|
||||||
|
PINN(problem=poisson_problem,
|
||||||
|
model=model_extra_feats,
|
||||||
|
extra_features=extra_feats)
|
||||||
|
|
||||||
|
|
||||||
|
def test_train_cpu():
|
||||||
|
poisson_problem = Poisson()
|
||||||
|
boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
|
||||||
|
n = 10
|
||||||
|
poisson_problem.discretise_domain(n, 'grid', locations=boundaries)
|
||||||
|
pinn = PINN(problem = poisson_problem, model=model,
|
||||||
|
extra_features=None, loss=LpLoss())
|
||||||
|
trainer = Trainer(solver=pinn, max_epochs=1,
|
||||||
|
accelerator='cpu', batch_size=20)
|
||||||
|
trainer.train()
|
||||||
|
|
||||||
|
|
||||||
|
def test_train_restore():
|
||||||
|
tmpdir = "tests/tmp_restore"
|
||||||
|
poisson_problem = Poisson()
|
||||||
|
boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
|
||||||
|
n = 10
|
||||||
|
poisson_problem.discretise_domain(n, 'grid', locations=boundaries)
|
||||||
|
pinn = PINN(problem=poisson_problem,
|
||||||
|
model=model,
|
||||||
|
extra_features=None,
|
||||||
|
loss=LpLoss())
|
||||||
|
trainer = Trainer(solver=pinn,
|
||||||
|
max_epochs=5,
|
||||||
|
accelerator='cpu',
|
||||||
|
default_root_dir=tmpdir)
|
||||||
|
trainer.train()
|
||||||
|
ntrainer = Trainer(solver=pinn, max_epochs=15, accelerator='cpu')
|
||||||
|
t = ntrainer.train(
|
||||||
|
ckpt_path=f'{tmpdir}/lightning_logs/version_0/'
|
||||||
|
'checkpoints/epoch=4-step=10.ckpt')
|
||||||
|
import shutil
|
||||||
|
shutil.rmtree(tmpdir)
|
||||||
|
|
||||||
|
|
||||||
|
def test_train_load():
|
||||||
|
tmpdir = "tests/tmp_load"
|
||||||
|
poisson_problem = Poisson()
|
||||||
|
boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
|
||||||
|
n = 10
|
||||||
|
poisson_problem.discretise_domain(n, 'grid', locations=boundaries)
|
||||||
|
pinn = PINN(problem=poisson_problem,
|
||||||
|
model=model,
|
||||||
|
extra_features=None,
|
||||||
|
loss=LpLoss())
|
||||||
|
trainer = Trainer(solver=pinn,
|
||||||
|
max_epochs=15,
|
||||||
|
accelerator='cpu',
|
||||||
|
default_root_dir=tmpdir)
|
||||||
|
trainer.train()
|
||||||
|
new_pinn = PINN.load_from_checkpoint(
|
||||||
|
f'{tmpdir}/lightning_logs/version_0/checkpoints/epoch=14-step=30.ckpt',
|
||||||
|
problem = poisson_problem, model=model)
|
||||||
|
test_pts = CartesianDomain({'x': [0, 1], 'y': [0, 1]}).sample(10)
|
||||||
|
assert new_pinn.forward(test_pts).extract(['u']).shape == (10, 1)
|
||||||
|
assert new_pinn.forward(test_pts).extract(
|
||||||
|
['u']).shape == pinn.forward(test_pts).extract(['u']).shape
|
||||||
|
torch.testing.assert_close(
|
||||||
|
new_pinn.forward(test_pts).extract(['u']),
|
||||||
|
pinn.forward(test_pts).extract(['u']))
|
||||||
|
import shutil
|
||||||
|
shutil.rmtree(tmpdir)
|
||||||
|
|
||||||
|
def test_train_inverse_problem_cpu():
|
||||||
|
poisson_problem = InversePoisson()
|
||||||
|
boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4', 'D']
|
||||||
|
n = 100
|
||||||
|
poisson_problem.discretise_domain(n, 'random', locations=boundaries)
|
||||||
|
pinn = PINN(problem = poisson_problem, model=model,
|
||||||
|
extra_features=None, loss=LpLoss())
|
||||||
|
trainer = Trainer(solver=pinn, max_epochs=1,
|
||||||
|
accelerator='cpu', batch_size=20)
|
||||||
|
trainer.train()
|
||||||
|
|
||||||
|
|
||||||
|
# # TODO does not currently work
|
||||||
|
# def test_train_inverse_problem_restore():
|
||||||
|
# tmpdir = "tests/tmp_restore_inv"
|
||||||
|
# poisson_problem = InversePoisson()
|
||||||
|
# boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4', 'D']
|
||||||
|
# n = 100
|
||||||
|
# poisson_problem.discretise_domain(n, 'random', locations=boundaries)
|
||||||
|
# pinn = PINN(problem=poisson_problem,
|
||||||
|
# model=model,
|
||||||
|
# extra_features=None,
|
||||||
|
# loss=LpLoss())
|
||||||
|
# trainer = Trainer(solver=pinn,
|
||||||
|
# max_epochs=5,
|
||||||
|
# accelerator='cpu',
|
||||||
|
# default_root_dir=tmpdir)
|
||||||
|
# trainer.train()
|
||||||
|
# ntrainer = Trainer(solver=pinn, max_epochs=5, accelerator='cpu')
|
||||||
|
# t = ntrainer.train(
|
||||||
|
# ckpt_path=f'{tmpdir}/lightning_logs/version_0/checkpoints/epoch=4-step=10.ckpt')
|
||||||
|
# import shutil
|
||||||
|
# shutil.rmtree(tmpdir)
|
||||||
|
|
||||||
|
|
||||||
|
def test_train_inverse_problem_load():
|
||||||
|
tmpdir = "tests/tmp_load_inv"
|
||||||
|
poisson_problem = InversePoisson()
|
||||||
|
boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4', 'D']
|
||||||
|
n = 100
|
||||||
|
poisson_problem.discretise_domain(n, 'random', locations=boundaries)
|
||||||
|
pinn = PINN(problem=poisson_problem,
|
||||||
|
model=model,
|
||||||
|
extra_features=None,
|
||||||
|
loss=LpLoss())
|
||||||
|
trainer = Trainer(solver=pinn,
|
||||||
|
max_epochs=15,
|
||||||
|
accelerator='cpu',
|
||||||
|
default_root_dir=tmpdir)
|
||||||
|
trainer.train()
|
||||||
|
new_pinn = PINN.load_from_checkpoint(
|
||||||
|
f'{tmpdir}/lightning_logs/version_0/checkpoints/epoch=14-step=30.ckpt',
|
||||||
|
problem = poisson_problem, model=model)
|
||||||
|
test_pts = CartesianDomain({'x': [0, 1], 'y': [0, 1]}).sample(10)
|
||||||
|
assert new_pinn.forward(test_pts).extract(['u']).shape == (10, 1)
|
||||||
|
assert new_pinn.forward(test_pts).extract(
|
||||||
|
['u']).shape == pinn.forward(test_pts).extract(['u']).shape
|
||||||
|
torch.testing.assert_close(
|
||||||
|
new_pinn.forward(test_pts).extract(['u']),
|
||||||
|
pinn.forward(test_pts).extract(['u']))
|
||||||
|
import shutil
|
||||||
|
shutil.rmtree(tmpdir)
|
||||||
|
|
||||||
|
# # TODO fix asap. Basically sampling few variables
|
||||||
|
# # works only if both variables are in a range.
|
||||||
|
# # if one is fixed and the other not, this will
|
||||||
|
# # not work. This test also needs to be fixed and
|
||||||
|
# # insert in test problem not in test pinn.
|
||||||
|
# def test_train_cpu_sampling_few_vars():
|
||||||
|
# poisson_problem = Poisson()
|
||||||
|
# boundaries = ['gamma1', 'gamma2', 'gamma3']
|
||||||
|
# n = 10
|
||||||
|
# poisson_problem.discretise_domain(n, 'grid', locations=boundaries)
|
||||||
|
# poisson_problem.discretise_domain(n, 'random', locations=['gamma4'], variables=['x'])
|
||||||
|
# poisson_problem.discretise_domain(n, 'random', locations=['gamma4'], variables=['y'])
|
||||||
|
# pinn = PINN(problem = poisson_problem, model=model, extra_features=None, loss=LpLoss())
|
||||||
|
# trainer = Trainer(solver=pinn, kwargs={'max_epochs' : 5, 'accelerator':'cpu'})
|
||||||
|
# trainer.train()
|
||||||
|
|
||||||
|
|
||||||
|
def test_train_extra_feats_cpu():
|
||||||
|
poisson_problem = Poisson()
|
||||||
|
boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
|
||||||
|
n = 10
|
||||||
|
poisson_problem.discretise_domain(n, 'grid', locations=boundaries)
|
||||||
|
pinn = PINN(problem=poisson_problem,
|
||||||
|
model=model_extra_feats,
|
||||||
|
extra_features=extra_feats)
|
||||||
|
trainer = Trainer(solver=pinn, max_epochs=5, accelerator='cpu')
|
||||||
|
trainer.train()
|
||||||
|
|
||||||
|
|
||||||
|
# TODO, fix GitHub actions to run also on GPU
|
||||||
|
# def test_train_gpu():
|
||||||
|
# poisson_problem = Poisson()
|
||||||
|
# boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
|
||||||
|
# n = 10
|
||||||
|
# poisson_problem.discretise_domain(n, 'grid', locations=boundaries)
|
||||||
|
# pinn = PINN(problem = poisson_problem, model=model, extra_features=None, loss=LpLoss())
|
||||||
|
# trainer = Trainer(solver=pinn, kwargs={'max_epochs' : 5, 'accelerator':'gpu'})
|
||||||
|
# trainer.train()
|
||||||
|
|
||||||
|
# def test_train_gpu(): #TODO fix ASAP
|
||||||
|
# poisson_problem = Poisson()
|
||||||
|
# boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
|
||||||
|
# n = 10
|
||||||
|
# poisson_problem.discretise_domain(n, 'grid', locations=boundaries)
|
||||||
|
# poisson_problem.conditions.pop('data') # The input/output pts are allocated on cpu
|
||||||
|
# pinn = PINN(problem = poisson_problem, model=model, extra_features=None, loss=LpLoss())
|
||||||
|
# trainer = Trainer(solver=pinn, kwargs={'max_epochs' : 5, 'accelerator':'gpu'})
|
||||||
|
# trainer.train()
|
||||||
|
|
||||||
|
# def test_train_2():
|
||||||
|
# boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
|
||||||
|
# n = 10
|
||||||
|
# expected_keys = [[], list(range(0, 50, 3))]
|
||||||
|
# param = [0, 3]
|
||||||
|
# for i, truth_key in zip(param, expected_keys):
|
||||||
|
# pinn = PINN(problem, model)
|
||||||
|
# pinn.discretise_domain(n, 'grid', locations=boundaries)
|
||||||
|
# pinn.discretise_domain(n, 'grid', locations=['D'])
|
||||||
|
# pinn.train(50, save_loss=i)
|
||||||
|
# assert list(pinn.history_loss.keys()) == truth_key
|
||||||
|
|
||||||
|
|
||||||
|
# def test_train_extra_feats():
|
||||||
|
# pinn = PINN(problem, model_extra_feat, [myFeature()])
|
||||||
|
# boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
|
||||||
|
# n = 10
|
||||||
|
# pinn.discretise_domain(n, 'grid', locations=boundaries)
|
||||||
|
# pinn.discretise_domain(n, 'grid', locations=['D'])
|
||||||
|
# pinn.train(5)
|
||||||
|
|
||||||
|
|
||||||
|
# def test_train_2_extra_feats():
|
||||||
|
# boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
|
||||||
|
# n = 10
|
||||||
|
# expected_keys = [[], list(range(0, 50, 3))]
|
||||||
|
# param = [0, 3]
|
||||||
|
# for i, truth_key in zip(param, expected_keys):
|
||||||
|
# pinn = PINN(problem, model_extra_feat, [myFeature()])
|
||||||
|
# pinn.discretise_domain(n, 'grid', locations=boundaries)
|
||||||
|
# pinn.discretise_domain(n, 'grid', locations=['D'])
|
||||||
|
# pinn.train(50, save_loss=i)
|
||||||
|
# assert list(pinn.history_loss.keys()) == truth_key
|
||||||
|
|
||||||
|
|
||||||
|
# def test_train_with_optimizer_kwargs():
|
||||||
|
# boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
|
||||||
|
# n = 10
|
||||||
|
# expected_keys = [[], list(range(0, 50, 3))]
|
||||||
|
# param = [0, 3]
|
||||||
|
# for i, truth_key in zip(param, expected_keys):
|
||||||
|
# pinn = PINN(problem, model, optimizer_kwargs={'lr' : 0.3})
|
||||||
|
# pinn.discretise_domain(n, 'grid', locations=boundaries)
|
||||||
|
# pinn.discretise_domain(n, 'grid', locations=['D'])
|
||||||
|
# pinn.train(50, save_loss=i)
|
||||||
|
# assert list(pinn.history_loss.keys()) == truth_key
|
||||||
|
|
||||||
|
|
||||||
|
# def test_train_with_lr_scheduler():
|
||||||
|
# boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
|
||||||
|
# n = 10
|
||||||
|
# expected_keys = [[], list(range(0, 50, 3))]
|
||||||
|
# param = [0, 3]
|
||||||
|
# for i, truth_key in zip(param, expected_keys):
|
||||||
|
# pinn = PINN(
|
||||||
|
# problem,
|
||||||
|
# model,
|
||||||
|
# lr_scheduler_type=torch.optim.lr_scheduler.CyclicLR,
|
||||||
|
# lr_scheduler_kwargs={'base_lr' : 0.1, 'max_lr' : 0.3, 'cycle_momentum': False}
|
||||||
|
# )
|
||||||
|
# pinn.discretise_domain(n, 'grid', locations=boundaries)
|
||||||
|
# pinn.discretise_domain(n, 'grid', locations=['D'])
|
||||||
|
# pinn.train(50, save_loss=i)
|
||||||
|
# assert list(pinn.history_loss.keys()) == truth_key
|
||||||
|
|
||||||
|
|
||||||
|
# # def test_train_batch():
|
||||||
|
# # pinn = PINN(problem, model, batch_size=6)
|
||||||
|
# # boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
|
||||||
|
# # n = 10
|
||||||
|
# # pinn.discretise_domain(n, 'grid', locations=boundaries)
|
||||||
|
# # pinn.discretise_domain(n, 'grid', locations=['D'])
|
||||||
|
# # pinn.train(5)
|
||||||
|
|
||||||
|
|
||||||
|
# # def test_train_batch_2():
|
||||||
|
# # boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
|
||||||
|
# # n = 10
|
||||||
|
# # expected_keys = [[], list(range(0, 50, 3))]
|
||||||
|
# # param = [0, 3]
|
||||||
|
# # for i, truth_key in zip(param, expected_keys):
|
||||||
|
# # pinn = PINN(problem, model, batch_size=6)
|
||||||
|
# # pinn.discretise_domain(n, 'grid', locations=boundaries)
|
||||||
|
# # pinn.discretise_domain(n, 'grid', locations=['D'])
|
||||||
|
# # pinn.train(50, save_loss=i)
|
||||||
|
# # assert list(pinn.history_loss.keys()) == truth_key
|
||||||
|
|
||||||
|
|
||||||
|
# if torch.cuda.is_available():
|
||||||
|
|
||||||
|
# # def test_gpu_train():
|
||||||
|
# # pinn = PINN(problem, model, batch_size=20, device='cuda')
|
||||||
|
# # boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
|
||||||
|
# # n = 100
|
||||||
|
# # pinn.discretise_domain(n, 'grid', locations=boundaries)
|
||||||
|
# # pinn.discretise_domain(n, 'grid', locations=['D'])
|
||||||
|
# # pinn.train(5)
|
||||||
|
|
||||||
|
# def test_gpu_train_nobatch():
|
||||||
|
# pinn = PINN(problem, model, batch_size=None, device='cuda')
|
||||||
|
# boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
|
||||||
|
# n = 100
|
||||||
|
# pinn.discretise_domain(n, 'grid', locations=boundaries)
|
||||||
|
# pinn.discretise_domain(n, 'grid', locations=['D'])
|
||||||
|
# pinn.train(5)
|
||||||
|
|
||||||
Reference in New Issue
Block a user