Update solvers (#434)

* Enable DDP training with batch_size=None and add validity check for split sizes
* Refactoring SolverInterfaces (#435)
* Solver update + weighting
* Updating PINN for 0.2
* Modify GAROM + tests
* Adding more versatile loggers
* Disable compilation when running on Windows
* Fix tests

---------

Co-authored-by: giovanni <giovanni.canali98@yahoo.it>
Co-authored-by: FilippoOlivo <filippo@filippoolivo.com>
This commit is contained in:
Dario Coscia
2025-02-17 11:26:21 +01:00
committed by Nicola Demo
parent 780c4921eb
commit 9cae9a438f
50 changed files with 2848 additions and 4187 deletions

View File

@@ -1,17 +1,17 @@
__all__ = [
"PINNInterface",
"PINN",
"GPINN",
"GradientPINN",
"CausalPINN",
"CompetitivePINN",
"SAPINN",
"SelfAdaptivePINN",
"RBAPINN",
]
from .pinn_interface import PINNInterface
from .pinn import PINN
from .gpinn import GPINN
from .causalpinn import CausalPINN
from .rba_pinn import RBAPINN
from .causal_pinn import CausalPINN
from .gradient_pinn import GradientPINN
from .competitive_pinn import CompetitivePINN
from .sapinn import SAPINN
from .rbapinn import RBAPINN
from .self_adaptive_pinn import SelfAdaptivePINN

View File

@@ -1,18 +1,15 @@
""" Module for CausalPINN """
""" Module for Causal PINN. """
import torch
from torch.optim.lr_scheduler import ConstantLR
from .pinn import PINN
from pina.problem import TimeDependentProblem
from .pinn import PINN
from pina.utils import check_consistency
class CausalPINN(PINN):
r"""
Causal Physics Informed Neural Network (PINN) solver class.
Causal Physics Informed Neural Network (CausalPINN) solver class.
This class implements Causal Physics Informed Neural
Network solvers, using a user specified ``model`` to solve a specific
``problem``. It can be used for solving both forward and inverse problems.
@@ -70,45 +67,33 @@ class CausalPINN(PINN):
:class:`~pina.problem.timedep_problem.TimeDependentProblem` class.
"""
def __init__(
self,
problem,
model,
extra_features=None,
loss=torch.nn.MSELoss(),
optimizer=torch.optim.Adam,
optimizer_kwargs={"lr": 0.001},
scheduler=ConstantLR,
scheduler_kwargs={"factor": 1, "total_iters": 0},
eps=100,
):
def __init__(self,
problem,
model,
optimizer=None,
scheduler=None,
weighting=None,
loss=None,
eps=100):
"""
:param AbstractProblem problem: The formulation of the problem.
:param torch.nn.Module model: The neural network model to use.
:param torch.nn.Module loss: The loss function used as minimizer,
default :class:`torch.nn.MSELoss`.
:param torch.nn.Module extra_features: The additional input
features to use as augmented input.
:param AbstractProblem problem: The formulation of the problem.
:param torch.optim.Optimizer optimizer: The neural network optimizer to
use; default is :class:`torch.optim.Adam`.
:param dict optimizer_kwargs: Optimizer constructor keyword args.
:param torch.optim.LRScheduler scheduler: Learning
rate scheduler.
:param dict scheduler_kwargs: LR scheduler constructor keyword args.
:param int | float eps: The exponential decay parameter. Note that this
value is kept fixed during the training, but can be changed by means
of a callback, e.g. for annealing.
use; default `None`.
:param torch.optim.LRScheduler scheduler: Learning rate scheduler;
default `None`.
:param WeightingInterface weighting: The weighting schema to use;
default `None`.
:param torch.nn.Module loss: The loss function to be minimized;
default `None`.
:param float eps: The exponential decay parameter; default `100`.
"""
super().__init__(
problem=problem,
model=model,
extra_features=extra_features,
loss=loss,
optimizer=optimizer,
optimizer_kwargs=optimizer_kwargs,
scheduler=scheduler,
scheduler_kwargs=scheduler_kwargs,
)
super().__init__(model=model,
problem=problem,
optimizer=optimizer,
scheduler=scheduler,
weighting=weighting,
loss=loss)
# checking consistency
check_consistency(eps, (int, float))
@@ -116,7 +101,7 @@ class CausalPINN(PINN):
if not isinstance(self.problem, TimeDependentProblem):
raise ValueError(
"Casual PINN works only for problems"
"inheritig from TimeDependentProblem."
"inheriting from TimeDependentProblem."
)
def loss_phys(self, samples, equation):
@@ -134,8 +119,8 @@ class CausalPINN(PINN):
# split sequentially ordered time tensors into chunks
chunks, labels = self._split_tensor_into_chunks(samples)
# compute residuals - this correspond to ordered loss functions
# values for each time step. We apply `flatten` such that after
# concataning the residuals we obtain a tensor of shape #chunks
# values for each time step. Apply `flatten` to ensure obtaining
# a tensor of shape #chunks after concatenating the residuals
time_loss = []
for chunk in chunks:
chunk.labels = labels
@@ -145,11 +130,10 @@ class CausalPINN(PINN):
torch.zeros_like(residual, requires_grad=True), residual
)
time_loss.append(loss_val)
# store results
self.store_log(loss_value=float(sum(time_loss) / len(time_loss)))
# concatenate residuals
time_loss = torch.stack(time_loss)
# compute weights (without the gradient storing)
# compute weights without storing the gradient
with torch.no_grad():
weights = self._compute_weights(time_loss)
return (weights * time_loss).mean()
@@ -197,17 +181,17 @@ class CausalPINN(PINN):
:return: Tuple containing the chunks and the original labels.
:rtype: Tuple[List[LabelTensor], List]
"""
# labels input tensors
# extract labels
labels = tensor.labels
# labels input tensors
# sort input tensor based on time
tensor = self._sort_label_tensor(tensor)
# extract time tensor
time_tensor = tensor.extract(self.problem.temporal_domain.variables)
# count unique tensors in time
_, idx_split = time_tensor.unique(return_counts=True)
# splitting
# split the tensor based on time
chunks = torch.split(tensor, tuple(idx_split))
return chunks, labels # return chunks
return chunks, labels
def _compute_weights(self, loss):
"""
@@ -217,7 +201,7 @@ class CausalPINN(PINN):
:return: The computed weights for the physics loss.
:rtype: LabelTensor
"""
# compute comulative loss and multiply by epsilos
# compute comulative loss and multiply by epsilon
cumulative_loss = self._eps * torch.cumsum(loss, dim=0)
# return the exponential of the weghited negative cumulative sum
# return the exponential of the negative weighted cumulative sum
return torch.exp(-cumulative_loss)

View File

@@ -1,23 +1,14 @@
""" Module for CompetitivePINN """
""" Module for Competitive PINN. """
import torch
import copy
try:
from torch.optim.lr_scheduler import LRScheduler # torch >= 2.0
except ImportError:
from torch.optim.lr_scheduler import (
_LRScheduler as LRScheduler,
) # torch < 2.0
from torch.optim.lr_scheduler import ConstantLR
from .pinn_interface import PINNInterface
from pina.utils import check_consistency
from pina.problem import InverseProblem
from .pinn_interface import PINNInterface
from ..solver import MultiSolverInterface
class CompetitivePINN(PINNInterface):
class CompetitivePINN(PINNInterface, MultiSolverInterface):
r"""
Competitive Physics Informed Neural Network (PINN) solver class.
This class implements Competitive Physics Informed Neural
@@ -64,82 +55,49 @@ class CompetitivePINN(PINNInterface):
``extra_feature``.
"""
def __init__(
self,
problem,
model,
discriminator=None,
loss=torch.nn.MSELoss(),
optimizer_model=torch.optim.Adam,
optimizer_model_kwargs={"lr": 0.001},
optimizer_discriminator=torch.optim.Adam,
optimizer_discriminator_kwargs={"lr": 0.001},
scheduler_model=ConstantLR,
scheduler_model_kwargs={"factor": 1, "total_iters": 0},
scheduler_discriminator=ConstantLR,
scheduler_discriminator_kwargs={"factor": 1, "total_iters": 0},
):
def __init__(self,
problem,
model,
discriminator=None,
optimizer_model=None,
optimizer_discriminator=None,
scheduler_model=None,
scheduler_discriminator=None,
weighting=None,
loss=None):
"""
:param AbstractProblem problem: The formualation of the problem.
:param AbstractProblem problem: The formulation of the problem.
:param torch.nn.Module model: The neural network model to use
for the model.
:param torch.nn.Module discriminator: The neural network model to use
for the discriminator. If ``None``, the discriminator network will
have the same architecture as the model network.
:param torch.nn.Module loss: The loss function used as minimizer,
default :class:`torch.nn.MSELoss`.
:param torch.optim.Optimizer optimizer_model: The neural
network optimizer to use for the model network
, default is `torch.optim.Adam`.
:param dict optimizer_model_kwargs: Optimizer constructor keyword
args. for the model.
:param torch.optim.Optimizer optimizer_discriminator: The neural
network optimizer to use for the discriminator network
, default is `torch.optim.Adam`.
:param dict optimizer_discriminator_kwargs: Optimizer constructor
keyword args. for the discriminator.
:param torch.optim.LRScheduler scheduler_model: Learning
rate scheduler for the model.
:param dict scheduler_model_kwargs: LR scheduler constructor
keyword args.
:param torch.optim.LRScheduler scheduler_discriminator: Learning
rate scheduler for the discriminator.
:param torch.optim.Optimizer optimizer_model: The neural network
optimizer to use for the model network; default `None`.
:param torch.optim.Optimizer optimizer_discriminator: The neural network
optimizer to use for the discriminator network; default `None`.
:param torch.optim.LRScheduler scheduler_model: Learning rate scheduler
for the model; default `None`.
:param torch.optim.LRScheduler scheduler_discriminator: Learning rate
scheduler for the discriminator; default `None`.
:param WeightingInterface weighting: The weighting schema to use;
default `None`.
:param torch.nn.Module loss: The loss function to be minimized;
default `None`.
"""
if discriminator is None:
discriminator = copy.deepcopy(model)
super().__init__(
models=[model, discriminator],
problem=problem,
optimizers=[optimizer_model, optimizer_discriminator],
optimizers_kwargs=[
optimizer_model_kwargs,
optimizer_discriminator_kwargs,
],
extra_features=None, # CompetitivePINN doesn't take extra features
loss=loss,
)
super().__init__(models=[model, discriminator],
problem=problem,
optimizers=[optimizer_model, optimizer_discriminator],
schedulers=[scheduler_model, scheduler_discriminator],
weighting=weighting,
loss=loss)
# set automatic optimization for GANs
# Set automatic optimization to False
self.automatic_optimization = False
# check consistency
check_consistency(scheduler_model, LRScheduler, subclass=True)
check_consistency(scheduler_model_kwargs, dict)
check_consistency(scheduler_discriminator, LRScheduler, subclass=True)
check_consistency(scheduler_discriminator_kwargs, dict)
# assign schedulers
self._schedulers = [
scheduler_model(self.optimizers[0], **scheduler_model_kwargs),
scheduler_discriminator(
self.optimizers[1], **scheduler_discriminator_kwargs
),
]
self._model = self.models[0]
self._discriminator = self.models[1]
def forward(self, x):
r"""
Forward pass implementation for the PINN solver. It returns the function
@@ -154,6 +112,22 @@ class CompetitivePINN(PINNInterface):
"""
return self.neural_net(x)
def training_step(self, batch):
"""
Solver training step, overridden to perform manual optimization.
:param batch: The batch element in the dataloader.
:type batch: tuple
:return: The sum of the loss functions.
:rtype: LabelTensor
"""
self.optimizer_model.instance.zero_grad()
self.optimizer_discriminator.instance.zero_grad()
loss = super().training_step(batch)
self.optimizer_model.instance.step()
self.optimizer_discriminator.instance.step()
return loss
def loss_phys(self, samples, equation):
"""
Computes the physics loss for the Competitive PINN solver based on given
@@ -166,25 +140,26 @@ class CompetitivePINN(PINNInterface):
samples and equation.
:rtype: LabelTensor
"""
# train one step of the model
# Train the model for one step
with torch.no_grad():
discriminator_bets = self.discriminator(samples)
loss_val = self._train_model(samples, equation, discriminator_bets)
self.store_log(loss_value=float(loss_val))
# detaching samples from the computational graph to erase it and setting
# the gradient to true to create a new computational graph.
# Detach samples from the existing computational graph and
# create a new one by setting requires_grad to True.
# In alternative set `retain_graph=True`.
samples = samples.detach()
samples.requires_grad = True
# train one step of discriminator
samples.requires_grad_()
# Train the discriminator for one step
discriminator_bets = self.discriminator(samples)
self._train_discriminator(samples, equation, discriminator_bets)
return loss_val
def loss_data(self, input_tensor, output_tensor):
def loss_data(self, input_pts, output_pts):
"""
The data loss for the PINN solver. It computes the loss between the
network output against the true solution.
The data loss for the CompetitivePINN solver. It computes the loss
between the network output against the true solution.
:param LabelTensor input_tensor: The input to the neural networks.
:param LabelTensor output_tensor: The true solution to compare the
@@ -192,14 +167,9 @@ class CompetitivePINN(PINNInterface):
:return: The computed data loss.
:rtype: torch.Tensor
"""
self.optimizer_model.zero_grad()
loss_val = (
super()
.loss_data(input_tensor, output_tensor)
.as_subclass(torch.Tensor)
)
loss_val = (super().loss_data(input_pts, output_pts))
# prepare for optimizer step called in training step
loss_val.backward()
self.optimizer_model.step()
return loss_val
def configure_optimizers(self):
@@ -209,10 +179,12 @@ class CompetitivePINN(PINNInterface):
:return: The optimizers and the schedulers
:rtype: tuple(list, list)
"""
# if the problem is an InverseProblem, add the unknown parameters
# to the parameters that the optimizer needs to optimize
# If the problem is an InverseProblem, add the unknown parameters
# to the parameters to be optimized
self.optimizer_model.hook(self.neural_net.parameters())
self.optimizer_discriminator.hook(self.discriminator.parameters())
if isinstance(self.problem, InverseProblem):
self.optimizer_model.add_param_group(
self.optimizer_model.instance.add_param_group(
{
"params": [
self._params[var]
@@ -220,7 +192,14 @@ class CompetitivePINN(PINNInterface):
]
}
)
return self.optimizers, self._schedulers
self.scheduler_model.hook(self.optimizer_model)
self.scheduler_discriminator.hook(self.optimizer_discriminator)
return (
[self.optimizer_model.instance,
self.optimizer_discriminator.instance],
[self.scheduler_model.instance,
self.scheduler_discriminator.instance]
)
def on_train_batch_end(self, outputs, batch, batch_idx):
"""
@@ -236,9 +215,11 @@ class CompetitivePINN(PINNInterface):
:rtype: Any
"""
# increase by one the counter of optimization to save loggers
self.trainer.fit_loop.epoch_loop.manual_optimization.optim_step_progress.total.completed += (
1
)
(
self.trainer.fit_loop.epoch_loop.manual_optimization
.optim_step_progress.total.completed
) += 1
return super().on_train_batch_end(outputs, batch, batch_idx)
def _train_discriminator(self, samples, equation, discriminator_bets):
@@ -251,22 +232,19 @@ class CompetitivePINN(PINNInterface):
:param Tensor discriminator_bets: Predictions made by the discriminator
network.
"""
# manual optimization
self.optimizer_discriminator.zero_grad()
# compute residual, we detach because the weights of the generator
# model are fixed
# Compute residual. Detach since discriminator weights are fixed
residual = self.compute_residual(
samples=samples, equation=equation
).detach()
# compute competitive residual, the minus is because we maximise
# Compute competitive residual, then maximise the loss
competitive_residual = residual * discriminator_bets
loss_val = -self.loss(
torch.zeros_like(competitive_residual, requires_grad=True),
competitive_residual,
).as_subclass(torch.Tensor)
# backprop
)
# prepare for optimizer step called in training step
self.manual_backward(loss_val)
self.optimizer_discriminator.step()
return
def _train_model(self, samples, equation, discriminator_bets):
@@ -281,23 +259,20 @@ class CompetitivePINN(PINNInterface):
:return: The computed data loss.
:rtype: torch.Tensor
"""
# manual optimization
self.optimizer_model.zero_grad()
# compute residual (detached for discriminator) and log
# Compute residual
residual = self.compute_residual(samples=samples, equation=equation)
# store logging
with torch.no_grad():
loss_residual = self.loss(torch.zeros_like(residual), residual)
# compute competitive residual, discriminator_bets are detached becase
# we optimize only the generator model
# Compute competitive residual. Detach discriminator_bets
# to optimize only the generator model
competitive_residual = residual * discriminator_bets.detach()
loss_val = self.loss(
torch.zeros_like(competitive_residual, requires_grad=True),
competitive_residual,
).as_subclass(torch.Tensor)
# backprop
)
# prepare for optimizer step called in training step
self.manual_backward(loss_val)
self.optimizer_model.step()
return loss_residual
@property
@@ -308,7 +283,7 @@ class CompetitivePINN(PINNInterface):
:return: The neural network model.
:rtype: torch.nn.Module
"""
return self._model
return self.models[0]
@property
def discriminator(self):
@@ -318,7 +293,7 @@ class CompetitivePINN(PINNInterface):
:return: The discriminator model.
:rtype: torch.nn.Module
"""
return self._discriminator
return self.models[1]
@property
def optimizer_model(self):
@@ -348,7 +323,7 @@ class CompetitivePINN(PINNInterface):
:return: The scheduler for the neural network model.
:rtype: torch.optim.lr_scheduler._LRScheduler
"""
return self._schedulers[0]
return self.schedulers[0]
@property
def scheduler_discriminator(self):
@@ -358,4 +333,4 @@ class CompetitivePINN(PINNInterface):
:return: The scheduler for the discriminator.
:rtype: torch.optim.lr_scheduler._LRScheduler
"""
return self._schedulers[1]
return self.schedulers[1]

View File

@@ -1,18 +1,15 @@
""" Module for GPINN """
""" Module for Gradient PINN. """
import torch
from torch.optim.lr_scheduler import ConstantLR
from .pinn import PINN
from pina.operators import grad
from pina.problem import SpatialProblem
class GPINN(PINN):
class GradientPINN(PINN):
r"""
Gradient Physics Informed Neural Network (GPINN) solver class.
Gradient Physics Informed Neural Network (GradientPINN) solver class.
This class implements Gradient Physics Informed Neural
Network solvers, using a user specified ``model`` to solve a specific
``problem``. It can be used for solving both forward and inverse problems.
@@ -42,7 +39,8 @@ class GPINN(PINN):
\nabla_{\mathbf{x}}\mathcal{L}(\mathcal{B}[\mathbf{u}](\mathbf{x}_i))
where :math:`\mathcal{L}` is a specific loss function, default Mean Square Error:
where :math:`\mathcal{L}` is a specific loss function,
default Mean Square Error:
.. math::
\mathcal{L}(v) = \| v \|^2_2.
@@ -61,44 +59,35 @@ class GPINN(PINN):
class.
"""
def __init__(
self,
problem,
model,
extra_features=None,
loss=torch.nn.MSELoss(),
optimizer=torch.optim.Adam,
optimizer_kwargs={"lr": 0.001},
scheduler=ConstantLR,
scheduler_kwargs={"factor": 1, "total_iters": 0},
):
def __init__(self,
problem,
model,
optimizer=None,
scheduler=None,
weighting=None,
loss=None):
"""
:param torch.nn.Module model: The neural network model to use.
:param AbstractProblem problem: The formulation of the problem. It must
inherit from at least
:class:`~pina.problem.spatial_problem.SpatialProblem` in order to
compute the gradient of the loss.
:param torch.nn.Module model: The neural network model to use.
:param torch.nn.Module loss: The loss function used as minimizer,
default :class:`torch.nn.MSELoss`.
:param torch.nn.Module extra_features: The additional input
features to use as augmented input.
:class:`~pina.problem.spatial_problem.SpatialProblem` to compute
the gradient of the loss.
:param torch.optim.Optimizer optimizer: The neural network optimizer to
use; default is :class:`torch.optim.Adam`.
:param dict optimizer_kwargs: Optimizer constructor keyword args.
:param torch.optim.LRScheduler scheduler: Learning
rate scheduler.
:param dict scheduler_kwargs: LR scheduler constructor keyword args.
use; default `None`.
:param torch.optim.LRScheduler scheduler: Learning rate scheduler;
default `None`.
:param WeightingInterface weighting: The weighting schema to use;
default `None`.
:param torch.nn.Module loss: The loss function to be minimized;
default `None`.
"""
super().__init__(
problem=problem,
model=model,
extra_features=extra_features,
loss=loss,
optimizer=optimizer,
optimizer_kwargs=optimizer_kwargs,
scheduler=scheduler,
scheduler_kwargs=scheduler_kwargs,
)
super().__init__(model=model,
problem=problem,
optimizer=optimizer,
scheduler=scheduler,
weighting=weighting,
loss=loss)
if not isinstance(self.problem, SpatialProblem):
raise ValueError(
"Gradient PINN computes the gradient of the "
@@ -124,10 +113,10 @@ class GPINN(PINN):
loss_value = self.loss(
torch.zeros_like(residual, requires_grad=True), residual
)
self.store_log(loss_value=float(loss_value))
# gradient PINN loss
loss_value = loss_value.reshape(-1, 1)
loss_value.labels = ["__LOSS"]
loss_value.labels = ["__loss"]
loss_grad = grad(loss_value, samples, d=self.problem.spatial_variables)
g_loss_phys = self.loss(
torch.zeros_like(loss_grad, requires_grad=True), loss_grad

View File

@@ -2,19 +2,12 @@
import torch
try:
from torch.optim.lr_scheduler import LRScheduler # torch >= 2.0
except ImportError:
from torch.optim.lr_scheduler import (
_LRScheduler as LRScheduler,
) # torch < 2.0
from .pinn_interface import PINNInterface
from ..solver import SingleSolverInterface
from ...problem import InverseProblem
class PINN(PINNInterface):
class PINN(PINNInterface, SingleSolverInterface):
r"""
Physics Informed Neural Network (PINN) solver class.
This class implements Physics Informed Neural
@@ -41,7 +34,8 @@ class PINN(PINNInterface):
\frac{1}{N}\sum_{i=1}^N
\mathcal{L}(\mathcal{B}[\mathbf{u}](\mathbf{x}_i))
where :math:`\mathcal{L}` is a specific loss function, default Mean Square Error:
where :math:`\mathcal{L}` is a specific loss function,
default Mean Square Error:
.. math::
\mathcal{L}(v) = \| v \|^2_2.
@@ -54,54 +48,31 @@ class PINN(PINNInterface):
DOI: `10.1038 <https://doi.org/10.1038/s42254-021-00314-5>`_.
"""
__name__ = 'PINN'
def __init__(
self,
problem,
model,
loss=None,
optimizer=None,
scheduler=None,
):
def __init__(self,
problem,
model,
optimizer=None,
scheduler=None,
weighting=None,
loss=None):
"""
:param AbstractProblem problem: The formulation of the problem.
:param torch.nn.Module model: The neural network model to use.
:param torch.nn.Module loss: The loss function used as minimizer,
default :class:`torch.nn.MSELoss`.
:param torch.nn.Module extra_features: The additional input
features to use as augmented input.
:param AbstractProblem problem: The formulation of the problem.
:param torch.optim.Optimizer optimizer: The neural network optimizer to
use; default is :class:`torch.optim.Adam`.
:param dict optimizer_kwargs: Optimizer constructor keyword args.
:param torch.optim.LRScheduler scheduler: Learning
rate scheduler.
:param dict scheduler_kwargs: LR scheduler constructor keyword args.
use; default `None`.
:param torch.optim.LRScheduler scheduler: Learning rate scheduler;
default `None`.
:param WeightingInterface weighting: The weighting schema to use;
default `None`.
:param torch.nn.Module loss: The loss function to be minimized;
default `None`.
"""
super().__init__(
models=model,
problem=problem,
loss=loss,
optimizers=optimizer,
schedulers=scheduler,
)
# assign variables
self._neural_net = self.models[0]
def forward(self, x):
r"""
Forward pass implementation for the PINN solver. It returns the function
evaluation :math:`\mathbf{u}(\mathbf{x})` at the control points
:math:`\mathbf{x}`.
:param LabelTensor x: Input tensor for the PINN solver. It expects
a tensor :math:`N \times D`, where :math:`N` the number of points
in the mesh, :math:`D` the dimension of the problem,
:return: PINN solution evaluated at contro points.
:rtype: LabelTensor
"""
return self.neural_net(x)
super().__init__(model=model,
problem=problem,
optimizer=optimizer,
scheduler=scheduler,
weighting=weighting,
loss=loss)
def loss_phys(self, samples, equation):
"""
@@ -117,46 +88,31 @@ class PINN(PINNInterface):
"""
residual = self.compute_residual(samples=samples, equation=equation)
loss_value = self.loss(
torch.zeros_like(residual), residual
torch.zeros_like(residual, requires_grad=True), residual
)
return loss_value
def configure_optimizers(self):
"""
Optimizer configuration for the PINN
solver.
Optimizer configuration for the PINN solver.
:return: The optimizers and the schedulers
:rtype: tuple(list, list)
"""
# if the problem is an InverseProblem, add the unknown parameters
# to the parameters that the optimizer needs to optimize
self._optimizer.hook(self._model.parameters())
# If the problem is an InverseProblem, add the unknown parameters
# to the parameters to be optimized.
self.optimizer.hook(self.model.parameters())
if isinstance(self.problem, InverseProblem):
self._optimizer.optimizer_instance.add_param_group(
{
"params": [
self._params[var]
for var in self.problem.unknown_variables
]
}
)
self._scheduler.hook(self._optimizer)
return ([self._optimizer.optimizer_instance],
[self._scheduler.scheduler_instance])
@property
def scheduler(self):
"""
Scheduler for the PINN training.
"""
return self._scheduler
@property
def neural_net(self):
"""
Neural network for the PINN training.
"""
return self._neural_net
self.optimizer.instance.add_param_group(
{
"params": [
self._params[var]
for var in self.problem.unknown_variables
]
}
)
self.scheduler.hook(self.optimizer)
return (
[self.optimizer.instance],
[self.scheduler.instance]
)

View File

@@ -1,17 +1,18 @@
""" Module for PINN """
""" Module for Physics Informed Neural Network Interface."""
from abc import ABCMeta, abstractmethod
import torch
from torch.nn.modules.loss import _Loss
from ..solver import SolverInterface
from ...utils import check_consistency
from ...loss.loss_interface import LossInterface
from ...problem import InverseProblem
from ...optim import TorchOptimizer, TorchScheduler
from ...condition import InputOutputPointsCondition, \
InputPointsEquationCondition, DomainEquationCondition
torch.pi = torch.acos(torch.zeros(1)).item() * 2 # which is 3.1415927410125732
from ...condition import (
InputOutputPointsCondition,
InputPointsEquationCondition,
DomainEquationCondition
)
class PINNInterface(SolverInterface, metaclass=ABCMeta):
@@ -19,57 +20,34 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
Base PINN solver class. This class implements the Solver Interface
for Physics Informed Neural Network solvers.
This class can be used to
define PINNs with multiple ``optimizers``, and/or ``models``.
By default it takes
an :class:`~pina.problem.abstract_problem.AbstractProblem`, so it is up
to the user to choose which problem the implemented solver inheriting from
this class is suitable for.
This class can be used to define PINNs with multiple ``optimizers``,
and/or ``models``.
By default it takes :class:`~pina.problem.abstract_problem.AbstractProblem`,
so the user can choose what type of problem the implemented solver,
inheriting from this class, is designed to solve.
"""
accepted_conditions_types = (InputOutputPointsCondition,
InputPointsEquationCondition, DomainEquationCondition)
accepted_conditions_types = (
InputOutputPointsCondition,
InputPointsEquationCondition,
DomainEquationCondition
)
def __init__(
self,
models,
problem,
loss=None,
optimizers=None,
schedulers=None,
):
def __init__(self,
problem,
loss=None,
**kwargs):
"""
:param models: Multiple torch neural network models instances.
:type models: list(torch.nn.Module)
:param problem: A problem definition instance.
:type problem: AbstractProblem
:param list(torch.optim.Optimizer) optimizer: A list of neural network
optimizers to use.
:param list(dict) optimizer_kwargs: A list of optimizer constructor
keyword args.
:param list(torch.nn.Module) extra_features: The additional input
features to use as augmented input. If ``None`` no extra features
are passed. If it is a list of :class:`torch.nn.Module`,
the extra feature list is passed to all models. If it is a list
of extra features' lists, each single list of extra feature
is passed to a model.
:param torch.nn.Module loss: The loss function used as minimizer,
default :class:`torch.nn.MSELoss`.
:param AbstractProblem problem: A problem definition instance.
:param torch.nn.Module loss: The loss function to be minimized,
default `None`.
"""
if optimizers is None:
optimizers = TorchOptimizer(torch.optim.Adam, lr=0.001)
if schedulers is None:
schedulers = TorchScheduler(torch.optim.lr_scheduler.ConstantLR)
if loss is None:
loss = torch.nn.MSELoss()
super().__init__(
models=models,
problem=problem,
optimizers=optimizers,
schedulers=schedulers,
)
super().__init__(problem=problem,
use_lt=True,
**kwargs)
# check consistency
check_consistency(loss, (LossInterface, _Loss), subclass=False)
@@ -85,86 +63,24 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
self._params = None
self._clamp_params = lambda: None
# variable used internally to store residual losses at each epoch
# this variable save the residual at each iteration (not weighted)
self.__logged_res_losses = []
self.__metric = None
# variable used internally in pina for logging. This variable points to
# the current condition during the training step and returns the
# condition name. Whenever :meth:`store_log` is called the logged
# variable will be stored with name = self.__logged_metric
self.__logged_metric = None
self._model = self._pina_models[0]
self._optimizer = self._pina_optimizers[0]
self._scheduler = self._pina_schedulers[0]
def training_step(self, batch):
"""
The Physics Informed Solver Training Step. This function takes care
of the physics informed training step, and it must not be override
if not intentionally. It handles the batching mechanism, the workload
division for the various conditions, the inverse problem clamping,
and loggers.
:param tuple batch: The batch element in the dataloader.
:param int batch_idx: The batch index.
:return: The sum of the loss functions.
:rtype: LabelTensor
"""
condition_loss = []
for condition_name, points in batch:
if 'output_points' in points:
input_pts, output_pts = points['input_points'], points['output_points']
loss_ = self.loss_data(
input_pts=input_pts, output_pts=output_pts)
condition_loss.append(loss_.as_subclass(torch.Tensor))
else:
input_pts = points['input_points']
condition = self.problem.conditions[condition_name]
loss_ = self.loss_phys(
input_pts.requires_grad_(), condition.equation)
condition_loss.append(loss_.as_subclass(torch.Tensor))
condition_loss.append(loss_.as_subclass(torch.Tensor))
# clamp unknown parameters in InverseProblem (if needed)
self._clamp_params()
loss = sum(condition_loss)
self.log('train_loss', loss, prog_bar=True, on_epoch=True,
logger=True, batch_size=self.get_batch_size(batch),
sync_dist=True)
def optimization_cycle(self, batch):
return self._run_optimization_cycle(batch, self.loss_phys)
@torch.set_grad_enabled(True)
def validation_step(self, batch):
losses = self._run_optimization_cycle(batch, self._residual_loss)
loss = self.weighting.aggregate(losses).as_subclass(torch.Tensor)
self.store_log('val_loss', loss, self.get_batch_size(batch))
return loss
def validation_step(self, batch):
"""
TODO: add docstring
"""
condition_loss = []
for condition_name, points in batch:
if 'output_points' in points:
input_pts, output_pts = points['input_points'], points['output_points']
loss_ = self.loss_data(
input_pts=input_pts, output_pts=output_pts)
condition_loss.append(loss_.as_subclass(torch.Tensor))
else:
input_pts = points['input_points']
condition = self.problem.conditions[condition_name]
with torch.set_grad_enabled(True):
loss_ = self.loss_phys(
input_pts.requires_grad_(), condition.equation)
condition_loss.append(loss_.as_subclass(torch.Tensor))
condition_loss.append(loss_.as_subclass(torch.Tensor))
# clamp unknown parameters in InverseProblem (if needed)
loss = sum(condition_loss)
self.log('val_loss', loss, on_epoch=True, prog_bar=True,
logger=True, batch_size=self.get_batch_size(batch),
sync_dist=True)
@torch.set_grad_enabled(True)
def test_step(self, batch):
losses = self._run_optimization_cycle(batch, self._residual_loss)
loss = self.weighting.aggregate(losses).as_subclass(torch.Tensor)
self.store_log('test_loss', loss, self.get_batch_size(batch))
return loss
def loss_data(self, input_pts, output_pts):
"""
@@ -196,11 +112,6 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
"""
pass
def configure_optimizers(self):
self._optimizer.hook(self._model)
self.schedulers.hook(self._optimizer)
return [self.optimizers.instance]#, self.schedulers.scheduler_instance
def compute_residual(self, samples, equation):
"""
Compute the residual for Physics Informed learning. This function
@@ -215,53 +126,45 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
"""
try:
residual = equation.residual(samples, self.forward(samples))
except (
TypeError
): # this occurs when the function has three inputs, i.e. inverse problem
except TypeError:
# this occurs when the function has three inputs (inverse problem)
residual = equation.residual(
samples, self.forward(samples), self._params
samples,
self.forward(samples),
self._params
)
return residual
def store_log(self, loss_value):
"""
Stores the loss value in the logger. This function should be
called for all conditions. It automatically handles the storing
conditions names. It must be used
anytime a specific variable wants to be stored for a specific condition.
A simple example is to use the variable to store the residual.
:param str name: The name of the loss.
:param torch.Tensor loss_value: The value of the loss.
"""
batch_size = self.trainer.data_module.batch_size \
if self.trainer.data_module.batch_size is not None else 999
self.log(
self.__logged_metric + "_loss",
loss_value,
prog_bar=True,
logger=True,
on_epoch=True,
on_step=True,
batch_size=batch_size,
)
self.__logged_res_losses.append(loss_value)
def save_logs_and_release(self):
"""
At the end of each epoch we free the stored losses. This function
should not be override if not intentionally.
"""
if self.__logged_res_losses:
# storing mean loss
self.__logged_metric = "mean"
self.store_log(
sum(self.__logged_res_losses) / len(self.__logged_res_losses)
)
# free the logged losses
self.__logged_res_losses = []
def _residual_loss(self, samples, equation):
residuals = self.compute_residual(samples, equation)
return self.loss(residuals, torch.zeros_like(residuals))
def _run_optimization_cycle(self, batch, loss_residuals):
condition_loss = {}
for condition_name, points in batch:
self.__metric = condition_name
# if equations are passed
if 'output_points' not in points:
input_pts = points['input_points']
condition = self.problem.conditions[condition_name]
loss = loss_residuals(
input_pts.requires_grad_(),
condition.equation
)
# if data are passed
else:
input_pts = points['input_points']
output_pts = points['output_points']
loss = self.loss_data(
input_pts=input_pts.requires_grad_(),
output_pts=output_pts
)
# append loss
condition_loss[condition_name] = loss
# clamp unknown parameters in InverseProblem (if needed)
self._clamp_params()
return condition_loss
def _clamp_inverse_problem_params(self):
"""
Clamps the parameters of the inverse problem
@@ -272,19 +175,17 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
self.problem.unknown_parameter_domain.range_[v][0],
self.problem.unknown_parameter_domain.range_[v][1],
)
@property
def loss(self):
"""
Loss used for training.
"""
return self._loss
@property
def current_condition_name(self):
"""
Returns the condition name. This function can be used inside the
:meth:`loss_phys` to extract the condition at which the loss is
computed.
The current condition name.
"""
return self.__logged_metric
return self.__metric

View File

@@ -1,8 +1,8 @@
""" Module for RBAPINN. """
""" Module for Residual-Based Attention PINN. """
from copy import deepcopy
import torch
from torch.optim.lr_scheduler import ConstantLR
from .pinn import PINN
from ...utils import check_consistency
@@ -66,51 +66,44 @@ class RBAPINN(PINN):
j.cma.2024.116805 <https://doi.org/10.1016/j.cma.2024.116805>`_.
"""
def __init__(
self,
problem,
model,
extra_features=None,
loss=torch.nn.MSELoss(),
optimizer=torch.optim.Adam,
optimizer_kwargs={"lr": 0.001},
scheduler=ConstantLR,
scheduler_kwargs={"factor": 1, "total_iters": 0},
eta=0.001,
gamma=0.999,
):
def __init__(self,
problem,
model,
optimizer=None,
scheduler=None,
weighting=None,
loss=None,
eta=0.001,
gamma=0.999):
"""
:param AbstractProblem problem: The formulation of the problem.
:param torch.nn.Module model: The neural network model to use.
:param torch.nn.Module extra_features: The additional input
features to use as augmented input.
:param torch.nn.Module loss: The loss function used as minimizer,
default :class:`torch.nn.MSELoss`.
:param AbstractProblem problem: The formulation of the problem.
:param torch.optim.Optimizer optimizer: The neural network optimizer to
use; default is :class:`torch.optim.Adam`.
:param dict optimizer_kwargs: Optimizer constructor keyword args.
:param torch.optim.LRScheduler scheduler: Learning
rate scheduler.
:param dict scheduler_kwargs: LR scheduler constructor keyword args.
:param float | int eta: The learning rate for the
weights of the residual.
:param float gamma: The decay parameter in the update of the weights
of the residual.
use; default `None`.
:param torch.optim.LRScheduler scheduler: Learning rate scheduler;
default `None`.
:param WeightingInterface weighting: The weighting schema to use;
default `None`.
:param torch.nn.Module loss: The loss function to be minimized;
default `None`.
:param float | int eta: The learning rate for the weights of the
residual; default 0.001.
:param float gamma: The decay parameter in the update of the weights
of the residual. Must be between 0 and 1; default 0.999.
"""
super().__init__(
problem=problem,
model=model,
extra_features=extra_features,
loss=loss,
optimizer=optimizer,
optimizer_kwargs=optimizer_kwargs,
scheduler=scheduler,
scheduler_kwargs=scheduler_kwargs,
)
super().__init__(model=model,
problem=problem,
optimizer=optimizer,
scheduler=scheduler,
weighting=weighting,
loss=loss)
# check consistency
check_consistency(eta, (float, int))
check_consistency(gamma, float)
assert (
0 < gamma < 1
), f"Invalid range: expected 0 < gamma < 1, got {gamma=}"
self.eta = eta
self.gamma = gamma
@@ -120,9 +113,17 @@ class RBAPINN(PINN):
self.weights[condition_name] = 0
# define vectorial loss
self._vectorial_loss = deepcopy(loss)
self._vectorial_loss = deepcopy(self.loss)
self._vectorial_loss.reduction = "none"
# for now RBAPINN is implemented only for batch_size = None
def on_train_start(self):
if self.trainer.batch_size is not None:
raise NotImplementedError("RBAPINN only works with full batch "
"size, set batch_size=None inside the "
"Trainer to use the solver.")
return super().on_train_start()
def _vect_to_scalar(self, loss_value):
"""
Elaboration of the pointwise loss.
@@ -159,16 +160,13 @@ class RBAPINN(PINN):
cond = self.current_condition_name
r_norm = (
self.eta
* torch.abs(residual)
self.eta * torch.abs(residual)
/ (torch.max(torch.abs(residual)) + 1e-12)
)
self.weights[cond] = (self.gamma * self.weights[cond] + r_norm).detach()
self.weights[cond] = (self.gamma*self.weights[cond] + r_norm).detach()
loss_value = self._vectorial_loss(
torch.zeros_like(residual, requires_grad=True), residual
)
self.store_log(loss_value=float(self._vect_to_scalar(loss_value)))
return self._vect_to_scalar(self.weights[cond] ** 2 * loss_value)

View File

@@ -1,30 +1,23 @@
""" Module for Self-Adaptive PINN. """
import torch
from copy import deepcopy
try:
from torch.optim.lr_scheduler import LRScheduler # torch >= 2.0
except ImportError:
from torch.optim.lr_scheduler import (
_LRScheduler as LRScheduler,
) # torch < 2.0
from .pinn_interface import PINNInterface
from pina.utils import check_consistency
from pina.problem import InverseProblem
from torch.optim.lr_scheduler import ConstantLR
from ..solver import MultiSolverInterface
from .pinn_interface import PINNInterface
class Weights(torch.nn.Module):
"""
This class aims to implements the mask model for
self adaptive weights of the Self-Adaptive
PINN solver.
This class aims to implements the mask model for the
self-adaptive weights of the Self-Adaptive PINN solver.
"""
def __init__(self, func):
"""
:param torch.nn.Module func: the mask module of SAPINN
:param torch.nn.Module func: the mask module of SAPINN.
"""
super().__init__()
check_consistency(func, torch.nn.Module)
@@ -34,8 +27,7 @@ class Weights(torch.nn.Module):
def forward(self):
"""
Forward pass implementation for the mask module.
It returns the function on the weights
evaluation.
It returns the function on the weights evaluation.
:return: evaluation of self adaptive weights through the mask.
:rtype: torch.Tensor
@@ -43,10 +35,10 @@ class Weights(torch.nn.Module):
return self.func(self.sa_weights)
class SAPINN(PINNInterface):
class SelfAdaptivePINN(PINNInterface, MultiSolverInterface):
r"""
Self Adaptive Physics Informed Neural Network (SAPINN) solver class.
This class implements Self-Adaptive Physics Informed Neural
Self Adaptive Physics Informed Neural Network (SelfAdaptivePINN)
solver class. This class implements Self-Adaptive Physics Informed Neural
Network solvers, using a user specified ``model`` to solve a specific
``problem``. It can be used for solving both forward and inverse problems.
@@ -107,97 +99,55 @@ class SAPINN(PINNInterface):
j.jcp.2022.111722 <https://doi.org/10.1016/j.jcp.2022.111722>`_.
"""
def __init__(
self,
problem,
model,
weights_function=torch.nn.Sigmoid(),
extra_features=None,
loss=torch.nn.MSELoss(),
optimizer_model=torch.optim.Adam,
optimizer_model_kwargs={"lr": 0.001},
optimizer_weights=torch.optim.Adam,
optimizer_weights_kwargs={"lr": 0.001},
scheduler_model=ConstantLR,
scheduler_model_kwargs={"factor": 1, "total_iters": 0},
scheduler_weights=ConstantLR,
scheduler_weights_kwargs={"factor": 1, "total_iters": 0},
):
def __init__(self,
problem,
model,
weight_function=torch.nn.Sigmoid(),
optimizer_model=None,
optimizer_weights=None,
scheduler_model=None,
scheduler_weights=None,
weighting=None,
loss=None):
"""
:param AbstractProblem problem: The formualation of the problem.
:param torch.nn.Module model: The neural network model to use
for the model.
:param torch.nn.Module weights_function: The neural network model
related to the mask of SAPINN.
default :obj:`~torch.nn.Sigmoid`.
:param list(torch.nn.Module) extra_features: The additional input
features to use as augmented input. If ``None`` no extra features
are passed. If it is a list of :class:`torch.nn.Module`,
the extra feature list is passed to all models. If it is a list
of extra features' lists, each single list of extra feature
is passed to a model.
:param torch.nn.Module loss: The loss function used as minimizer,
default :class:`torch.nn.MSELoss`.
:param torch.optim.Optimizer optimizer_model: The neural
network optimizer to use for the model network
, default is `torch.optim.Adam`.
:param dict optimizer_model_kwargs: Optimizer constructor keyword
args. for the model.
:param torch.optim.Optimizer optimizer_weights: The neural
network optimizer to use for mask model model,
default is `torch.optim.Adam`.
:param dict optimizer_weights_kwargs: Optimizer constructor
keyword args. for the mask module.
:param torch.optim.LRScheduler scheduler_model: Learning
rate scheduler for the model.
:param dict scheduler_model_kwargs: LR scheduler constructor
keyword args.
:param torch.optim.LRScheduler scheduler_weights: Learning
rate scheduler for the mask model.
:param dict scheduler_model_kwargs: LR scheduler constructor
keyword args.
:param AbstractProblem problem: The formulation of the problem.
:param torch.nn.Module model: The neural network model to use for
the model.
:param torch.nn.Module weight_function: The neural network model
related to the Self-Adaptive PINN mask; default `torch.nn.Sigmoid()`
:param torch.optim.Optimizer optimizer_model: The neural network
optimizer to use for the model network; default `None`.
:param torch.optim.Optimizer optimizer_weights: The neural network
optimizer to use for mask model; default `None`.
:param torch.optim.LRScheduler scheduler_model: Learning rate scheduler
for the model; default `None`.
:param torch.optim.LRScheduler scheduler_weights: Learning rate
scheduler for the mask model; default `None`.
:param WeightingInterface weighting: The weighting schema to use;
default `None`.
:param torch.nn.Module loss: The loss function to be minimized;
default `None`.
"""
# check consistency weitghs_function
check_consistency(weights_function, torch.nn.Module)
check_consistency(weight_function, torch.nn.Module)
# create models for weights
weights_dict = {}
for condition_name in problem.conditions:
weights_dict[condition_name] = Weights(weights_function)
weights_dict[condition_name] = Weights(weight_function)
weights_dict = torch.nn.ModuleDict(weights_dict)
super().__init__(
models=[model, weights_dict],
problem=problem,
optimizers=[optimizer_model, optimizer_weights],
optimizers_kwargs=[
optimizer_model_kwargs,
optimizer_weights_kwargs,
],
extra_features=extra_features,
loss=loss,
)
super().__init__(models=[model, weights_dict],
problem=problem,
optimizers=[optimizer_model, optimizer_weights],
schedulers=[scheduler_model, scheduler_weights],
weighting=weighting,
loss=loss)
# set automatic optimization
# Set automatic optimization to False
self.automatic_optimization = False
# check consistency
check_consistency(scheduler_model, LRScheduler, subclass=True)
check_consistency(scheduler_model_kwargs, dict)
check_consistency(scheduler_weights, LRScheduler, subclass=True)
check_consistency(scheduler_weights_kwargs, dict)
# assign schedulers
self._schedulers = [
scheduler_model(self.optimizers[0], **scheduler_model_kwargs),
scheduler_weights(self.optimizers[1], **scheduler_weights_kwargs),
]
self._model = self.models[0]
self._weights = self.models[1]
self._vectorial_loss = deepcopy(loss)
self._vectorial_loss = deepcopy(self.loss)
self._vectorial_loss.reduction = "none"
def forward(self, x):
@@ -213,7 +163,23 @@ class SAPINN(PINNInterface):
:return: PINN solution.
:rtype: LabelTensor
"""
return self.neural_net(x)
return self.model(x)
def training_step(self, batch):
"""
Solver training step, overridden to perform manual optimization.
:param batch: The batch element in the dataloader.
:type batch: tuple
:return: The sum of the loss functions.
:rtype: LabelTensor
"""
self.optimizer_model.instance.zero_grad()
self.optimizer_weights.instance.zero_grad()
loss = super().training_step(batch)
self.optimizer_model.instance.step()
self.optimizer_weights.instance.step()
return loss
def loss_phys(self, samples, equation):
"""
@@ -227,86 +193,72 @@ class SAPINN(PINNInterface):
samples and equation.
:rtype: torch.Tensor
"""
# train weights
self.optimizer_weights.zero_grad()
weighted_loss, _ = self._loss_phys(samples, equation)
# Train the weights
weighted_loss = self._loss_phys(samples, equation)
loss_value = -weighted_loss.as_subclass(torch.Tensor)
self.manual_backward(loss_value)
self.optimizer_weights.step()
# detaching samples from the computational graph to erase it and setting
# the gradient to true to create a new computational graph.
# Detach samples from the existing computational graph and
# create a new one by setting requires_grad to True.
# In alternative set `retain_graph=True`.
samples = samples.detach()
samples.requires_grad = True
samples.requires_grad_()# = True
# train model
self.optimizer_model.zero_grad()
weighted_loss, loss = self._loss_phys(samples, equation)
# Train the model
weighted_loss = self._loss_phys(samples, equation)
loss_value = weighted_loss.as_subclass(torch.Tensor)
self.manual_backward(loss_value)
self.optimizer_model.step()
# store loss without weights
self.store_log(loss_value=float(loss))
return loss_value
def loss_data(self, input_tensor, output_tensor):
def loss_data(self, input_pts, output_pts):
"""
Computes the data loss for the SAPINN solver based on input and
output. It computes the loss between the
network output against the true solution.
:param LabelTensor input_tensor: The input to the neural networks.
:param LabelTensor output_tensor: The true solution to compare the
:param LabelTensor input_pts: The input to the neural networks.
:param LabelTensor output_pts: The true solution to compare the
network solution.
:return: The computed data loss.
:rtype: torch.Tensor
"""
# train weights
self.optimizer_weights.zero_grad()
weighted_loss, _ = self._loss_data(input_tensor, output_tensor)
loss_value = -weighted_loss.as_subclass(torch.Tensor)
residual = self.forward(input_pts) - output_pts
loss = self._vectorial_loss(
torch.zeros_like(residual, requires_grad=True), residual
)
loss_value = self._vect_to_scalar(loss).as_subclass(torch.Tensor)
self.manual_backward(loss_value)
self.optimizer_weights.step()
# detaching samples from the computational graph to erase it and setting
# the gradient to true to create a new computational graph.
# In alternative set `retain_graph=True`.
input_tensor = input_tensor.detach()
input_tensor.requires_grad = True
# train model
self.optimizer_model.zero_grad()
weighted_loss, loss = self._loss_data(input_tensor, output_tensor)
loss_value = weighted_loss.as_subclass(torch.Tensor)
self.manual_backward(loss_value)
self.optimizer_model.step()
# store loss without weights
self.store_log(loss_value=float(loss))
return loss_value
def configure_optimizers(self):
"""
Optimizer configuration for the SAPINN
solver.
Optimizer configuration for the SelfAdaptive PINN solver.
:return: The optimizers and the schedulers
:rtype: tuple(list, list)
"""
# if the problem is an InverseProblem, add the unknown parameters
# to the parameters that the optimizer needs to optimize
# If the problem is an InverseProblem, add the unknown parameters
# to the parameters to be optimized
self.optimizer_model.hook(self.model.parameters())
self.optimizer_weights.hook(self.weights_dict.parameters())
if isinstance(self.problem, InverseProblem):
self.optimizers[0].add_param_group(
{
"params": [
self._params[var]
for var in self.problem.unknown_variables
]
}
)
return self.optimizers, self._schedulers
self.optimizer_model.instance.add_param_group(
{
"params": [
self._params[var]
for var in self.problem.unknown_variables
]
}
)
self.scheduler_model.hook(self.optimizer_model)
self.scheduler_weights.hook(self.optimizer_weights)
return (
[self.optimizer_model.instance,
self.optimizer_weights.instance],
[self.scheduler_model.instance,
self.scheduler_weights.instance]
)
def on_train_batch_end(self, outputs, batch, batch_idx):
"""
@@ -322,9 +274,11 @@ class SAPINN(PINNInterface):
:rtype: Any
"""
# increase by one the counter of optimization to save loggers
self.trainer.fit_loop.epoch_loop.manual_optimization.optim_step_progress.total.completed += (
1
)
(
self.trainer.fit_loop.epoch_loop.manual_optimization
.optim_step_progress.total.completed
) += 1
return super().on_train_batch_end(outputs, batch, batch_idx)
def on_train_start(self):
@@ -336,32 +290,45 @@ class SAPINN(PINNInterface):
method ``on_train_start``.
:rtype: Any
"""
if self.trainer.batch_size is not None:
raise NotImplementedError("SelfAdaptivePINN only works with full "
"batch size, set batch_size=None inside "
"the Trainer to use the solver.")
device = torch.device(
self.trainer._accelerator_connector._accelerator_flag
)
for condition_name, tensor in self.problem.input_pts.items():
self.weights_dict.torchmodel[condition_name].sa_weights.data = (
# Initialize the self adaptive weights only for training points
for condition_name, tensor in (
self.trainer.data_module.train_dataset.input_points.items()
):
self.weights_dict[condition_name].sa_weights.data = (
torch.rand((tensor.shape[0], 1), device=device)
)
return super().on_train_start()
def on_load_checkpoint(self, checkpoint):
"""
Overriding the Pytorch Lightning ``on_load_checkpoint`` to handle
checkpoints for Self Adaptive Weights. This method should not be
Override the Pytorch Lightning ``on_load_checkpoint`` to handle
checkpoints for Self-Adaptive Weights. This method should not be
overridden if not intentionally.
:param dict checkpoint: Pytorch Lightning checkpoint dict.
"""
for condition_name, tensor in self.problem.input_pts.items():
self.weights_dict.torchmodel[condition_name].sa_weights.data = (
torch.rand((tensor.shape[0], 1))
# First initialize self-adaptive weights with correct shape,
# then load the values from the checkpoint.
for condition_name, _ in self.problem.input_pts.items():
shape = checkpoint['state_dict'][
f"_pina_models.1.{condition_name}.sa_weights"
].shape
self.weights_dict[condition_name].sa_weights.data = (
torch.rand(shape)
)
return super().on_load_checkpoint(checkpoint)
def _loss_phys(self, samples, equation):
"""
Elaboration of the physical loss for the SAPINN solver.
Computation of the physical loss for SelfAdaptive PINN solver.
:param LabelTensor samples: Input samples to evaluate the physics loss.
:param EquationInterface equation: the governing equation representing
@@ -371,43 +338,11 @@ class SAPINN(PINNInterface):
:rtype: List[LabelTensor, LabelTensor]
"""
residual = self.compute_residual(samples, equation)
return self._compute_loss(residual)
def _loss_data(self, input_tensor, output_tensor):
"""
Elaboration of the loss related to data for the SAPINN solver.
:param LabelTensor input_tensor: The input to the neural networks.
:param LabelTensor output_tensor: The true solution to compare the
network solution.
:return: tuple with weighted and not weighted scalar loss
:rtype: List[LabelTensor, LabelTensor]
"""
residual = self.forward(input_tensor) - output_tensor
return self._compute_loss(residual)
def _compute_loss(self, residual):
"""
Elaboration of the pointwise loss through the mask model and the
self adaptive weights
:param LabelTensor residual: the matrix of residuals that have to
be weighted
:return: tuple with weighted and not weighted loss
:rtype List[LabelTensor, LabelTensor]
"""
weights = self.weights_dict.torchmodel[
self.current_condition_name
].forward()
weights = self.weights_dict[self.current_condition_name].forward()
loss_value = self._vectorial_loss(
torch.zeros_like(residual, requires_grad=True), residual
)
return (
self._vect_to_scalar(weights * loss_value),
self._vect_to_scalar(loss_value),
)
return self._vect_to_scalar(weights * loss_value)
def _vect_to_scalar(self, loss_value):
"""
@@ -431,12 +366,14 @@ class SAPINN(PINNInterface):
return ret
@property
def neural_net(self):
def model(self):
"""
Returns the neural network model.
Return the mask models associate to the application of
the mask to the self adaptive weights for each loss that
compones the global loss of the problem.
:return: The neural network model.
:rtype: torch.nn.Module
:return: The ModuleDict for mask models.
:rtype: torch.nn.ModuleDict
"""
return self.models[0]
@@ -460,7 +397,7 @@ class SAPINN(PINNInterface):
:return: The scheduler for the neural network model.
:rtype: torch.optim.lr_scheduler._LRScheduler
"""
return self._scheduler[0]
return self.schedulers[0]
@property
def scheduler_weights(self):
@@ -470,7 +407,7 @@ class SAPINN(PINNInterface):
:return: The scheduler for the mask model.
:rtype: torch.optim.lr_scheduler._LRScheduler
"""
return self._scheduler[1]
return self.schedulers[1]
@property
def optimizer_model(self):