Update PINNInterface Inheritance (#542)
This commit is contained in:
@@ -2,12 +2,8 @@
|
||||
|
||||
from abc import ABCMeta, abstractmethod
|
||||
import torch
|
||||
from torch.nn.modules.loss import _Loss
|
||||
|
||||
from ..solver import SolverInterface
|
||||
from ...utils import check_consistency
|
||||
from ...loss.loss_interface import LossInterface
|
||||
from ...problem import InverseProblem
|
||||
from ..supervised_solver import SupervisedSolverInterface
|
||||
from ...condition import (
|
||||
InputTargetCondition,
|
||||
InputEquationCondition,
|
||||
@@ -15,7 +11,7 @@ from ...condition import (
|
||||
)
|
||||
|
||||
|
||||
class PINNInterface(SolverInterface, metaclass=ABCMeta):
|
||||
class PINNInterface(SupervisedSolverInterface, metaclass=ABCMeta):
|
||||
"""
|
||||
Base class for Physics-Informed Neural Network (PINN) solvers, implementing
|
||||
the :class:`~pina.solver.solver.SolverInterface` class.
|
||||
@@ -32,7 +28,7 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
|
||||
DomainEquationCondition,
|
||||
)
|
||||
|
||||
def __init__(self, problem, loss=None, **kwargs):
|
||||
def __init__(self, **kwargs):
|
||||
"""
|
||||
Initialization of the :class:`PINNInterface` class.
|
||||
|
||||
@@ -41,28 +37,13 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
|
||||
If ``None``, the :class:`torch.nn.MSELoss` loss is used.
|
||||
Default is `None`.
|
||||
:param kwargs: Additional keyword arguments to be passed to the
|
||||
:class:`~pina.solver.solver.SolverInterface` class.
|
||||
:class:`~pina.solver.supervised_solver.SupervisedSolverInterface`
|
||||
class.
|
||||
"""
|
||||
kwargs["use_lt"] = True
|
||||
super().__init__(**kwargs)
|
||||
|
||||
if loss is None:
|
||||
loss = torch.nn.MSELoss()
|
||||
|
||||
super().__init__(problem=problem, use_lt=True, **kwargs)
|
||||
|
||||
# check consistency
|
||||
check_consistency(loss, (LossInterface, _Loss), subclass=False)
|
||||
|
||||
# assign variables
|
||||
self._loss_fn = loss
|
||||
|
||||
# inverse problem handling
|
||||
if isinstance(self.problem, InverseProblem):
|
||||
self._params = self.problem.unknown_parameters
|
||||
self._clamp_params = self._clamp_inverse_problem_params
|
||||
else:
|
||||
self._params = None
|
||||
self._clamp_params = lambda: None
|
||||
|
||||
# current condition name
|
||||
self.__metric = None
|
||||
|
||||
def optimization_cycle(self, batch, loss_residuals=None):
|
||||
@@ -103,8 +84,6 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
|
||||
)
|
||||
# append loss
|
||||
condition_loss[condition_name] = loss
|
||||
# clamp unknown parameters in InverseProblem (if needed)
|
||||
self._clamp_params()
|
||||
return condition_loss
|
||||
|
||||
@torch.set_grad_enabled(True)
|
||||
@@ -135,7 +114,6 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
|
||||
"""
|
||||
return super().test_step(batch, loss_residuals=self._residual_loss)
|
||||
|
||||
@abstractmethod
|
||||
def loss_data(self, input, target):
|
||||
"""
|
||||
Compute the data loss for the PINN solver by evaluating the loss
|
||||
@@ -147,7 +125,12 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
|
||||
network's output.
|
||||
:return: The supervised loss, averaged over the number of observations.
|
||||
:rtype: LabelTensor
|
||||
:raises NotImplementedError: If the method is not implemented.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"PINN is being used in a supervised learning context, but the "
|
||||
"'loss_data' method has not been implemented. "
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def loss_phys(self, samples, equation):
|
||||
@@ -196,26 +179,6 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
|
||||
residuals = self.compute_residual(samples, equation)
|
||||
return self._loss_fn(residuals, torch.zeros_like(residuals))
|
||||
|
||||
def _clamp_inverse_problem_params(self):
|
||||
"""
|
||||
Clamps the parameters of the inverse problem solver to specified ranges.
|
||||
"""
|
||||
for v in self._params:
|
||||
self._params[v].data.clamp_(
|
||||
self.problem.unknown_parameter_domain.range_[v][0],
|
||||
self.problem.unknown_parameter_domain.range_[v][1],
|
||||
)
|
||||
|
||||
@property
|
||||
def loss(self):
|
||||
"""
|
||||
The loss used for training.
|
||||
|
||||
:return: The loss function used for training.
|
||||
:rtype: torch.nn.Module
|
||||
"""
|
||||
return self._loss_fn
|
||||
|
||||
@property
|
||||
def current_condition_name(self):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user