PINN variants addition and Solvers Update (#263)
* gpinn/basepinn new classes, pinn restructure * codacy fix gpinn/basepinn/pinn * inverse problem fix * Causal PINN (#267) * fix GPU training in inverse problem (#283) * Create a `compute_residual` attribute for `PINNInterface` * Modify dataloading in solvers (#286) * Modify PINNInterface by removing _loss_phys, _loss_data * Adding in PINNInterface a variable to track the current condition during training * Modify GPINN,PINN,CausalPINN to match changes in PINNInterface * Competitive Pinn Addition (#288) * fixing after rebase/ fix loss * fixing final issues --------- Co-authored-by: Dario Coscia <dariocoscia@Dario-Coscia.local> * Modify min max formulation to max min for paper consistency * Adding SAPINN solver (#291) * rom solver * fix import --------- Co-authored-by: Dario Coscia <dariocoscia@Dario-Coscia.local> Co-authored-by: Anna Ivagnes <75523024+annaivagnes@users.noreply.github.com> Co-authored-by: valc89 <103250118+valc89@users.noreply.github.com> Co-authored-by: Monthly Tag bot <mtbot@noreply.github.com> Co-authored-by: Nicola Demo <demo.nicola@gmail.com>
This commit is contained in:
@@ -1,6 +1,19 @@
|
||||
__all__ = ["PINN", "GAROM", "SupervisedSolver", "SolverInterface"]
|
||||
__all__ = [
|
||||
"SolverInterface",
|
||||
"PINNInterface",
|
||||
"PINN",
|
||||
"GPINN",
|
||||
"CausalPINN",
|
||||
"CompetitivePINN",
|
||||
"SAPINN",
|
||||
"SupervisedSolver",
|
||||
"ReducedOrderModelSolver",
|
||||
"GAROM",
|
||||
]
|
||||
|
||||
from .garom import GAROM
|
||||
from .pinn import PINN
|
||||
from .supervised import SupervisedSolver
|
||||
from .solver import SolverInterface
|
||||
from .pinns import *
|
||||
from .supervised import SupervisedSolver
|
||||
from .rom import ReducedOrderModelSolver
|
||||
from .garom import GAROM
|
||||
|
||||
|
||||
@@ -253,18 +253,11 @@ class GAROM(SolverInterface):
|
||||
:rtype: LabelTensor
|
||||
"""
|
||||
|
||||
dataloader = self.trainer.train_dataloader
|
||||
condition_idx = batch["condition"]
|
||||
|
||||
for condition_id in range(condition_idx.min(), condition_idx.max() + 1):
|
||||
|
||||
if sys.version_info >= (3, 8):
|
||||
condition_name = dataloader.condition_names[condition_id]
|
||||
else:
|
||||
condition_name = dataloader.loaders.condition_names[
|
||||
condition_id
|
||||
]
|
||||
|
||||
condition_name = self._dataloader.condition_names[condition_id]
|
||||
condition = self.problem.conditions[condition_name]
|
||||
pts = batch["pts"].detach()
|
||||
out = batch["output"]
|
||||
|
||||
@@ -1,232 +0,0 @@
|
||||
""" Module for PINN """
|
||||
|
||||
import torch
|
||||
|
||||
try:
|
||||
from torch.optim.lr_scheduler import LRScheduler # torch >= 2.0
|
||||
except ImportError:
|
||||
from torch.optim.lr_scheduler import (
|
||||
_LRScheduler as LRScheduler,
|
||||
) # torch < 2.0
|
||||
|
||||
import sys
|
||||
from torch.optim.lr_scheduler import ConstantLR
|
||||
|
||||
from .solver import SolverInterface
|
||||
from ..label_tensor import LabelTensor
|
||||
from ..utils import check_consistency
|
||||
from ..loss import LossInterface
|
||||
from ..problem import InverseProblem
|
||||
from torch.nn.modules.loss import _Loss
|
||||
|
||||
torch.pi = torch.acos(torch.zeros(1)).item() * 2 # which is 3.1415927410125732
|
||||
|
||||
|
||||
class PINN(SolverInterface):
|
||||
"""
|
||||
PINN solver class. This class implements Physics Informed Neural
|
||||
Network solvers, using a user specified ``model`` to solve a specific
|
||||
``problem``. It can be used for solving both forward and inverse problems.
|
||||
|
||||
.. seealso::
|
||||
|
||||
**Original reference**: Karniadakis, G. E., Kevrekidis, I. G., Lu, L.,
|
||||
Perdikaris, P., Wang, S., & Yang, L. (2021).
|
||||
Physics-informed machine learning. Nature Reviews Physics, 3(6), 422-440.
|
||||
<https://doi.org/10.1038/s42254-021-00314-5>`_.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
problem,
|
||||
model,
|
||||
extra_features=None,
|
||||
loss=torch.nn.MSELoss(),
|
||||
optimizer=torch.optim.Adam,
|
||||
optimizer_kwargs={"lr": 0.001},
|
||||
scheduler=ConstantLR,
|
||||
scheduler_kwargs={"factor": 1, "total_iters": 0},
|
||||
):
|
||||
"""
|
||||
:param AbstractProblem problem: The formulation of the problem.
|
||||
:param torch.nn.Module model: The neural network model to use.
|
||||
:param torch.nn.Module loss: The loss function used as minimizer,
|
||||
default :class:`torch.nn.MSELoss`.
|
||||
:param torch.nn.Module extra_features: The additional input
|
||||
features to use as augmented input.
|
||||
:param torch.optim.Optimizer optimizer: The neural network optimizer to
|
||||
use; default is :class:`torch.optim.Adam`.
|
||||
:param dict optimizer_kwargs: Optimizer constructor keyword args.
|
||||
:param torch.optim.LRScheduler scheduler: Learning
|
||||
rate scheduler.
|
||||
:param dict scheduler_kwargs: LR scheduler constructor keyword args.
|
||||
"""
|
||||
super().__init__(
|
||||
models=[model],
|
||||
problem=problem,
|
||||
optimizers=[optimizer],
|
||||
optimizers_kwargs=[optimizer_kwargs],
|
||||
extra_features=extra_features,
|
||||
)
|
||||
|
||||
# check consistency
|
||||
check_consistency(scheduler, LRScheduler, subclass=True)
|
||||
check_consistency(scheduler_kwargs, dict)
|
||||
check_consistency(loss, (LossInterface, _Loss), subclass=False)
|
||||
|
||||
# assign variables
|
||||
self._scheduler = scheduler(self.optimizers[0], **scheduler_kwargs)
|
||||
self._loss = loss
|
||||
self._neural_net = self.models[0]
|
||||
|
||||
# inverse problem handling
|
||||
if isinstance(self.problem, InverseProblem):
|
||||
self._params = self.problem.unknown_parameters
|
||||
else:
|
||||
self._params = None
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Forward pass implementation for the PINN
|
||||
solver.
|
||||
|
||||
:param torch.Tensor x: Input tensor.
|
||||
:return: PINN solution.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
return self.neural_net(x)
|
||||
|
||||
def configure_optimizers(self):
|
||||
"""
|
||||
Optimizer configuration for the PINN
|
||||
solver.
|
||||
|
||||
:return: The optimizers and the schedulers
|
||||
:rtype: tuple(list, list)
|
||||
"""
|
||||
# if the problem is an InverseProblem, add the unknown parameters
|
||||
# to the parameters that the optimizer needs to optimize
|
||||
if isinstance(self.problem, InverseProblem):
|
||||
self.optimizers[0].add_param_group(
|
||||
{
|
||||
"params": [
|
||||
self._params[var]
|
||||
for var in self.problem.unknown_variables
|
||||
]
|
||||
}
|
||||
)
|
||||
return self.optimizers, [self.scheduler]
|
||||
|
||||
def _clamp_inverse_problem_params(self):
|
||||
for v in self._params:
|
||||
self._params[v].data.clamp_(
|
||||
self.problem.unknown_parameter_domain.range_[v][0],
|
||||
self.problem.unknown_parameter_domain.range_[v][1],
|
||||
)
|
||||
|
||||
def _loss_data(self, input, output):
|
||||
return self.loss(self.forward(input), output)
|
||||
|
||||
def _loss_phys(self, samples, equation):
|
||||
try:
|
||||
residual = equation.residual(samples, self.forward(samples))
|
||||
except (
|
||||
TypeError
|
||||
): # this occurs when the function has three inputs, i.e. inverse problem
|
||||
residual = equation.residual(
|
||||
samples, self.forward(samples), self._params
|
||||
)
|
||||
return self.loss(
|
||||
torch.zeros_like(residual, requires_grad=True), residual
|
||||
)
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
"""
|
||||
PINN solver training step.
|
||||
|
||||
:param batch: The batch element in the dataloader.
|
||||
:type batch: tuple
|
||||
:param batch_idx: The batch index.
|
||||
:type batch_idx: int
|
||||
:return: The sum of the loss functions.
|
||||
:rtype: LabelTensor
|
||||
"""
|
||||
|
||||
dataloader = self.trainer.train_dataloader
|
||||
condition_losses = []
|
||||
|
||||
condition_idx = batch["condition"]
|
||||
|
||||
for condition_id in range(condition_idx.min(), condition_idx.max() + 1):
|
||||
|
||||
if sys.version_info >= (3, 8):
|
||||
condition_name = dataloader.condition_names[condition_id]
|
||||
else:
|
||||
condition_name = dataloader.loaders.condition_names[
|
||||
condition_id
|
||||
]
|
||||
condition = self.problem.conditions[condition_name]
|
||||
pts = batch["pts"]
|
||||
|
||||
if len(batch) == 2:
|
||||
samples = pts[condition_idx == condition_id]
|
||||
loss = self._loss_phys(samples, condition.equation)
|
||||
elif len(batch) == 3:
|
||||
samples = pts[condition_idx == condition_id]
|
||||
ground_truth = batch["output"][condition_idx == condition_id]
|
||||
loss = self._loss_data(samples, ground_truth)
|
||||
else:
|
||||
raise ValueError("Batch size not supported")
|
||||
|
||||
# TODO for users this us hard to remember when creating a new solver, to fix in a smarter way
|
||||
loss = loss.as_subclass(torch.Tensor)
|
||||
|
||||
# # add condition losses and accumulate logging for each epoch
|
||||
condition_losses.append(loss * condition.data_weight)
|
||||
self.log(
|
||||
condition_name + "_loss",
|
||||
float(loss),
|
||||
prog_bar=True,
|
||||
logger=True,
|
||||
on_epoch=True,
|
||||
on_step=False,
|
||||
)
|
||||
|
||||
# clamp unknown parameters of the InverseProblem to their domain ranges (if needed)
|
||||
if isinstance(self.problem, InverseProblem):
|
||||
self._clamp_inverse_problem_params()
|
||||
|
||||
# TODO Fix the bug, tot_loss is a label tensor without labels
|
||||
# we need to pass it as a torch tensor to make everything work
|
||||
total_loss = sum(condition_losses)
|
||||
self.log(
|
||||
"mean_loss",
|
||||
float(total_loss / len(condition_losses)),
|
||||
prog_bar=True,
|
||||
logger=True,
|
||||
on_epoch=True,
|
||||
on_step=False,
|
||||
)
|
||||
|
||||
return total_loss
|
||||
|
||||
@property
|
||||
def scheduler(self):
|
||||
"""
|
||||
Scheduler for the PINN training.
|
||||
"""
|
||||
return self._scheduler
|
||||
|
||||
@property
|
||||
def neural_net(self):
|
||||
"""
|
||||
Neural network for the PINN training.
|
||||
"""
|
||||
return self._neural_net
|
||||
|
||||
@property
|
||||
def loss(self):
|
||||
"""
|
||||
Loss for the PINN training.
|
||||
"""
|
||||
return self._loss
|
||||
15
pina/solvers/pinns/__init__.py
Normal file
15
pina/solvers/pinns/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
__all__ = [
|
||||
"PINNInterface",
|
||||
"PINN",
|
||||
"GPINN",
|
||||
"CausalPINN",
|
||||
"CompetitivePINN",
|
||||
"SAPINN",
|
||||
]
|
||||
|
||||
from .basepinn import PINNInterface
|
||||
from .pinn import PINN
|
||||
from .gpinn import GPINN
|
||||
from .causalpinn import CausalPINN
|
||||
from .competitive_pinn import CompetitivePINN
|
||||
from .sapinn import SAPINN
|
||||
247
pina/solvers/pinns/basepinn.py
Normal file
247
pina/solvers/pinns/basepinn.py
Normal file
@@ -0,0 +1,247 @@
|
||||
""" Module for PINN """
|
||||
|
||||
import sys
|
||||
from abc import ABCMeta, abstractmethod
|
||||
import torch
|
||||
|
||||
from ...solvers.solver import SolverInterface
|
||||
from pina.utils import check_consistency
|
||||
from pina.loss import LossInterface
|
||||
from pina.problem import InverseProblem
|
||||
from torch.nn.modules.loss import _Loss
|
||||
|
||||
torch.pi = torch.acos(torch.zeros(1)).item() * 2 # which is 3.1415927410125732
|
||||
|
||||
class PINNInterface(SolverInterface, metaclass=ABCMeta):
|
||||
"""
|
||||
Base PINN solver class. This class implements the Solver Interface
|
||||
for Physics Informed Neural Network solvers.
|
||||
|
||||
This class can be used to
|
||||
define PINNs with multiple ``optimizers``, and/or ``models``.
|
||||
By default it takes
|
||||
an :class:`~pina.problem.abstract_problem.AbstractProblem`, so it is up
|
||||
to the user to choose which problem the implemented solver inheriting from
|
||||
this class is suitable for.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
models,
|
||||
problem,
|
||||
optimizers,
|
||||
optimizers_kwargs,
|
||||
extra_features,
|
||||
loss,
|
||||
):
|
||||
"""
|
||||
:param models: Multiple torch neural network models instances.
|
||||
:type models: list(torch.nn.Module)
|
||||
:param problem: A problem definition instance.
|
||||
:type problem: AbstractProblem
|
||||
:param list(torch.optim.Optimizer) optimizer: A list of neural network
|
||||
optimizers to use.
|
||||
:param list(dict) optimizer_kwargs: A list of optimizer constructor
|
||||
keyword args.
|
||||
:param list(torch.nn.Module) extra_features: The additional input
|
||||
features to use as augmented input. If ``None`` no extra features
|
||||
are passed. If it is a list of :class:`torch.nn.Module`,
|
||||
the extra feature list is passed to all models. If it is a list
|
||||
of extra features' lists, each single list of extra feature
|
||||
is passed to a model.
|
||||
:param torch.nn.Module loss: The loss function used as minimizer,
|
||||
default :class:`torch.nn.MSELoss`.
|
||||
"""
|
||||
super().__init__(
|
||||
models=models,
|
||||
problem=problem,
|
||||
optimizers=optimizers,
|
||||
optimizers_kwargs=optimizers_kwargs,
|
||||
extra_features=extra_features,
|
||||
)
|
||||
|
||||
# check consistency
|
||||
check_consistency(loss, (LossInterface, _Loss), subclass=False)
|
||||
|
||||
# assign variables
|
||||
self._loss = loss
|
||||
|
||||
# inverse problem handling
|
||||
if isinstance(self.problem, InverseProblem):
|
||||
self._params = self.problem.unknown_parameters
|
||||
self._clamp_params = self._clamp_inverse_problem_params
|
||||
else:
|
||||
self._params = None
|
||||
self._clamp_params = lambda : None
|
||||
|
||||
# variable used internally to store residual losses at each epoch
|
||||
# this variable save the residual at each iteration (not weighted)
|
||||
self.__logged_res_losses = []
|
||||
|
||||
# variable used internally in pina for logging. This variable points to
|
||||
# the current condition during the training step and returns the
|
||||
# condition name. Whenever :meth:`store_log` is called the logged
|
||||
# variable will be stored with name = self.__logged_metric
|
||||
self.__logged_metric = None
|
||||
|
||||
def training_step(self, batch, _):
|
||||
"""
|
||||
The Physics Informed Solver Training Step. This function takes care
|
||||
of the physics informed training step, and it must not be override
|
||||
if not intentionally. It handles the batching mechanism, the workload
|
||||
division for the various conditions, the inverse problem clamping,
|
||||
and loggers.
|
||||
|
||||
:param tuple batch: The batch element in the dataloader.
|
||||
:param int batch_idx: The batch index.
|
||||
:return: The sum of the loss functions.
|
||||
:rtype: LabelTensor
|
||||
"""
|
||||
|
||||
condition_losses = []
|
||||
condition_idx = batch["condition"]
|
||||
|
||||
for condition_id in range(condition_idx.min(), condition_idx.max() + 1):
|
||||
|
||||
condition_name = self._dataloader.condition_names[condition_id]
|
||||
condition = self.problem.conditions[condition_name]
|
||||
pts = batch["pts"]
|
||||
# condition name is logged (if logs enabled)
|
||||
self.__logged_metric = condition_name
|
||||
|
||||
if len(batch) == 2:
|
||||
samples = pts[condition_idx == condition_id]
|
||||
loss = self.loss_phys(samples, condition.equation)
|
||||
elif len(batch) == 3:
|
||||
samples = pts[condition_idx == condition_id]
|
||||
ground_truth = batch["output"][condition_idx == condition_id]
|
||||
loss = self.loss_data(samples, ground_truth)
|
||||
else:
|
||||
raise ValueError("Batch size not supported")
|
||||
|
||||
# add condition losses for each epoch
|
||||
condition_losses.append(loss * condition.data_weight)
|
||||
|
||||
# clamp unknown parameters in InverseProblem (if needed)
|
||||
self._clamp_params()
|
||||
|
||||
# total loss (must be a torch.Tensor)
|
||||
total_loss = sum(condition_losses)
|
||||
return total_loss.as_subclass(torch.Tensor)
|
||||
|
||||
def loss_data(self, input_tensor, output_tensor):
|
||||
"""
|
||||
The data loss for the PINN solver. It computes the loss between
|
||||
the network output against the true solution. This function
|
||||
should not be override if not intentionally.
|
||||
|
||||
:param LabelTensor input_tensor: The input to the neural networks.
|
||||
:param LabelTensor output_tensor: The true solution to compare the
|
||||
network solution.
|
||||
:return: The residual loss averaged on the input coordinates
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
loss_value = self.loss(self.forward(input_tensor), output_tensor)
|
||||
self.store_log(loss_value=float(loss_value))
|
||||
return self.loss(self.forward(input_tensor), output_tensor)
|
||||
|
||||
@abstractmethod
|
||||
def loss_phys(self, samples, equation):
|
||||
"""
|
||||
Computes the physics loss for the physics informed solver based on given
|
||||
samples and equation. This method must be override by all inherited
|
||||
classes and it is the core to define a new physics informed solver.
|
||||
|
||||
:param LabelTensor samples: The samples to evaluate the physics loss.
|
||||
:param EquationInterface equation: The governing equation
|
||||
representing the physics.
|
||||
:return: The physics loss calculated based on given
|
||||
samples and equation.
|
||||
:rtype: LabelTensor
|
||||
"""
|
||||
pass
|
||||
|
||||
def compute_residual(self, samples, equation):
|
||||
"""
|
||||
Compute the residual for Physics Informed learning. This function
|
||||
returns the :obj:`~pina.equation.equation.Equation` specified in the
|
||||
:obj:`~pina.condition.Condition` evaluated at the ``samples`` points.
|
||||
|
||||
:param LabelTensor samples: The samples to evaluate the physics loss.
|
||||
:param EquationInterface equation: The governing equation
|
||||
representing the physics.
|
||||
:return: The residual of the neural network solution.
|
||||
:rtype: LabelTensor
|
||||
"""
|
||||
try:
|
||||
residual = equation.residual(samples, self.forward(samples))
|
||||
except (
|
||||
TypeError
|
||||
): # this occurs when the function has three inputs, i.e. inverse problem
|
||||
residual = equation.residual(
|
||||
samples, self.forward(samples), self._params
|
||||
)
|
||||
return residual
|
||||
|
||||
def store_log(self, loss_value):
|
||||
"""
|
||||
Stores the loss value in the logger. This function should be
|
||||
called for all conditions. It automatically handles the storing
|
||||
conditions names. It must be used
|
||||
anytime a specific variable wants to be stored for a specific condition.
|
||||
A simple example is to use the variable to store the residual.
|
||||
|
||||
:param str name: The name of the loss.
|
||||
:param torch.Tensor loss_value: The value of the loss.
|
||||
"""
|
||||
self.log(
|
||||
self.__logged_metric+'_loss',
|
||||
loss_value,
|
||||
prog_bar=True,
|
||||
logger=True,
|
||||
on_epoch=True,
|
||||
on_step=False,
|
||||
)
|
||||
self.__logged_res_losses.append(loss_value)
|
||||
|
||||
def on_train_epoch_end(self):
|
||||
"""
|
||||
At the end of each epoch we free the stored losses. This function
|
||||
should not be override if not intentionally.
|
||||
"""
|
||||
if self.__logged_res_losses:
|
||||
# storing mean loss
|
||||
self.__logged_metric = 'mean'
|
||||
self.store_log(
|
||||
sum(self.__logged_res_losses)/len(self.__logged_res_losses)
|
||||
)
|
||||
# free the logged losses
|
||||
self.__logged_res_losses = []
|
||||
return super().on_train_epoch_end()
|
||||
|
||||
def _clamp_inverse_problem_params(self):
|
||||
"""
|
||||
Clamps the parameters of the inverse problem
|
||||
solver to the specified ranges.
|
||||
"""
|
||||
for v in self._params:
|
||||
self._params[v].data.clamp_(
|
||||
self.problem.unknown_parameter_domain.range_[v][0],
|
||||
self.problem.unknown_parameter_domain.range_[v][1],
|
||||
)
|
||||
|
||||
@property
|
||||
def loss(self):
|
||||
"""
|
||||
Loss used for training.
|
||||
"""
|
||||
return self._loss
|
||||
|
||||
@property
|
||||
def current_condition_name(self):
|
||||
"""
|
||||
Returns the condition name. This function can be used inside the
|
||||
:meth:`loss_phys` to extract the condition at which the loss is
|
||||
computed.
|
||||
"""
|
||||
return self.__logged_metric
|
||||
221
pina/solvers/pinns/causalpinn.py
Normal file
221
pina/solvers/pinns/causalpinn.py
Normal file
@@ -0,0 +1,221 @@
|
||||
""" Module for CausalPINN """
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
from torch.optim.lr_scheduler import ConstantLR
|
||||
|
||||
from .pinn import PINN
|
||||
from pina.problem import TimeDependentProblem
|
||||
from pina.utils import check_consistency
|
||||
|
||||
|
||||
class CausalPINN(PINN):
|
||||
r"""
|
||||
Causal Physics Informed Neural Network (PINN) solver class.
|
||||
This class implements Causal Physics Informed Neural
|
||||
Network solvers, using a user specified ``model`` to solve a specific
|
||||
``problem``. It can be used for solving both forward and inverse problems.
|
||||
|
||||
The Causal Physics Informed Network aims to find
|
||||
the solution :math:`\mathbf{u}:\Omega\rightarrow\mathbb{R}^m`
|
||||
of the differential problem:
|
||||
|
||||
.. math::
|
||||
|
||||
\begin{cases}
|
||||
\mathcal{A}[\mathbf{u}](\mathbf{x})=0\quad,\mathbf{x}\in\Omega\\
|
||||
\mathcal{B}[\mathbf{u}](\mathbf{x})=0\quad,
|
||||
\mathbf{x}\in\partial\Omega
|
||||
\end{cases}
|
||||
|
||||
minimizing the loss function
|
||||
|
||||
.. math::
|
||||
\mathcal{L}_{\rm{problem}} = \frac{1}{N_t}\sum_{i=1}^{N_t}
|
||||
\omega_{i}\mathcal{L}_r(t_i),
|
||||
|
||||
where:
|
||||
|
||||
.. math::
|
||||
\mathcal{L}_r(t) = \frac{1}{N}\sum_{i=1}^N
|
||||
\mathcal{L}(\mathcal{A}[\mathbf{u}](\mathbf{x}_i, t)) +
|
||||
\frac{1}{N}\sum_{i=1}^N
|
||||
\mathcal{L}(\mathcal{B}[\mathbf{u}](\mathbf{x}_i, t))
|
||||
|
||||
and,
|
||||
|
||||
.. math::
|
||||
\omega_i = \exp\left(\epsilon \sum_{k=1}^{i-1}\mathcal{L}_r(t_k)\right).
|
||||
|
||||
:math:`\epsilon` is an hyperparameter, default set to :math:`100`, while
|
||||
:math:`\mathcal{L}` is a specific loss function,
|
||||
default Mean Square Error:
|
||||
|
||||
.. math::
|
||||
\mathcal{L}(v) = \| v \|^2_2.
|
||||
|
||||
|
||||
.. seealso::
|
||||
|
||||
**Original reference**: Wang, Sifan, Shyam Sankaran, and Paris
|
||||
Perdikaris. "Respecting causality for training physics-informed
|
||||
neural networks." Computer Methods in Applied Mechanics
|
||||
and Engineering 421 (2024): 116813.
|
||||
DOI `10.1016 <https://doi.org/10.1016/j.cma.2024.116813>`_.
|
||||
|
||||
.. note::
|
||||
This class can only work for problems inheriting
|
||||
from at least
|
||||
:class:`~pina.problem.timedep_problem.TimeDependentProblem` class.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
problem,
|
||||
model,
|
||||
extra_features=None,
|
||||
loss=torch.nn.MSELoss(),
|
||||
optimizer=torch.optim.Adam,
|
||||
optimizer_kwargs={"lr": 0.001},
|
||||
scheduler=ConstantLR,
|
||||
scheduler_kwargs={"factor": 1, "total_iters": 0},
|
||||
eps=100,
|
||||
):
|
||||
"""
|
||||
:param AbstractProblem problem: The formulation of the problem.
|
||||
:param torch.nn.Module model: The neural network model to use.
|
||||
:param torch.nn.Module loss: The loss function used as minimizer,
|
||||
default :class:`torch.nn.MSELoss`.
|
||||
:param torch.nn.Module extra_features: The additional input
|
||||
features to use as augmented input.
|
||||
:param torch.optim.Optimizer optimizer: The neural network optimizer to
|
||||
use; default is :class:`torch.optim.Adam`.
|
||||
:param dict optimizer_kwargs: Optimizer constructor keyword args.
|
||||
:param torch.optim.LRScheduler scheduler: Learning
|
||||
rate scheduler.
|
||||
:param dict scheduler_kwargs: LR scheduler constructor keyword args.
|
||||
:param int | float eps: The exponential decay parameter. Note that this
|
||||
value is kept fixed during the training, but can be changed by means
|
||||
of a callback, e.g. for annealing.
|
||||
"""
|
||||
super().__init__(
|
||||
problem=problem,
|
||||
model=model,
|
||||
extra_features=extra_features,
|
||||
loss=loss,
|
||||
optimizer=optimizer,
|
||||
optimizer_kwargs=optimizer_kwargs,
|
||||
scheduler=scheduler,
|
||||
scheduler_kwargs=scheduler_kwargs,
|
||||
)
|
||||
|
||||
# checking consistency
|
||||
check_consistency(eps, (int,float))
|
||||
self._eps = eps
|
||||
if not isinstance(self.problem, TimeDependentProblem):
|
||||
raise ValueError('Casual PINN works only for problems'
|
||||
'inheritig from TimeDependentProblem.')
|
||||
|
||||
def loss_phys(self, samples, equation):
|
||||
"""
|
||||
Computes the physics loss for the Causal PINN solver based on given
|
||||
samples and equation.
|
||||
|
||||
:param LabelTensor samples: The samples to evaluate the physics loss.
|
||||
:param EquationInterface equation: The governing equation
|
||||
representing the physics.
|
||||
:return: The physics loss calculated based on given
|
||||
samples and equation.
|
||||
:rtype: LabelTensor
|
||||
"""
|
||||
# split sequentially ordered time tensors into chunks
|
||||
chunks, labels = self._split_tensor_into_chunks(samples)
|
||||
# compute residuals - this correspond to ordered loss functions
|
||||
# values for each time step. We apply `flatten` such that after
|
||||
# concataning the residuals we obtain a tensor of shape #chunks
|
||||
time_loss = []
|
||||
for chunk in chunks:
|
||||
chunk.labels = labels
|
||||
# classical PINN loss
|
||||
residual = self.compute_residual(samples=chunk, equation=equation)
|
||||
loss_val = self.loss(
|
||||
torch.zeros_like(residual, requires_grad=True), residual
|
||||
)
|
||||
time_loss.append(loss_val)
|
||||
# store results
|
||||
self.store_log(loss_value=float(sum(time_loss)/len(time_loss)))
|
||||
# concatenate residuals
|
||||
time_loss = torch.stack(time_loss)
|
||||
# compute weights (without the gradient storing)
|
||||
with torch.no_grad():
|
||||
weights = self._compute_weights(time_loss)
|
||||
return (weights * time_loss).mean()
|
||||
|
||||
@property
|
||||
def eps(self):
|
||||
"""
|
||||
The exponential decay parameter.
|
||||
"""
|
||||
return self._eps
|
||||
|
||||
@eps.setter
|
||||
def eps(self, value):
|
||||
"""
|
||||
Setter method for the eps parameter.
|
||||
|
||||
:param float value: The exponential decay parameter.
|
||||
"""
|
||||
check_consistency(value, float)
|
||||
self._eps = value
|
||||
|
||||
def _sort_label_tensor(self, tensor):
|
||||
"""
|
||||
Sorts the label tensor based on time variables.
|
||||
|
||||
:param LabelTensor tensor: The label tensor to be sorted.
|
||||
:return: The sorted label tensor based on time variables.
|
||||
:rtype: LabelTensor
|
||||
"""
|
||||
# labels input tensors
|
||||
labels = tensor.labels
|
||||
# extract time tensor
|
||||
time_tensor = tensor.extract(self.problem.temporal_domain.variables)
|
||||
# sort the time tensors (this is very bad for GPU)
|
||||
_, idx = torch.sort(time_tensor.tensor.flatten())
|
||||
tensor = tensor[idx]
|
||||
tensor.labels = labels
|
||||
return tensor
|
||||
|
||||
def _split_tensor_into_chunks(self, tensor):
|
||||
"""
|
||||
Splits the label tensor into chunks based on time.
|
||||
|
||||
:param LabelTensor tensor: The label tensor to be split.
|
||||
:return: Tuple containing the chunks and the original labels.
|
||||
:rtype: Tuple[List[LabelTensor], List]
|
||||
"""
|
||||
# labels input tensors
|
||||
labels = tensor.labels
|
||||
# labels input tensors
|
||||
tensor = self._sort_label_tensor(tensor)
|
||||
# extract time tensor
|
||||
time_tensor = tensor.extract(self.problem.temporal_domain.variables)
|
||||
# count unique tensors in time
|
||||
_, idx_split = time_tensor.unique(return_counts=True)
|
||||
# splitting
|
||||
chunks = torch.split(tensor, tuple(idx_split))
|
||||
return chunks, labels # return chunks
|
||||
|
||||
def _compute_weights(self, loss):
|
||||
"""
|
||||
Computes the weights for the physics loss based on the cumulative loss.
|
||||
|
||||
:param LabelTensor loss: The physics loss values.
|
||||
:return: The computed weights for the physics loss.
|
||||
:rtype: LabelTensor
|
||||
"""
|
||||
# compute comulative loss and multiply by epsilos
|
||||
cumulative_loss = self._eps * torch.cumsum(loss, dim=0)
|
||||
# return the exponential of the weghited negative cumulative sum
|
||||
return torch.exp(-cumulative_loss)
|
||||
360
pina/solvers/pinns/competitive_pinn.py
Normal file
360
pina/solvers/pinns/competitive_pinn.py
Normal file
@@ -0,0 +1,360 @@
|
||||
""" Module for CompetitivePINN """
|
||||
|
||||
import torch
|
||||
import copy
|
||||
|
||||
try:
|
||||
from torch.optim.lr_scheduler import LRScheduler # torch >= 2.0
|
||||
except ImportError:
|
||||
from torch.optim.lr_scheduler import (
|
||||
_LRScheduler as LRScheduler,
|
||||
) # torch < 2.0
|
||||
|
||||
from torch.optim.lr_scheduler import ConstantLR
|
||||
|
||||
from .basepinn import PINNInterface
|
||||
from pina.utils import check_consistency
|
||||
from pina.problem import InverseProblem
|
||||
|
||||
|
||||
class CompetitivePINN(PINNInterface):
|
||||
r"""
|
||||
Competitive Physics Informed Neural Network (PINN) solver class.
|
||||
This class implements Competitive Physics Informed Neural
|
||||
Network solvers, using a user specified ``model`` to solve a specific
|
||||
``problem``. It can be used for solving both forward and inverse problems.
|
||||
|
||||
The Competitive Physics Informed Network aims to find
|
||||
the solution :math:`\mathbf{u}:\Omega\rightarrow\mathbb{R}^m`
|
||||
of the differential problem:
|
||||
|
||||
.. math::
|
||||
|
||||
\begin{cases}
|
||||
\mathcal{A}[\mathbf{u}](\mathbf{x})=0\quad,\mathbf{x}\in\Omega\\
|
||||
\mathcal{B}[\mathbf{u}](\mathbf{x})=0\quad,
|
||||
\mathbf{x}\in\partial\Omega
|
||||
\end{cases}
|
||||
|
||||
with a minimization (on ``model`` parameters) maximation (
|
||||
on ``discriminator`` parameters) of the loss function
|
||||
|
||||
.. math::
|
||||
\mathcal{L}_{\rm{problem}} = \frac{1}{N}\sum_{i=1}^N
|
||||
\mathcal{L}(D(\mathbf{x}_i)\mathcal{A}[\mathbf{u}](\mathbf{x}_i))+
|
||||
\frac{1}{N}\sum_{i=1}^N
|
||||
\mathcal{L}(D(\mathbf{x}_i)\mathcal{B}[\mathbf{u}](\mathbf{x}_i))
|
||||
|
||||
where :math:`D` is the discriminator network, which tries to find the points
|
||||
where the network performs worst, and :math:`\mathcal{L}` is a specific loss
|
||||
function, default Mean Square Error:
|
||||
|
||||
.. math::
|
||||
\mathcal{L}(v) = \| v \|^2_2.
|
||||
|
||||
.. seealso::
|
||||
|
||||
**Original reference**: Zeng, Qi, et al.
|
||||
"Competitive physics informed networks." International Conference on
|
||||
Learning Representations, ICLR 2022
|
||||
`OpenReview Preprint <https://openreview.net/forum?id=z9SIj-IM7tn>`_.
|
||||
|
||||
.. warning::
|
||||
This solver does not currently support the possibility to pass
|
||||
``extra_feature``.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
problem,
|
||||
model,
|
||||
discriminator=None,
|
||||
loss=torch.nn.MSELoss(),
|
||||
optimizer_model=torch.optim.Adam,
|
||||
optimizer_model_kwargs={"lr": 0.001},
|
||||
optimizer_discriminator=torch.optim.Adam,
|
||||
optimizer_discriminator_kwargs={"lr": 0.001},
|
||||
scheduler_model=ConstantLR,
|
||||
scheduler_model_kwargs={"factor": 1, "total_iters": 0},
|
||||
scheduler_discriminator=ConstantLR,
|
||||
scheduler_discriminator_kwargs={"factor": 1, "total_iters": 0},
|
||||
):
|
||||
"""
|
||||
:param AbstractProblem problem: The formualation of the problem.
|
||||
:param torch.nn.Module model: The neural network model to use
|
||||
for the model.
|
||||
:param torch.nn.Module discriminator: The neural network model to use
|
||||
for the discriminator. If ``None``, the discriminator network will
|
||||
have the same architecture as the model network.
|
||||
:param torch.nn.Module loss: The loss function used as minimizer,
|
||||
default :class:`torch.nn.MSELoss`.
|
||||
:param torch.optim.Optimizer optimizer_model: The neural
|
||||
network optimizer to use for the model network
|
||||
, default is `torch.optim.Adam`.
|
||||
:param dict optimizer_model_kwargs: Optimizer constructor keyword
|
||||
args. for the model.
|
||||
:param torch.optim.Optimizer optimizer_discriminator: The neural
|
||||
network optimizer to use for the discriminator network
|
||||
, default is `torch.optim.Adam`.
|
||||
:param dict optimizer_discriminator_kwargs: Optimizer constructor
|
||||
keyword args. for the discriminator.
|
||||
:param torch.optim.LRScheduler scheduler_model: Learning
|
||||
rate scheduler for the model.
|
||||
:param dict scheduler_model_kwargs: LR scheduler constructor
|
||||
keyword args.
|
||||
:param torch.optim.LRScheduler scheduler_discriminator: Learning
|
||||
rate scheduler for the discriminator.
|
||||
"""
|
||||
if discriminator is None:
|
||||
discriminator = copy.deepcopy(model)
|
||||
|
||||
super().__init__(
|
||||
models=[model, discriminator],
|
||||
problem=problem,
|
||||
optimizers=[optimizer_model, optimizer_discriminator],
|
||||
optimizers_kwargs=[
|
||||
optimizer_model_kwargs,
|
||||
optimizer_discriminator_kwargs,
|
||||
],
|
||||
extra_features=None, # CompetitivePINN doesn't take extra features
|
||||
loss=loss
|
||||
)
|
||||
|
||||
# set automatic optimization for GANs
|
||||
self.automatic_optimization = False
|
||||
|
||||
# check consistency
|
||||
check_consistency(scheduler_model, LRScheduler, subclass=True)
|
||||
check_consistency(scheduler_model_kwargs, dict)
|
||||
check_consistency(scheduler_discriminator, LRScheduler, subclass=True)
|
||||
check_consistency(scheduler_discriminator_kwargs, dict)
|
||||
|
||||
# assign schedulers
|
||||
self._schedulers = [
|
||||
scheduler_model(
|
||||
self.optimizers[0], **scheduler_model_kwargs
|
||||
),
|
||||
scheduler_discriminator(
|
||||
self.optimizers[1], **scheduler_discriminator_kwargs
|
||||
),
|
||||
]
|
||||
|
||||
self._model = self.models[0]
|
||||
self._discriminator = self.models[1]
|
||||
|
||||
def forward(self, x):
|
||||
r"""
|
||||
Forward pass implementation for the PINN solver. It returns the function
|
||||
evaluation :math:`\mathbf{u}(\mathbf{x})` at the control points
|
||||
:math:`\mathbf{x}`.
|
||||
|
||||
:param LabelTensor x: Input tensor for the PINN solver. It expects
|
||||
a tensor :math:`N \times D`, where :math:`N` the number of points
|
||||
in the mesh, :math:`D` the dimension of the problem,
|
||||
:return: PINN solution evaluated at contro points.
|
||||
:rtype: LabelTensor
|
||||
"""
|
||||
return self.neural_net(x)
|
||||
|
||||
def loss_phys(self, samples, equation):
|
||||
"""
|
||||
Computes the physics loss for the Competitive PINN solver based on given
|
||||
samples and equation.
|
||||
|
||||
:param LabelTensor samples: The samples to evaluate the physics loss.
|
||||
:param EquationInterface equation: The governing equation
|
||||
representing the physics.
|
||||
:return: The physics loss calculated based on given
|
||||
samples and equation.
|
||||
:rtype: LabelTensor
|
||||
"""
|
||||
# train one step of the model
|
||||
with torch.no_grad():
|
||||
discriminator_bets = self.discriminator(samples)
|
||||
loss_val = self._train_model(samples, equation, discriminator_bets)
|
||||
self.store_log(loss_value=float(loss_val))
|
||||
# detaching samples from the computational graph to erase it and setting
|
||||
# the gradient to true to create a new computational graph.
|
||||
# In alternative set `retain_graph=True`.
|
||||
samples = samples.detach()
|
||||
samples.requires_grad = True
|
||||
# train one step of discriminator
|
||||
discriminator_bets = self.discriminator(samples)
|
||||
self._train_discriminator(samples, equation, discriminator_bets)
|
||||
return loss_val
|
||||
|
||||
def loss_data(self, input_tensor, output_tensor):
|
||||
"""
|
||||
The data loss for the PINN solver. It computes the loss between the
|
||||
network output against the true solution.
|
||||
|
||||
:param LabelTensor input_tensor: The input to the neural networks.
|
||||
:param LabelTensor output_tensor: The true solution to compare the
|
||||
network solution.
|
||||
:return: The computed data loss.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
self.optimizer_model.zero_grad()
|
||||
loss_val = super().loss_data(
|
||||
input_tensor, output_tensor).as_subclass(torch.Tensor)
|
||||
loss_val.backward()
|
||||
self.optimizer_model.step()
|
||||
return loss_val
|
||||
|
||||
def configure_optimizers(self):
|
||||
"""
|
||||
Optimizer configuration for the Competitive PINN solver.
|
||||
|
||||
:return: The optimizers and the schedulers
|
||||
:rtype: tuple(list, list)
|
||||
"""
|
||||
# if the problem is an InverseProblem, add the unknown parameters
|
||||
# to the parameters that the optimizer needs to optimize
|
||||
if isinstance(self.problem, InverseProblem):
|
||||
self.optimizer_model.add_param_group(
|
||||
{
|
||||
"params": [
|
||||
self._params[var]
|
||||
for var in self.problem.unknown_variables
|
||||
]
|
||||
}
|
||||
)
|
||||
return self.optimizers, self._schedulers
|
||||
|
||||
def on_train_batch_end(self,outputs, batch, batch_idx):
|
||||
"""
|
||||
This method is called at the end of each training batch, and ovverides
|
||||
the PytorchLightining implementation for logging the checkpoints.
|
||||
|
||||
:param torch.Tensor outputs: The output from the model for the
|
||||
current batch.
|
||||
:param tuple batch: The current batch of data.
|
||||
:param int batch_idx: The index of the current batch.
|
||||
:return: Whatever is returned by the parent
|
||||
method ``on_train_batch_end``.
|
||||
:rtype: Any
|
||||
"""
|
||||
# increase by one the counter of optimization to save loggers
|
||||
self.trainer.fit_loop.epoch_loop.manual_optimization.optim_step_progress.total.completed += 1
|
||||
return super().on_train_batch_end(outputs, batch, batch_idx)
|
||||
|
||||
def _train_discriminator(self, samples, equation, discriminator_bets):
|
||||
"""
|
||||
Trains the discriminator network of the Competitive PINN.
|
||||
|
||||
:param LabelTensor samples: Input samples to evaluate the physics loss.
|
||||
:param EquationInterface equation: The governing equation representing
|
||||
the physics.
|
||||
:param Tensor discriminator_bets: Predictions made by the discriminator
|
||||
network.
|
||||
"""
|
||||
# manual optimization
|
||||
self.optimizer_discriminator.zero_grad()
|
||||
# compute residual, we detach because the weights of the generator
|
||||
# model are fixed
|
||||
residual = self.compute_residual(samples=samples,
|
||||
equation=equation).detach()
|
||||
# compute competitive residual, the minus is because we maximise
|
||||
competitive_residual = residual * discriminator_bets
|
||||
loss_val = - self.loss(
|
||||
torch.zeros_like(competitive_residual, requires_grad=True),
|
||||
competitive_residual
|
||||
).as_subclass(torch.Tensor)
|
||||
# backprop
|
||||
self.manual_backward(loss_val)
|
||||
self.optimizer_discriminator.step()
|
||||
return
|
||||
|
||||
def _train_model(self, samples, equation, discriminator_bets):
|
||||
"""
|
||||
Trains the model network of the Competitive PINN.
|
||||
|
||||
:param LabelTensor samples: Input samples to evaluate the physics loss.
|
||||
:param EquationInterface equation: The governing equation representing
|
||||
the physics.
|
||||
:param Tensor discriminator_bets: Predictions made by the discriminator.
|
||||
network.
|
||||
:return: The computed data loss.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
# manual optimization
|
||||
self.optimizer_model.zero_grad()
|
||||
# compute residual (detached for discriminator) and log
|
||||
residual = self.compute_residual(samples=samples, equation=equation)
|
||||
# store logging
|
||||
with torch.no_grad():
|
||||
loss_residual = self.loss(
|
||||
torch.zeros_like(residual),
|
||||
residual
|
||||
)
|
||||
# compute competitive residual, discriminator_bets are detached becase
|
||||
# we optimize only the generator model
|
||||
competitive_residual = residual * discriminator_bets.detach()
|
||||
loss_val = self.loss(
|
||||
torch.zeros_like(competitive_residual, requires_grad=True),
|
||||
competitive_residual
|
||||
).as_subclass(torch.Tensor)
|
||||
# backprop
|
||||
self.manual_backward(loss_val)
|
||||
self.optimizer_model.step()
|
||||
return loss_residual
|
||||
|
||||
@property
|
||||
def neural_net(self):
|
||||
"""
|
||||
Returns the neural network model.
|
||||
|
||||
:return: The neural network model.
|
||||
:rtype: torch.nn.Module
|
||||
"""
|
||||
return self._model
|
||||
|
||||
@property
|
||||
def discriminator(self):
|
||||
"""
|
||||
Returns the discriminator model (if applicable).
|
||||
|
||||
:return: The discriminator model.
|
||||
:rtype: torch.nn.Module
|
||||
"""
|
||||
return self._discriminator
|
||||
|
||||
@property
|
||||
def optimizer_model(self):
|
||||
"""
|
||||
Returns the optimizer associated with the neural network model.
|
||||
|
||||
:return: The optimizer for the neural network model.
|
||||
:rtype: torch.optim.Optimizer
|
||||
"""
|
||||
return self.optimizers[0]
|
||||
|
||||
@property
|
||||
def optimizer_discriminator(self):
|
||||
"""
|
||||
Returns the optimizer associated with the discriminator (if applicable).
|
||||
|
||||
:return: The optimizer for the discriminator.
|
||||
:rtype: torch.optim.Optimizer
|
||||
"""
|
||||
return self.optimizers[1]
|
||||
|
||||
@property
|
||||
def scheduler_model(self):
|
||||
"""
|
||||
Returns the scheduler associated with the neural network model.
|
||||
|
||||
:return: The scheduler for the neural network model.
|
||||
:rtype: torch.optim.lr_scheduler._LRScheduler
|
||||
"""
|
||||
return self._schedulers[0]
|
||||
|
||||
@property
|
||||
def scheduler_discriminator(self):
|
||||
"""
|
||||
Returns the scheduler associated with the discriminator (if applicable).
|
||||
|
||||
:return: The scheduler for the discriminator.
|
||||
:rtype: torch.optim.lr_scheduler._LRScheduler
|
||||
"""
|
||||
return self._schedulers[1]
|
||||
134
pina/solvers/pinns/gpinn.py
Normal file
134
pina/solvers/pinns/gpinn.py
Normal file
@@ -0,0 +1,134 @@
|
||||
""" Module for GPINN """
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
from torch.optim.lr_scheduler import ConstantLR
|
||||
|
||||
from .pinn import PINN
|
||||
from pina.operators import grad
|
||||
from pina.problem import SpatialProblem
|
||||
|
||||
|
||||
class GPINN(PINN):
|
||||
r"""
|
||||
Gradient Physics Informed Neural Network (GPINN) solver class.
|
||||
This class implements Gradient Physics Informed Neural
|
||||
Network solvers, using a user specified ``model`` to solve a specific
|
||||
``problem``. It can be used for solving both forward and inverse problems.
|
||||
|
||||
The Gradient Physics Informed Network aims to find
|
||||
the solution :math:`\mathbf{u}:\Omega\rightarrow\mathbb{R}^m`
|
||||
of the differential problem:
|
||||
|
||||
.. math::
|
||||
|
||||
\begin{cases}
|
||||
\mathcal{A}[\mathbf{u}](\mathbf{x})=0\quad,\mathbf{x}\in\Omega\\
|
||||
\mathcal{B}[\mathbf{u}](\mathbf{x})=0\quad,
|
||||
\mathbf{x}\in\partial\Omega
|
||||
\end{cases}
|
||||
|
||||
minimizing the loss function
|
||||
|
||||
.. math::
|
||||
\mathcal{L}_{\rm{problem}} =& \frac{1}{N}\sum_{i=1}^N
|
||||
\mathcal{L}(\mathcal{A}[\mathbf{u}](\mathbf{x}_i)) +
|
||||
\frac{1}{N}\sum_{i=1}^N
|
||||
\mathcal{L}(\mathcal{B}[\mathbf{u}](\mathbf{x}_i)) + \\
|
||||
&\frac{1}{N}\sum_{i=1}^N
|
||||
\nabla_{\mathbf{x}}\mathcal{L}(\mathcal{A}[\mathbf{u}](\mathbf{x}_i)) +
|
||||
\frac{1}{N}\sum_{i=1}^N
|
||||
\nabla_{\mathbf{x}}\mathcal{L}(\mathcal{B}[\mathbf{u}](\mathbf{x}_i))
|
||||
|
||||
|
||||
where :math:`\mathcal{L}` is a specific loss function, default Mean Square Error:
|
||||
|
||||
.. math::
|
||||
\mathcal{L}(v) = \| v \|^2_2.
|
||||
|
||||
.. seealso::
|
||||
|
||||
**Original reference**: Yu, Jeremy, et al. "Gradient-enhanced
|
||||
physics-informed neural networks for forward and inverse
|
||||
PDE problems." Computer Methods in Applied Mechanics
|
||||
and Engineering 393 (2022): 114823.
|
||||
DOI: `10.1016 <https://doi.org/10.1016/j.cma.2022.114823>`_.
|
||||
|
||||
.. note::
|
||||
This class can only work for problems inheriting
|
||||
from at least :class:`~pina.problem.spatial_problem.SpatialProblem`
|
||||
class.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
problem,
|
||||
model,
|
||||
extra_features=None,
|
||||
loss=torch.nn.MSELoss(),
|
||||
optimizer=torch.optim.Adam,
|
||||
optimizer_kwargs={"lr": 0.001},
|
||||
scheduler=ConstantLR,
|
||||
scheduler_kwargs={"factor": 1, "total_iters": 0},
|
||||
):
|
||||
"""
|
||||
:param AbstractProblem problem: The formulation of the problem. It must
|
||||
inherit from at least
|
||||
:class:`~pina.problem.spatial_problem.SpatialProblem` in order to
|
||||
compute the gradient of the loss.
|
||||
:param torch.nn.Module model: The neural network model to use.
|
||||
:param torch.nn.Module loss: The loss function used as minimizer,
|
||||
default :class:`torch.nn.MSELoss`.
|
||||
:param torch.nn.Module extra_features: The additional input
|
||||
features to use as augmented input.
|
||||
:param torch.optim.Optimizer optimizer: The neural network optimizer to
|
||||
use; default is :class:`torch.optim.Adam`.
|
||||
:param dict optimizer_kwargs: Optimizer constructor keyword args.
|
||||
:param torch.optim.LRScheduler scheduler: Learning
|
||||
rate scheduler.
|
||||
:param dict scheduler_kwargs: LR scheduler constructor keyword args.
|
||||
"""
|
||||
super().__init__(
|
||||
problem=problem,
|
||||
model=model,
|
||||
extra_features=extra_features,
|
||||
loss=loss,
|
||||
optimizer=optimizer,
|
||||
optimizer_kwargs=optimizer_kwargs,
|
||||
scheduler=scheduler,
|
||||
scheduler_kwargs=scheduler_kwargs,
|
||||
)
|
||||
if not isinstance(self.problem, SpatialProblem):
|
||||
raise ValueError('Gradient PINN computes the gradient of the '
|
||||
'PINN loss with respect to the spatial '
|
||||
'coordinates, thus the PINA problem must be '
|
||||
'a SpatialProblem.')
|
||||
|
||||
|
||||
def loss_phys(self, samples, equation):
|
||||
"""
|
||||
Computes the physics loss for the GPINN solver based on given
|
||||
samples and equation.
|
||||
|
||||
:param LabelTensor samples: The samples to evaluate the physics loss.
|
||||
:param EquationInterface equation: The governing equation
|
||||
representing the physics.
|
||||
:return: The physics loss calculated based on given
|
||||
samples and equation.
|
||||
:rtype: LabelTensor
|
||||
"""
|
||||
# classical PINN loss
|
||||
residual = self.compute_residual(samples=samples, equation=equation)
|
||||
loss_value = self.loss(
|
||||
torch.zeros_like(residual, requires_grad=True), residual
|
||||
)
|
||||
self.store_log(loss_value=float(loss_value))
|
||||
# gradient PINN loss
|
||||
loss_value = loss_value.reshape(-1, 1)
|
||||
loss_value.labels = ['__LOSS']
|
||||
loss_grad = grad(loss_value, samples, d=self.problem.spatial_variables)
|
||||
g_loss_phys = self.loss(
|
||||
torch.zeros_like(loss_grad, requires_grad=True), loss_grad
|
||||
)
|
||||
return loss_value + g_loss_phys
|
||||
170
pina/solvers/pinns/pinn.py
Normal file
170
pina/solvers/pinns/pinn.py
Normal file
@@ -0,0 +1,170 @@
|
||||
""" Module for Physics Informed Neural Network. """
|
||||
|
||||
import torch
|
||||
|
||||
try:
|
||||
from torch.optim.lr_scheduler import LRScheduler # torch >= 2.0
|
||||
except ImportError:
|
||||
from torch.optim.lr_scheduler import (
|
||||
_LRScheduler as LRScheduler,
|
||||
) # torch < 2.0
|
||||
|
||||
from torch.optim.lr_scheduler import ConstantLR
|
||||
|
||||
from .basepinn import PINNInterface
|
||||
from pina.utils import check_consistency
|
||||
from pina.problem import InverseProblem
|
||||
|
||||
|
||||
class PINN(PINNInterface):
|
||||
r"""
|
||||
Physics Informed Neural Network (PINN) solver class.
|
||||
This class implements Physics Informed Neural
|
||||
Network solvers, using a user specified ``model`` to solve a specific
|
||||
``problem``. It can be used for solving both forward and inverse problems.
|
||||
|
||||
The Physics Informed Network aims to find
|
||||
the solution :math:`\mathbf{u}:\Omega\rightarrow\mathbb{R}^m`
|
||||
of the differential problem:
|
||||
|
||||
.. math::
|
||||
|
||||
\begin{cases}
|
||||
\mathcal{A}[\mathbf{u}](\mathbf{x})=0\quad,\mathbf{x}\in\Omega\\
|
||||
\mathcal{B}[\mathbf{u}](\mathbf{x})=0\quad,
|
||||
\mathbf{x}\in\partial\Omega
|
||||
\end{cases}
|
||||
|
||||
minimizing the loss function
|
||||
|
||||
.. math::
|
||||
\mathcal{L}_{\rm{problem}} = \frac{1}{N}\sum_{i=1}^N
|
||||
\mathcal{L}(\mathcal{A}[\mathbf{u}](\mathbf{x}_i)) +
|
||||
\frac{1}{N}\sum_{i=1}^N
|
||||
\mathcal{L}(\mathcal{B}[\mathbf{u}](\mathbf{x}_i))
|
||||
|
||||
where :math:`\mathcal{L}` is a specific loss function, default Mean Square Error:
|
||||
|
||||
.. math::
|
||||
\mathcal{L}(v) = \| v \|^2_2.
|
||||
|
||||
.. seealso::
|
||||
|
||||
**Original reference**: Karniadakis, G. E., Kevrekidis, I. G., Lu, L.,
|
||||
Perdikaris, P., Wang, S., & Yang, L. (2021).
|
||||
Physics-informed machine learning. Nature Reviews Physics, 3, 422-440.
|
||||
DOI: `10.1038 <https://doi.org/10.1038/s42254-021-00314-5>`_.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
problem,
|
||||
model,
|
||||
extra_features=None,
|
||||
loss=torch.nn.MSELoss(),
|
||||
optimizer=torch.optim.Adam,
|
||||
optimizer_kwargs={"lr": 0.001},
|
||||
scheduler=ConstantLR,
|
||||
scheduler_kwargs={"factor": 1, "total_iters": 0},
|
||||
):
|
||||
"""
|
||||
:param AbstractProblem problem: The formulation of the problem.
|
||||
:param torch.nn.Module model: The neural network model to use.
|
||||
:param torch.nn.Module loss: The loss function used as minimizer,
|
||||
default :class:`torch.nn.MSELoss`.
|
||||
:param torch.nn.Module extra_features: The additional input
|
||||
features to use as augmented input.
|
||||
:param torch.optim.Optimizer optimizer: The neural network optimizer to
|
||||
use; default is :class:`torch.optim.Adam`.
|
||||
:param dict optimizer_kwargs: Optimizer constructor keyword args.
|
||||
:param torch.optim.LRScheduler scheduler: Learning
|
||||
rate scheduler.
|
||||
:param dict scheduler_kwargs: LR scheduler constructor keyword args.
|
||||
"""
|
||||
super().__init__(
|
||||
models=[model],
|
||||
problem=problem,
|
||||
optimizers=[optimizer],
|
||||
optimizers_kwargs=[optimizer_kwargs],
|
||||
extra_features=extra_features,
|
||||
loss=loss
|
||||
)
|
||||
|
||||
# check consistency
|
||||
check_consistency(scheduler, LRScheduler, subclass=True)
|
||||
check_consistency(scheduler_kwargs, dict)
|
||||
|
||||
# assign variables
|
||||
self._scheduler = scheduler(self.optimizers[0], **scheduler_kwargs)
|
||||
self._neural_net = self.models[0]
|
||||
|
||||
def forward(self, x):
|
||||
r"""
|
||||
Forward pass implementation for the PINN solver. It returns the function
|
||||
evaluation :math:`\mathbf{u}(\mathbf{x})` at the control points
|
||||
:math:`\mathbf{x}`.
|
||||
|
||||
:param LabelTensor x: Input tensor for the PINN solver. It expects
|
||||
a tensor :math:`N \times D`, where :math:`N` the number of points
|
||||
in the mesh, :math:`D` the dimension of the problem,
|
||||
:return: PINN solution evaluated at contro points.
|
||||
:rtype: LabelTensor
|
||||
"""
|
||||
return self.neural_net(x)
|
||||
|
||||
def loss_phys(self, samples, equation):
|
||||
"""
|
||||
Computes the physics loss for the PINN solver based on given
|
||||
samples and equation.
|
||||
|
||||
:param LabelTensor samples: The samples to evaluate the physics loss.
|
||||
:param EquationInterface equation: The governing equation
|
||||
representing the physics.
|
||||
:return: The physics loss calculated based on given
|
||||
samples and equation.
|
||||
:rtype: LabelTensor
|
||||
"""
|
||||
residual = self.compute_residual(samples=samples, equation=equation)
|
||||
loss_value = self.loss(
|
||||
torch.zeros_like(residual, requires_grad=True), residual
|
||||
)
|
||||
self.store_log(loss_value=float(loss_value))
|
||||
return loss_value
|
||||
|
||||
|
||||
def configure_optimizers(self):
|
||||
"""
|
||||
Optimizer configuration for the PINN
|
||||
solver.
|
||||
|
||||
:return: The optimizers and the schedulers
|
||||
:rtype: tuple(list, list)
|
||||
"""
|
||||
# if the problem is an InverseProblem, add the unknown parameters
|
||||
# to the parameters that the optimizer needs to optimize
|
||||
if isinstance(self.problem, InverseProblem):
|
||||
self.optimizers[0].add_param_group(
|
||||
{
|
||||
"params": [
|
||||
self._params[var]
|
||||
for var in self.problem.unknown_variables
|
||||
]
|
||||
}
|
||||
)
|
||||
return self.optimizers, [self.scheduler]
|
||||
|
||||
|
||||
@property
|
||||
def scheduler(self):
|
||||
"""
|
||||
Scheduler for the PINN training.
|
||||
"""
|
||||
return self._scheduler
|
||||
|
||||
|
||||
@property
|
||||
def neural_net(self):
|
||||
"""
|
||||
Neural network for the PINN training.
|
||||
"""
|
||||
return self._neural_net
|
||||
494
pina/solvers/pinns/sapinn.py
Normal file
494
pina/solvers/pinns/sapinn.py
Normal file
@@ -0,0 +1,494 @@
|
||||
import torch
|
||||
from copy import deepcopy
|
||||
|
||||
try:
|
||||
from torch.optim.lr_scheduler import LRScheduler # torch >= 2.0
|
||||
except ImportError:
|
||||
from torch.optim.lr_scheduler import (
|
||||
_LRScheduler as LRScheduler,
|
||||
) # torch < 2.0
|
||||
|
||||
from .basepinn import PINNInterface
|
||||
from pina.utils import check_consistency
|
||||
from pina.problem import InverseProblem
|
||||
|
||||
from torch.optim.lr_scheduler import ConstantLR
|
||||
|
||||
class Weights(torch.nn.Module):
|
||||
"""
|
||||
This class aims to implements the mask model for
|
||||
self adaptive weights of the Self-Adaptive
|
||||
PINN solver.
|
||||
"""
|
||||
|
||||
def __init__(self, func):
|
||||
"""
|
||||
:param torch.nn.Module func: the mask module of SAPINN
|
||||
"""
|
||||
super().__init__()
|
||||
check_consistency(func, torch.nn.Module)
|
||||
self.sa_weights = torch.nn.Parameter(
|
||||
torch.Tensor()
|
||||
)
|
||||
self.func = func
|
||||
|
||||
def forward(self):
|
||||
"""
|
||||
Forward pass implementation for the mask module.
|
||||
It returns the function on the weights
|
||||
evaluation.
|
||||
|
||||
:return: evaluation of self adaptive weights through the mask.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
return self.func(self.sa_weights)
|
||||
|
||||
class SAPINN(PINNInterface):
|
||||
r"""
|
||||
Self Adaptive Physics Informed Neural Network (SAPINN) solver class.
|
||||
This class implements Self-Adaptive Physics Informed Neural
|
||||
Network solvers, using a user specified ``model`` to solve a specific
|
||||
``problem``. It can be used for solving both forward and inverse problems.
|
||||
|
||||
The Self Adapive Physics Informed Neural Network aims to find
|
||||
the solution :math:`\mathbf{u}:\Omega\rightarrow\mathbb{R}^m`
|
||||
of the differential problem:
|
||||
|
||||
.. math::
|
||||
|
||||
\begin{cases}
|
||||
\mathcal{A}[\mathbf{u}](\mathbf{x})=0\quad,\mathbf{x}\in\Omega\\
|
||||
\mathcal{B}[\mathbf{u}](\mathbf{x})=0\quad,
|
||||
\mathbf{x}\in\partial\Omega
|
||||
\end{cases}
|
||||
|
||||
integrating the pointwise loss evaluation through a mask :math:`m` and
|
||||
self adaptive weights that permit to focus the loss function on
|
||||
specific training samples.
|
||||
The loss function to solve the problem is
|
||||
|
||||
.. math::
|
||||
|
||||
\mathcal{L}_{\rm{problem}} = \frac{1}{N} \sum_{i=1}^{N_\Omega} m
|
||||
\left( \lambda_{\Omega}^{i} \right) \mathcal{L} \left( \mathcal{A}
|
||||
[\mathbf{u}](\mathbf{x}) \right) + \frac{1}{N}
|
||||
\sum_{i=1}^{N_{\partial\Omega}}
|
||||
m \left( \lambda_{\partial\Omega}^{i} \right) \mathcal{L}
|
||||
\left( \mathcal{B}[\mathbf{u}](\mathbf{x})
|
||||
\right),
|
||||
|
||||
|
||||
denoting the self adaptive weights as
|
||||
:math:`\lambda_{\Omega}^1, \dots, \lambda_{\Omega}^{N_\Omega}` and
|
||||
:math:`\lambda_{\partial \Omega}^1, \dots,
|
||||
\lambda_{\Omega}^{N_\partial \Omega}`
|
||||
for :math:`\Omega` and :math:`\partial \Omega`, respectively.
|
||||
|
||||
Self Adaptive Physics Informed Neural Network identifies the solution
|
||||
and appropriate self adaptive weights by solving the following problem
|
||||
|
||||
.. math::
|
||||
|
||||
\min_{w} \max_{\lambda_{\Omega}^k, \lambda_{\partial \Omega}^s}
|
||||
\mathcal{L} ,
|
||||
|
||||
where :math:`w` denotes the network parameters, and
|
||||
:math:`\mathcal{L}` is a specific loss
|
||||
function, default Mean Square Error:
|
||||
|
||||
.. math::
|
||||
\mathcal{L}(v) = \| v \|^2_2.
|
||||
|
||||
.. seealso::
|
||||
**Original reference**: McClenny, Levi D., and Ulisses M. Braga-Neto.
|
||||
"Self-adaptive physics-informed neural networks."
|
||||
Journal of Computational Physics 474 (2023): 111722.
|
||||
DOI: `10.1016/
|
||||
j.jcp.2022.111722 <https://doi.org/10.1016/j.jcp.2022.111722>`_.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
problem,
|
||||
model,
|
||||
weights_function=torch.nn.Sigmoid(),
|
||||
extra_features=None,
|
||||
loss=torch.nn.MSELoss(),
|
||||
optimizer_model=torch.optim.Adam,
|
||||
optimizer_model_kwargs={"lr" : 0.001},
|
||||
optimizer_weights=torch.optim.Adam,
|
||||
optimizer_weights_kwargs={"lr" : 0.001},
|
||||
scheduler_model=ConstantLR,
|
||||
scheduler_model_kwargs={"factor" : 1, "total_iters" : 0},
|
||||
scheduler_weights=ConstantLR,
|
||||
scheduler_weights_kwargs={"factor" : 1, "total_iters" : 0}
|
||||
):
|
||||
"""
|
||||
:param AbstractProblem problem: The formualation of the problem.
|
||||
:param torch.nn.Module model: The neural network model to use
|
||||
for the model.
|
||||
:param torch.nn.Module weights_function: The neural network model
|
||||
related to the mask of SAPINN.
|
||||
default :obj:`~torch.nn.Sigmoid`.
|
||||
:param list(torch.nn.Module) extra_features: The additional input
|
||||
features to use as augmented input. If ``None`` no extra features
|
||||
are passed. If it is a list of :class:`torch.nn.Module`,
|
||||
the extra feature list is passed to all models. If it is a list
|
||||
of extra features' lists, each single list of extra feature
|
||||
is passed to a model.
|
||||
:param torch.nn.Module loss: The loss function used as minimizer,
|
||||
default :class:`torch.nn.MSELoss`.
|
||||
:param torch.optim.Optimizer optimizer_model: The neural
|
||||
network optimizer to use for the model network
|
||||
, default is `torch.optim.Adam`.
|
||||
:param dict optimizer_model_kwargs: Optimizer constructor keyword
|
||||
args. for the model.
|
||||
:param torch.optim.Optimizer optimizer_weights: The neural
|
||||
network optimizer to use for mask model model,
|
||||
default is `torch.optim.Adam`.
|
||||
:param dict optimizer_weights_kwargs: Optimizer constructor
|
||||
keyword args. for the mask module.
|
||||
:param torch.optim.LRScheduler scheduler_model: Learning
|
||||
rate scheduler for the model.
|
||||
:param dict scheduler_model_kwargs: LR scheduler constructor
|
||||
keyword args.
|
||||
:param torch.optim.LRScheduler scheduler_weights: Learning
|
||||
rate scheduler for the mask model.
|
||||
:param dict scheduler_model_kwargs: LR scheduler constructor
|
||||
keyword args.
|
||||
"""
|
||||
|
||||
# check consistency weitghs_function
|
||||
check_consistency(weights_function, torch.nn.Module)
|
||||
|
||||
# create models for weights
|
||||
weights_dict = {}
|
||||
for condition_name in problem.conditions:
|
||||
weights_dict[condition_name] = Weights(weights_function)
|
||||
weights_dict = torch.nn.ModuleDict(weights_dict)
|
||||
|
||||
|
||||
super().__init__(
|
||||
models=[model, weights_dict],
|
||||
problem=problem,
|
||||
optimizers=[optimizer_model, optimizer_weights],
|
||||
optimizers_kwargs=[
|
||||
optimizer_model_kwargs,
|
||||
optimizer_weights_kwargs
|
||||
],
|
||||
extra_features=extra_features,
|
||||
loss=loss
|
||||
)
|
||||
|
||||
# set automatic optimization
|
||||
self.automatic_optimization = False
|
||||
|
||||
# check consistency
|
||||
check_consistency(scheduler_model, LRScheduler, subclass=True)
|
||||
check_consistency(scheduler_model_kwargs, dict)
|
||||
check_consistency(scheduler_weights, LRScheduler, subclass=True)
|
||||
check_consistency(scheduler_weights_kwargs, dict)
|
||||
|
||||
# assign schedulers
|
||||
self._schedulers = [
|
||||
scheduler_model(
|
||||
self.optimizers[0], **scheduler_model_kwargs
|
||||
),
|
||||
scheduler_weights(
|
||||
self.optimizers[1], **scheduler_weights_kwargs
|
||||
),
|
||||
]
|
||||
|
||||
self._model = self.models[0]
|
||||
self._weights = self.models[1]
|
||||
|
||||
self._vectorial_loss = deepcopy(loss)
|
||||
self._vectorial_loss.reduction = "none"
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Forward pass implementation for the PINN
|
||||
solver. It returns the function
|
||||
evaluation :math:`\mathbf{u}(\mathbf{x})` at the control points
|
||||
:math:`\mathbf{x}`.
|
||||
|
||||
:param LabelTensor x: Input tensor for the SAPINN solver. It expects
|
||||
a tensor :math:`N \\times D`, where :math:`N` the number of points
|
||||
in the mesh, :math:`D` the dimension of the problem,
|
||||
:return: PINN solution.
|
||||
:rtype: LabelTensor
|
||||
"""
|
||||
return self.neural_net(x)
|
||||
|
||||
def loss_phys(self, samples, equation):
|
||||
"""
|
||||
Computes the physics loss for the SAPINN solver based on given
|
||||
samples and equation.
|
||||
|
||||
:param LabelTensor samples: The samples to evaluate the physics loss.
|
||||
:param EquationInterface equation: The governing equation
|
||||
representing the physics.
|
||||
:return: The physics loss calculated based on given
|
||||
samples and equation.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
# train weights
|
||||
self.optimizer_weights.zero_grad()
|
||||
weighted_loss, _ = self._loss_phys(samples, equation)
|
||||
loss_value = - weighted_loss.as_subclass(torch.Tensor)
|
||||
self.manual_backward(loss_value)
|
||||
self.optimizer_weights.step()
|
||||
|
||||
# detaching samples from the computational graph to erase it and setting
|
||||
# the gradient to true to create a new computational graph.
|
||||
# In alternative set `retain_graph=True`.
|
||||
samples = samples.detach()
|
||||
samples.requires_grad = True
|
||||
|
||||
# train model
|
||||
self.optimizer_model.zero_grad()
|
||||
weighted_loss, loss = self._loss_phys(samples, equation)
|
||||
loss_value = weighted_loss.as_subclass(torch.Tensor)
|
||||
self.manual_backward(loss_value)
|
||||
self.optimizer_model.step()
|
||||
|
||||
# store loss without weights
|
||||
self.store_log(loss_value=float(loss))
|
||||
return loss_value
|
||||
|
||||
def loss_data(self, input_tensor, output_tensor):
|
||||
"""
|
||||
Computes the data loss for the SAPINN solver based on input and
|
||||
output. It computes the loss between the
|
||||
network output against the true solution.
|
||||
|
||||
:param LabelTensor input_tensor: The input to the neural networks.
|
||||
:param LabelTensor output_tensor: The true solution to compare the
|
||||
network solution.
|
||||
:return: The computed data loss.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
# train weights
|
||||
self.optimizer_weights.zero_grad()
|
||||
weighted_loss, _ = self._loss_data(input_tensor, output_tensor)
|
||||
loss_value = - weighted_loss.as_subclass(torch.Tensor)
|
||||
self.manual_backward(loss_value)
|
||||
self.optimizer_weights.step()
|
||||
|
||||
# detaching samples from the computational graph to erase it and setting
|
||||
# the gradient to true to create a new computational graph.
|
||||
# In alternative set `retain_graph=True`.
|
||||
input_tensor = input_tensor.detach()
|
||||
input_tensor.requires_grad = True
|
||||
|
||||
# train model
|
||||
self.optimizer_model.zero_grad()
|
||||
weighted_loss, loss = self._loss_data(input_tensor, output_tensor)
|
||||
loss_value = weighted_loss.as_subclass(torch.Tensor)
|
||||
self.manual_backward(loss_value)
|
||||
self.optimizer_model.step()
|
||||
|
||||
# store loss without weights
|
||||
self.store_log(loss_value=float(loss))
|
||||
return loss_value
|
||||
|
||||
def configure_optimizers(self):
|
||||
"""
|
||||
Optimizer configuration for the SAPINN
|
||||
solver.
|
||||
|
||||
:return: The optimizers and the schedulers
|
||||
:rtype: tuple(list, list)
|
||||
"""
|
||||
# if the problem is an InverseProblem, add the unknown parameters
|
||||
# to the parameters that the optimizer needs to optimize
|
||||
if isinstance(self.problem, InverseProblem):
|
||||
self.optimizers[0].add_param_group(
|
||||
{
|
||||
"params": [
|
||||
self._params[var]
|
||||
for var in self.problem.unknown_variables
|
||||
]
|
||||
}
|
||||
)
|
||||
return self.optimizers, self._schedulers
|
||||
|
||||
def on_train_batch_end(self,outputs, batch, batch_idx):
|
||||
"""
|
||||
This method is called at the end of each training batch, and ovverides
|
||||
the PytorchLightining implementation for logging the checkpoints.
|
||||
|
||||
:param torch.Tensor outputs: The output from the model for the
|
||||
current batch.
|
||||
:param tuple batch: The current batch of data.
|
||||
:param int batch_idx: The index of the current batch.
|
||||
:return: Whatever is returned by the parent
|
||||
method ``on_train_batch_end``.
|
||||
:rtype: Any
|
||||
"""
|
||||
# increase by one the counter of optimization to save loggers
|
||||
self.trainer.fit_loop.epoch_loop.manual_optimization.optim_step_progress.total.completed += 1
|
||||
return super().on_train_batch_end(outputs, batch, batch_idx)
|
||||
|
||||
def on_train_start(self):
|
||||
"""
|
||||
This method is called at the start of the training for setting
|
||||
the self adaptive weights as parameters of the mask model.
|
||||
|
||||
:return: Whatever is returned by the parent
|
||||
method ``on_train_start``.
|
||||
:rtype: Any
|
||||
"""
|
||||
device = torch.device(
|
||||
self.trainer._accelerator_connector._accelerator_flag
|
||||
)
|
||||
for condition_name, tensor in self.problem.input_pts.items():
|
||||
self.weights_dict.torchmodel[condition_name].sa_weights.data = torch.rand(
|
||||
(tensor.shape[0], 1),
|
||||
device = device
|
||||
)
|
||||
return super().on_train_start()
|
||||
|
||||
def on_load_checkpoint(self, checkpoint):
|
||||
"""
|
||||
Overriding the Pytorch Lightning ``on_load_checkpoint`` to handle
|
||||
checkpoints for Self Adaptive Weights. This method should not be
|
||||
overridden if not intentionally.
|
||||
|
||||
:param dict checkpoint: Pytorch Lightning checkpoint dict.
|
||||
"""
|
||||
for condition_name, tensor in self.problem.input_pts.items():
|
||||
self.weights_dict.torchmodel[condition_name].sa_weights.data = torch.rand(
|
||||
(tensor.shape[0], 1)
|
||||
)
|
||||
return super().on_load_checkpoint(checkpoint)
|
||||
|
||||
def _loss_phys(self, samples, equation):
|
||||
"""
|
||||
Elaboration of the physical loss for the SAPINN solver.
|
||||
|
||||
:param LabelTensor samples: Input samples to evaluate the physics loss.
|
||||
:param EquationInterface equation: the governing equation representing
|
||||
the physics.
|
||||
|
||||
:return: tuple with weighted and not weighted scalar loss
|
||||
:rtype: List[LabelTensor, LabelTensor]
|
||||
"""
|
||||
residual = self.compute_residual(samples, equation)
|
||||
return self._compute_loss(residual)
|
||||
|
||||
def _loss_data(self, input_tensor, output_tensor):
|
||||
"""
|
||||
Elaboration of the loss related to data for the SAPINN solver.
|
||||
|
||||
:param LabelTensor input_tensor: The input to the neural networks.
|
||||
:param LabelTensor output_tensor: The true solution to compare the
|
||||
network solution.
|
||||
|
||||
:return: tuple with weighted and not weighted scalar loss
|
||||
:rtype: List[LabelTensor, LabelTensor]
|
||||
"""
|
||||
residual = self.forward(input_tensor) - output_tensor
|
||||
return self._compute_loss(residual)
|
||||
|
||||
def _compute_loss(self, residual):
|
||||
"""
|
||||
Elaboration of the pointwise loss through the mask model and the
|
||||
self adaptive weights
|
||||
|
||||
:param LabelTensor residual: the matrix of residuals that have to
|
||||
be weighted
|
||||
|
||||
:return: tuple with weighted and not weighted loss
|
||||
:rtype List[LabelTensor, LabelTensor]
|
||||
"""
|
||||
weights = self.weights_dict.torchmodel[
|
||||
self.current_condition_name].forward()
|
||||
loss_value = self._vectorial_loss(torch.zeros_like(
|
||||
residual, requires_grad=True), residual)
|
||||
return (
|
||||
self._vect_to_scalar(weights * loss_value),
|
||||
self._vect_to_scalar(loss_value)
|
||||
)
|
||||
|
||||
def _vect_to_scalar(self, loss_value):
|
||||
"""
|
||||
Elaboration of the pointwise loss through the mask model and the
|
||||
self adaptive weights
|
||||
|
||||
:param LabelTensor loss_value: the matrix of pointwise loss
|
||||
|
||||
:return: the scalar loss
|
||||
:rtype LabelTensor
|
||||
"""
|
||||
if self.loss.reduction == "mean":
|
||||
ret = torch.mean(loss_value)
|
||||
elif self.loss.reduction == "sum":
|
||||
ret = torch.sum(loss_value)
|
||||
else:
|
||||
raise RuntimeError(f"Invalid reduction, got {self.loss.reduction} "
|
||||
"but expected mean or sum.")
|
||||
return ret
|
||||
|
||||
|
||||
@property
|
||||
def neural_net(self):
|
||||
"""
|
||||
Returns the neural network model.
|
||||
|
||||
:return: The neural network model.
|
||||
:rtype: torch.nn.Module
|
||||
"""
|
||||
return self.models[0]
|
||||
|
||||
@property
|
||||
def weights_dict(self):
|
||||
"""
|
||||
Return the mask models associate to the application of
|
||||
the mask to the self adaptive weights for each loss that
|
||||
compones the global loss of the problem.
|
||||
|
||||
:return: The ModuleDict for mask models.
|
||||
:rtype: torch.nn.ModuleDict
|
||||
"""
|
||||
return self.models[1]
|
||||
|
||||
@property
|
||||
def scheduler_model(self):
|
||||
"""
|
||||
Returns the scheduler associated with the neural network model.
|
||||
|
||||
:return: The scheduler for the neural network model.
|
||||
:rtype: torch.optim.lr_scheduler._LRScheduler
|
||||
"""
|
||||
return self._scheduler[0]
|
||||
|
||||
@property
|
||||
def scheduler_weights(self):
|
||||
"""
|
||||
Returns the scheduler associated with the mask model (if applicable).
|
||||
|
||||
:return: The scheduler for the mask model.
|
||||
:rtype: torch.optim.lr_scheduler._LRScheduler
|
||||
"""
|
||||
return self._scheduler[1]
|
||||
|
||||
@property
|
||||
def optimizer_model(self):
|
||||
"""
|
||||
Returns the optimizer associated with the neural network model.
|
||||
|
||||
:return: The optimizer for the neural network model.
|
||||
:rtype: torch.optim.Optimizer
|
||||
"""
|
||||
return self.optimizers[0]
|
||||
|
||||
@property
|
||||
def optimizer_weights(self):
|
||||
"""
|
||||
Returns the optimizer associated with the mask model (if applicable).
|
||||
|
||||
:return: The optimizer for the mask model.
|
||||
:rtype: torch.optim.Optimizer
|
||||
"""
|
||||
return self.optimizers[1]
|
||||
190
pina/solvers/rom.py
Normal file
190
pina/solvers/rom.py
Normal file
@@ -0,0 +1,190 @@
|
||||
""" Module for ReducedOrderModelSolver """
|
||||
|
||||
import torch
|
||||
|
||||
from pina.solvers import SupervisedSolver
|
||||
|
||||
class ReducedOrderModelSolver(SupervisedSolver):
|
||||
r"""
|
||||
ReducedOrderModelSolver solver class. This class implements a
|
||||
Reduced Order Model solver, using user specified ``reduction_network`` and
|
||||
``interpolation_network`` to solve a specific ``problem``.
|
||||
|
||||
The Reduced Order Model approach aims to find
|
||||
the solution :math:`\mathbf{u}:\Omega\rightarrow\mathbb{R}^m`
|
||||
of the differential problem:
|
||||
|
||||
.. math::
|
||||
|
||||
\begin{cases}
|
||||
\mathcal{A}[\mathbf{u}(\mu)](\mathbf{x})=0\quad,\mathbf{x}\in\Omega\\
|
||||
\mathcal{B}[\mathbf{u}(\mu)](\mathbf{x})=0\quad,
|
||||
\mathbf{x}\in\partial\Omega
|
||||
\end{cases}
|
||||
|
||||
This is done by using two neural networks. The ``reduction_network``, which
|
||||
contains an encoder :math:`\mathcal{E}_{\rm{net}}`, a decoder
|
||||
:math:`\mathcal{D}_{\rm{net}}`; and an ``interpolation_network``
|
||||
:math:`\mathcal{I}_{\rm{net}}`. The input is assumed to be discretised in
|
||||
the spatial dimensions.
|
||||
|
||||
The following loss function is minimized during training
|
||||
|
||||
.. math::
|
||||
\mathcal{L}_{\rm{problem}} = \frac{1}{N}\sum_{i=1}^N
|
||||
\mathcal{L}(\mathcal{E}_{\rm{net}}[\mathbf{u}(\mu_i)] -
|
||||
\mathcal{I}_{\rm{net}}[\mu_i]) +
|
||||
\mathcal{L}(
|
||||
\mathcal{D}_{\rm{net}}[\mathcal{E}_{\rm{net}}[\mathbf{u}(\mu_i)]] -
|
||||
\mathbf{u}(\mu_i))
|
||||
|
||||
where :math:`\mathcal{L}` is a specific loss function, default Mean Square Error:
|
||||
|
||||
.. math::
|
||||
\mathcal{L}(v) = \| v \|^2_2.
|
||||
|
||||
|
||||
.. seealso::
|
||||
|
||||
**Original reference**: Hesthaven, Jan S., and Stefano Ubbiali.
|
||||
"Non-intrusive reduced order modeling of nonlinear problems
|
||||
using neural networks." Journal of Computational
|
||||
Physics 363 (2018): 55-78.
|
||||
DOI `10.1016/j.jcp.2018.02.037
|
||||
<https://doi.org/10.1016/j.jcp.2018.02.037>`_.
|
||||
|
||||
.. note::
|
||||
The specified ``reduction_network`` must contain two methods,
|
||||
namely ``encode`` for input encoding and ``decode`` for decoding the
|
||||
former result. The ``interpolation_network`` network ``forward`` output
|
||||
represents the interpolation of the latent space obtain with
|
||||
``reduction_network.encode``.
|
||||
|
||||
.. note::
|
||||
This solver uses the end-to-end training strategy, i.e. the
|
||||
``reduction_network`` and ``interpolation_network`` are trained
|
||||
simultaneously. For reference on this trainig strategy look at:
|
||||
Pichi, Federico, Beatriz Moya, and Jan S. Hesthaven.
|
||||
"A graph convolutional autoencoder approach to model order reduction
|
||||
for parametrized PDEs." Journal of
|
||||
Computational Physics 501 (2024): 112762.
|
||||
DOI
|
||||
`10.1016/j.jcp.2024.112762 <https://doi.org/10.1016/
|
||||
j.jcp.2024.112762>`_.
|
||||
|
||||
.. warning::
|
||||
This solver works only for data-driven model. Hence in the ``problem``
|
||||
definition the codition must only contain ``input_points``
|
||||
(e.g. coefficient parameters, time parameters), and ``output_points``.
|
||||
|
||||
.. warning::
|
||||
This solver does not currently support the possibility to pass
|
||||
``extra_feature``.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
problem,
|
||||
reduction_network,
|
||||
interpolation_network,
|
||||
loss=torch.nn.MSELoss(),
|
||||
optimizer=torch.optim.Adam,
|
||||
optimizer_kwargs={"lr": 0.001},
|
||||
scheduler=torch.optim.lr_scheduler.ConstantLR,
|
||||
scheduler_kwargs={"factor": 1, "total_iters": 0},
|
||||
):
|
||||
"""
|
||||
:param AbstractProblem problem: The formualation of the problem.
|
||||
:param torch.nn.Module reduction_network: The reduction network used
|
||||
for reducing the input space. It must contain two methods,
|
||||
namely ``encode`` for input encoding and ``decode`` for decoding the
|
||||
former result.
|
||||
:param torch.nn.Module interpolation_network: The interpolation network
|
||||
for interpolating the control parameters to latent space obtain by
|
||||
the ``reduction_network`` encoding.
|
||||
:param torch.nn.Module loss: The loss function used as minimizer,
|
||||
default :class:`torch.nn.MSELoss`.
|
||||
:param torch.nn.Module extra_features: The additional input
|
||||
features to use as augmented input.
|
||||
:param torch.optim.Optimizer optimizer: The neural network optimizer to
|
||||
use; default is :class:`torch.optim.Adam`.
|
||||
:param dict optimizer_kwargs: Optimizer constructor keyword args.
|
||||
:param float lr: The learning rate; default is 0.001.
|
||||
:param torch.optim.LRScheduler scheduler: Learning
|
||||
rate scheduler.
|
||||
:param dict scheduler_kwargs: LR scheduler constructor keyword args.
|
||||
"""
|
||||
model = torch.nn.ModuleDict({
|
||||
'reduction_network' : reduction_network,
|
||||
'interpolation_network' : interpolation_network})
|
||||
|
||||
super().__init__(
|
||||
model=model,
|
||||
problem=problem,
|
||||
loss=loss,
|
||||
optimizer=optimizer,
|
||||
optimizer_kwargs=optimizer_kwargs,
|
||||
scheduler=scheduler,
|
||||
scheduler_kwargs=scheduler_kwargs
|
||||
)
|
||||
|
||||
# assert reduction object contains encode/ decode
|
||||
if not hasattr(self.neural_net['reduction_network'], 'encode'):
|
||||
raise SyntaxError('reduction_network must have encode method. '
|
||||
'The encode method should return a lower '
|
||||
'dimensional representation of the input.')
|
||||
if not hasattr(self.neural_net['reduction_network'], 'decode'):
|
||||
raise SyntaxError('reduction_network must have decode method. '
|
||||
'The decode method should return a high '
|
||||
'dimensional representation of the encoding.')
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Forward pass implementation for the solver. It finds the encoder
|
||||
representation by calling ``interpolation_network.forward`` on the
|
||||
input, and maps this representation to output space by calling
|
||||
``reduction_network.decode``.
|
||||
|
||||
:param torch.Tensor x: Input tensor.
|
||||
:return: Solver solution.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
reduction_network = self.neural_net['reduction_network']
|
||||
interpolation_network = self.neural_net['interpolation_network']
|
||||
return reduction_network.decode(interpolation_network(x))
|
||||
|
||||
def loss_data(self, input_pts, output_pts):
|
||||
"""
|
||||
The data loss for the ReducedOrderModelSolver solver.
|
||||
It computes the loss between
|
||||
the network output against the true solution. This function
|
||||
should not be override if not intentionally.
|
||||
|
||||
:param LabelTensor input_tensor: The input to the neural networks.
|
||||
:param LabelTensor output_tensor: The true solution to compare the
|
||||
network solution.
|
||||
:return: The residual loss averaged on the input coordinates
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
# extract networks
|
||||
reduction_network = self.neural_net['reduction_network']
|
||||
interpolation_network = self.neural_net['interpolation_network']
|
||||
# encoded representations loss
|
||||
encode_repr_inter_net = interpolation_network(input_pts)
|
||||
encode_repr_reduction_network = reduction_network.encode(output_pts)
|
||||
loss_encode = self.loss(encode_repr_inter_net,
|
||||
encode_repr_reduction_network)
|
||||
# reconstruction loss
|
||||
loss_reconstruction = self.loss(
|
||||
reduction_network.decode(encode_repr_reduction_network),
|
||||
output_pts)
|
||||
|
||||
return loss_encode + loss_reconstruction
|
||||
|
||||
@property
|
||||
def neural_net(self):
|
||||
"""
|
||||
Neural network for training. It returns a :obj:`~torch.nn.ModuleDict`
|
||||
containing the ``reduction_network`` and ``interpolation_network``.
|
||||
"""
|
||||
return self._neural_net.torchmodel
|
||||
@@ -6,6 +6,7 @@ import pytorch_lightning
|
||||
from ..utils import check_consistency
|
||||
from ..problem import AbstractProblem
|
||||
import torch
|
||||
import sys
|
||||
|
||||
|
||||
class SolverInterface(pytorch_lightning.LightningModule, metaclass=ABCMeta):
|
||||
@@ -141,6 +142,20 @@ class SolverInterface(pytorch_lightning.LightningModule, metaclass=ABCMeta):
|
||||
"""
|
||||
The problem formulation."""
|
||||
return self._pina_problem
|
||||
|
||||
def on_train_start(self):
|
||||
"""
|
||||
On training epoch start this function is call to do global checks for
|
||||
the different solvers.
|
||||
"""
|
||||
|
||||
# 1. Check the verison for dataloader
|
||||
dataloader = self.trainer.train_dataloader
|
||||
if sys.version_info < (3, 8):
|
||||
dataloader = dataloader.loaders
|
||||
self._dataloader = dataloader
|
||||
|
||||
return super().on_train_start()
|
||||
|
||||
# @model.setter
|
||||
# def model(self, new_model):
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
""" Module for SupervisedSolver """
|
||||
|
||||
import torch
|
||||
import sys
|
||||
|
||||
try:
|
||||
from torch.optim.lr_scheduler import LRScheduler # torch >= 2.0
|
||||
@@ -20,9 +19,32 @@ from torch.nn.modules.loss import _Loss
|
||||
|
||||
|
||||
class SupervisedSolver(SolverInterface):
|
||||
"""
|
||||
r"""
|
||||
SupervisedSolver solver class. This class implements a SupervisedSolver,
|
||||
using a user specified ``model`` to solve a specific ``problem``.
|
||||
|
||||
The Supervised Solver class aims to find
|
||||
a map between the input :math:`\mathbf{s}:\Omega\rightarrow\mathbb{R}^m`
|
||||
and the output :math:`\mathbf{u}:\Omega\rightarrow\mathbb{R}^m`. The input
|
||||
can be discretised in space (as in :obj:`~pina.solvers.rom.ROMe2eSolver`),
|
||||
or not (e.g. when training Neural Operators).
|
||||
|
||||
Given a model :math:`\mathcal{M}`, the following loss function is
|
||||
minimized during training:
|
||||
|
||||
.. math::
|
||||
\mathcal{L}_{\rm{problem}} = \frac{1}{N}\sum_{i=1}^N
|
||||
\mathcal{L}(\mathbf{u}_i - \mathcal{M}(\mathbf{v}_i))
|
||||
|
||||
where :math:`\mathcal{L}` is a specific loss function,
|
||||
default Mean Square Error:
|
||||
|
||||
.. math::
|
||||
\mathcal{L}(v) = \| v \|^2_2.
|
||||
|
||||
In this context :math:`\mathbf{u}_i` and :math:`\mathbf{v}_i` means that
|
||||
we are seeking to approximate multiple (discretised) functions given
|
||||
multiple (discretised) input functions.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -96,18 +118,12 @@ class SupervisedSolver(SolverInterface):
|
||||
:return: The sum of the loss functions.
|
||||
:rtype: LabelTensor
|
||||
"""
|
||||
|
||||
dataloader = self.trainer.train_dataloader
|
||||
|
||||
condition_idx = batch["condition"]
|
||||
|
||||
for condition_id in range(condition_idx.min(), condition_idx.max() + 1):
|
||||
|
||||
if sys.version_info >= (3, 8):
|
||||
condition_name = dataloader.condition_names[condition_id]
|
||||
else:
|
||||
condition_name = dataloader.loaders.condition_names[
|
||||
condition_id
|
||||
]
|
||||
condition_name = self._dataloader.condition_names[condition_id]
|
||||
condition = self.problem.conditions[condition_name]
|
||||
pts = batch["pts"]
|
||||
out = batch["output"]
|
||||
@@ -118,14 +134,14 @@ class SupervisedSolver(SolverInterface):
|
||||
# for data driven mode
|
||||
if not hasattr(condition, "output_points"):
|
||||
raise NotImplementedError(
|
||||
"Supervised solver works only in data-driven mode."
|
||||
f"{type(self).__name__} works only in data-driven mode."
|
||||
)
|
||||
|
||||
output_pts = out[condition_idx == condition_id]
|
||||
input_pts = pts[condition_idx == condition_id]
|
||||
|
||||
loss = (
|
||||
self.loss(self.forward(input_pts), output_pts)
|
||||
self.loss_data(input_pts=input_pts, output_pts=output_pts)
|
||||
* condition.data_weight
|
||||
)
|
||||
loss = loss.as_subclass(torch.Tensor)
|
||||
@@ -133,6 +149,20 @@ class SupervisedSolver(SolverInterface):
|
||||
self.log("mean_loss", float(loss), prog_bar=True, logger=True)
|
||||
return loss
|
||||
|
||||
def loss_data(self, input_pts, output_pts):
|
||||
"""
|
||||
The data loss for the Supervised solver. It computes the loss between
|
||||
the network output against the true solution. This function
|
||||
should not be override if not intentionally.
|
||||
|
||||
:param LabelTensor input_tensor: The input to the neural networks.
|
||||
:param LabelTensor output_tensor: The true solution to compare the
|
||||
network solution.
|
||||
:return: The residual loss averaged on the input coordinates
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
return self.loss(self.forward(input_pts), output_pts)
|
||||
|
||||
@property
|
||||
def scheduler(self):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user