Update PINNInterface Inheritance (#542)

This commit is contained in:
Dario Coscia
2025-04-14 09:45:01 +02:00
parent 3679da7bac
commit 7e49392ac3
2 changed files with 46 additions and 51 deletions

View File

@@ -2,12 +2,8 @@
from abc import ABCMeta, abstractmethod
import torch
from torch.nn.modules.loss import _Loss
from ..solver import SolverInterface
from ...utils import check_consistency
from ...loss.loss_interface import LossInterface
from ...problem import InverseProblem
from ..supervised_solver import SupervisedSolverInterface
from ...condition import (
InputTargetCondition,
InputEquationCondition,
@@ -15,7 +11,7 @@ from ...condition import (
)
class PINNInterface(SolverInterface, metaclass=ABCMeta):
class PINNInterface(SupervisedSolverInterface, metaclass=ABCMeta):
"""
Base class for Physics-Informed Neural Network (PINN) solvers, implementing
the :class:`~pina.solver.solver.SolverInterface` class.
@@ -32,7 +28,7 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
DomainEquationCondition,
)
def __init__(self, problem, loss=None, **kwargs):
def __init__(self, **kwargs):
"""
Initialization of the :class:`PINNInterface` class.
@@ -41,28 +37,13 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
If ``None``, the :class:`torch.nn.MSELoss` loss is used.
Default is `None`.
:param kwargs: Additional keyword arguments to be passed to the
:class:`~pina.solver.solver.SolverInterface` class.
:class:`~pina.solver.supervised_solver.SupervisedSolverInterface`
class.
"""
kwargs["use_lt"] = True
super().__init__(**kwargs)
if loss is None:
loss = torch.nn.MSELoss()
super().__init__(problem=problem, use_lt=True, **kwargs)
# check consistency
check_consistency(loss, (LossInterface, _Loss), subclass=False)
# assign variables
self._loss_fn = loss
# inverse problem handling
if isinstance(self.problem, InverseProblem):
self._params = self.problem.unknown_parameters
self._clamp_params = self._clamp_inverse_problem_params
else:
self._params = None
self._clamp_params = lambda: None
# current condition name
self.__metric = None
def optimization_cycle(self, batch, loss_residuals=None):
@@ -103,8 +84,6 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
)
# append loss
condition_loss[condition_name] = loss
# clamp unknown parameters in InverseProblem (if needed)
self._clamp_params()
return condition_loss
@torch.set_grad_enabled(True)
@@ -135,7 +114,6 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
"""
return super().test_step(batch, loss_residuals=self._residual_loss)
@abstractmethod
def loss_data(self, input, target):
"""
Compute the data loss for the PINN solver by evaluating the loss
@@ -147,7 +125,12 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
network's output.
:return: The supervised loss, averaged over the number of observations.
:rtype: LabelTensor
:raises NotImplementedError: If the method is not implemented.
"""
raise NotImplementedError(
"PINN is being used in a supervised learning context, but the "
"'loss_data' method has not been implemented. "
)
@abstractmethod
def loss_phys(self, samples, equation):
@@ -196,26 +179,6 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
residuals = self.compute_residual(samples, equation)
return self._loss_fn(residuals, torch.zeros_like(residuals))
def _clamp_inverse_problem_params(self):
"""
Clamps the parameters of the inverse problem solver to specified ranges.
"""
for v in self._params:
self._params[v].data.clamp_(
self.problem.unknown_parameter_domain.range_[v][0],
self.problem.unknown_parameter_domain.range_[v][1],
)
@property
def loss(self):
"""
The loss used for training.
:return: The loss function used for training.
:rtype: torch.nn.Module
"""
return self._loss_fn
@property
def current_condition_name(self):
"""