Update PINNInterface Inheritance (#542)
This commit is contained in:
committed by
FilippoOlivo
parent
88b27605f1
commit
e1863d9318
@@ -2,12 +2,8 @@
|
|||||||
|
|
||||||
from abc import ABCMeta, abstractmethod
|
from abc import ABCMeta, abstractmethod
|
||||||
import torch
|
import torch
|
||||||
from torch.nn.modules.loss import _Loss
|
|
||||||
|
|
||||||
from ..solver import SolverInterface
|
from ..supervised_solver import SupervisedSolverInterface
|
||||||
from ...utils import check_consistency
|
|
||||||
from ...loss.loss_interface import LossInterface
|
|
||||||
from ...problem import InverseProblem
|
|
||||||
from ...condition import (
|
from ...condition import (
|
||||||
InputTargetCondition,
|
InputTargetCondition,
|
||||||
InputEquationCondition,
|
InputEquationCondition,
|
||||||
@@ -15,7 +11,7 @@ from ...condition import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class PINNInterface(SolverInterface, metaclass=ABCMeta):
|
class PINNInterface(SupervisedSolverInterface, metaclass=ABCMeta):
|
||||||
"""
|
"""
|
||||||
Base class for Physics-Informed Neural Network (PINN) solvers, implementing
|
Base class for Physics-Informed Neural Network (PINN) solvers, implementing
|
||||||
the :class:`~pina.solver.solver.SolverInterface` class.
|
the :class:`~pina.solver.solver.SolverInterface` class.
|
||||||
@@ -32,7 +28,7 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
|
|||||||
DomainEquationCondition,
|
DomainEquationCondition,
|
||||||
)
|
)
|
||||||
|
|
||||||
def __init__(self, problem, loss=None, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
"""
|
"""
|
||||||
Initialization of the :class:`PINNInterface` class.
|
Initialization of the :class:`PINNInterface` class.
|
||||||
|
|
||||||
@@ -41,28 +37,13 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
|
|||||||
If ``None``, the :class:`torch.nn.MSELoss` loss is used.
|
If ``None``, the :class:`torch.nn.MSELoss` loss is used.
|
||||||
Default is `None`.
|
Default is `None`.
|
||||||
:param kwargs: Additional keyword arguments to be passed to the
|
:param kwargs: Additional keyword arguments to be passed to the
|
||||||
:class:`~pina.solver.solver.SolverInterface` class.
|
:class:`~pina.solver.supervised_solver.SupervisedSolverInterface`
|
||||||
|
class.
|
||||||
"""
|
"""
|
||||||
|
kwargs["use_lt"] = True
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
if loss is None:
|
# current condition name
|
||||||
loss = torch.nn.MSELoss()
|
|
||||||
|
|
||||||
super().__init__(problem=problem, use_lt=True, **kwargs)
|
|
||||||
|
|
||||||
# check consistency
|
|
||||||
check_consistency(loss, (LossInterface, _Loss), subclass=False)
|
|
||||||
|
|
||||||
# assign variables
|
|
||||||
self._loss_fn = loss
|
|
||||||
|
|
||||||
# inverse problem handling
|
|
||||||
if isinstance(self.problem, InverseProblem):
|
|
||||||
self._params = self.problem.unknown_parameters
|
|
||||||
self._clamp_params = self._clamp_inverse_problem_params
|
|
||||||
else:
|
|
||||||
self._params = None
|
|
||||||
self._clamp_params = lambda: None
|
|
||||||
|
|
||||||
self.__metric = None
|
self.__metric = None
|
||||||
|
|
||||||
def optimization_cycle(self, batch, loss_residuals=None):
|
def optimization_cycle(self, batch, loss_residuals=None):
|
||||||
@@ -103,8 +84,6 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
|
|||||||
)
|
)
|
||||||
# append loss
|
# append loss
|
||||||
condition_loss[condition_name] = loss
|
condition_loss[condition_name] = loss
|
||||||
# clamp unknown parameters in InverseProblem (if needed)
|
|
||||||
self._clamp_params()
|
|
||||||
return condition_loss
|
return condition_loss
|
||||||
|
|
||||||
@torch.set_grad_enabled(True)
|
@torch.set_grad_enabled(True)
|
||||||
@@ -135,7 +114,6 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
|
|||||||
"""
|
"""
|
||||||
return super().test_step(batch, loss_residuals=self._residual_loss)
|
return super().test_step(batch, loss_residuals=self._residual_loss)
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def loss_data(self, input, target):
|
def loss_data(self, input, target):
|
||||||
"""
|
"""
|
||||||
Compute the data loss for the PINN solver by evaluating the loss
|
Compute the data loss for the PINN solver by evaluating the loss
|
||||||
@@ -147,7 +125,12 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
|
|||||||
network's output.
|
network's output.
|
||||||
:return: The supervised loss, averaged over the number of observations.
|
:return: The supervised loss, averaged over the number of observations.
|
||||||
:rtype: LabelTensor
|
:rtype: LabelTensor
|
||||||
|
:raises NotImplementedError: If the method is not implemented.
|
||||||
"""
|
"""
|
||||||
|
raise NotImplementedError(
|
||||||
|
"PINN is being used in a supervised learning context, but the "
|
||||||
|
"'loss_data' method has not been implemented. "
|
||||||
|
)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def loss_phys(self, samples, equation):
|
def loss_phys(self, samples, equation):
|
||||||
@@ -196,26 +179,6 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
|
|||||||
residuals = self.compute_residual(samples, equation)
|
residuals = self.compute_residual(samples, equation)
|
||||||
return self._loss_fn(residuals, torch.zeros_like(residuals))
|
return self._loss_fn(residuals, torch.zeros_like(residuals))
|
||||||
|
|
||||||
def _clamp_inverse_problem_params(self):
|
|
||||||
"""
|
|
||||||
Clamps the parameters of the inverse problem solver to specified ranges.
|
|
||||||
"""
|
|
||||||
for v in self._params:
|
|
||||||
self._params[v].data.clamp_(
|
|
||||||
self.problem.unknown_parameter_domain.range_[v][0],
|
|
||||||
self.problem.unknown_parameter_domain.range_[v][1],
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def loss(self):
|
|
||||||
"""
|
|
||||||
The loss used for training.
|
|
||||||
|
|
||||||
:return: The loss function used for training.
|
|
||||||
:rtype: torch.nn.Module
|
|
||||||
"""
|
|
||||||
return self._loss_fn
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def current_condition_name(self):
|
def current_condition_name(self):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import lightning
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from torch._dynamo import OptimizedModule
|
from torch._dynamo import OptimizedModule
|
||||||
from ..problem import AbstractProblem
|
from ..problem import AbstractProblem, InverseProblem
|
||||||
from ..optim import Optimizer, Scheduler, TorchOptimizer, TorchScheduler
|
from ..optim import Optimizer, Scheduler, TorchOptimizer, TorchScheduler
|
||||||
from ..loss import WeightingInterface
|
from ..loss import WeightingInterface
|
||||||
from ..loss.scalar_weighting import _NoWeighting
|
from ..loss.scalar_weighting import _NoWeighting
|
||||||
@@ -64,6 +64,14 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
|
|||||||
self._pina_optimizers = None
|
self._pina_optimizers = None
|
||||||
self._pina_schedulers = None
|
self._pina_schedulers = None
|
||||||
|
|
||||||
|
# inverse problem handling
|
||||||
|
if isinstance(self.problem, InverseProblem):
|
||||||
|
self._params = self.problem.unknown_parameters
|
||||||
|
self._clamp_params = self._clamp_inverse_problem_params
|
||||||
|
else:
|
||||||
|
self._params = None
|
||||||
|
self._clamp_params = lambda: None
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
@@ -231,14 +239,29 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
|
|||||||
containing the condition name and the associated scalar loss.
|
containing the condition name and the associated scalar loss.
|
||||||
:rtype: dict
|
:rtype: dict
|
||||||
"""
|
"""
|
||||||
|
# compute losses
|
||||||
losses = self.optimization_cycle(batch)
|
losses = self.optimization_cycle(batch)
|
||||||
|
# clamp unknown parameters in InverseProblem (if needed)
|
||||||
|
self._clamp_params()
|
||||||
|
# store log
|
||||||
for name, value in losses.items():
|
for name, value in losses.items():
|
||||||
self.store_log(
|
self.store_log(
|
||||||
f"{name}_loss", value.item(), self.get_batch_size(batch)
|
f"{name}_loss", value.item(), self.get_batch_size(batch)
|
||||||
)
|
)
|
||||||
|
# aggregate
|
||||||
loss = self.weighting.aggregate(losses).as_subclass(torch.Tensor)
|
loss = self.weighting.aggregate(losses).as_subclass(torch.Tensor)
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
def _clamp_inverse_problem_params(self):
|
||||||
|
"""
|
||||||
|
Clamps the parameters of the inverse problem solver to specified ranges.
|
||||||
|
"""
|
||||||
|
for v in self._params:
|
||||||
|
self._params[v].data.clamp_(
|
||||||
|
self.problem.unknown_parameter_domain.range_[v][0],
|
||||||
|
self.problem.unknown_parameter_domain.range_[v][1],
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _compile_modules(model):
|
def _compile_modules(model):
|
||||||
"""
|
"""
|
||||||
@@ -405,6 +428,15 @@ class SingleSolverInterface(SolverInterface, metaclass=ABCMeta):
|
|||||||
:rtype: tuple[list[Optimizer], list[Scheduler]]
|
:rtype: tuple[list[Optimizer], list[Scheduler]]
|
||||||
"""
|
"""
|
||||||
self.optimizer.hook(self.model.parameters())
|
self.optimizer.hook(self.model.parameters())
|
||||||
|
if isinstance(self.problem, InverseProblem):
|
||||||
|
self.optimizer.instance.add_param_group(
|
||||||
|
{
|
||||||
|
"params": [
|
||||||
|
self._params[var]
|
||||||
|
for var in self.problem.unknown_variables
|
||||||
|
]
|
||||||
|
}
|
||||||
|
)
|
||||||
self.scheduler.hook(self.optimizer)
|
self.scheduler.hook(self.optimizer)
|
||||||
return ([self.optimizer.instance], [self.scheduler.instance])
|
return ([self.optimizer.instance], [self.scheduler.instance])
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user