Update PINNInterface Inheritance (#542)

This commit is contained in:
Dario Coscia
2025-04-14 09:45:01 +02:00
committed by FilippoOlivo
parent 88b27605f1
commit e1863d9318
2 changed files with 46 additions and 51 deletions

View File

@@ -2,12 +2,8 @@
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
import torch import torch
from torch.nn.modules.loss import _Loss
from ..solver import SolverInterface from ..supervised_solver import SupervisedSolverInterface
from ...utils import check_consistency
from ...loss.loss_interface import LossInterface
from ...problem import InverseProblem
from ...condition import ( from ...condition import (
InputTargetCondition, InputTargetCondition,
InputEquationCondition, InputEquationCondition,
@@ -15,7 +11,7 @@ from ...condition import (
) )
class PINNInterface(SolverInterface, metaclass=ABCMeta): class PINNInterface(SupervisedSolverInterface, metaclass=ABCMeta):
""" """
Base class for Physics-Informed Neural Network (PINN) solvers, implementing Base class for Physics-Informed Neural Network (PINN) solvers, implementing
the :class:`~pina.solver.solver.SolverInterface` class. the :class:`~pina.solver.solver.SolverInterface` class.
@@ -32,7 +28,7 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
DomainEquationCondition, DomainEquationCondition,
) )
def __init__(self, problem, loss=None, **kwargs): def __init__(self, **kwargs):
""" """
Initialization of the :class:`PINNInterface` class. Initialization of the :class:`PINNInterface` class.
@@ -41,28 +37,13 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
If ``None``, the :class:`torch.nn.MSELoss` loss is used. If ``None``, the :class:`torch.nn.MSELoss` loss is used.
Default is `None`. Default is `None`.
:param kwargs: Additional keyword arguments to be passed to the :param kwargs: Additional keyword arguments to be passed to the
:class:`~pina.solver.solver.SolverInterface` class. :class:`~pina.solver.supervised_solver.SupervisedSolverInterface`
class.
""" """
kwargs["use_lt"] = True
super().__init__(**kwargs)
if loss is None: # current condition name
loss = torch.nn.MSELoss()
super().__init__(problem=problem, use_lt=True, **kwargs)
# check consistency
check_consistency(loss, (LossInterface, _Loss), subclass=False)
# assign variables
self._loss_fn = loss
# inverse problem handling
if isinstance(self.problem, InverseProblem):
self._params = self.problem.unknown_parameters
self._clamp_params = self._clamp_inverse_problem_params
else:
self._params = None
self._clamp_params = lambda: None
self.__metric = None self.__metric = None
def optimization_cycle(self, batch, loss_residuals=None): def optimization_cycle(self, batch, loss_residuals=None):
@@ -103,8 +84,6 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
) )
# append loss # append loss
condition_loss[condition_name] = loss condition_loss[condition_name] = loss
# clamp unknown parameters in InverseProblem (if needed)
self._clamp_params()
return condition_loss return condition_loss
@torch.set_grad_enabled(True) @torch.set_grad_enabled(True)
@@ -135,7 +114,6 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
""" """
return super().test_step(batch, loss_residuals=self._residual_loss) return super().test_step(batch, loss_residuals=self._residual_loss)
@abstractmethod
def loss_data(self, input, target): def loss_data(self, input, target):
""" """
Compute the data loss for the PINN solver by evaluating the loss Compute the data loss for the PINN solver by evaluating the loss
@@ -147,7 +125,12 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
network's output. network's output.
:return: The supervised loss, averaged over the number of observations. :return: The supervised loss, averaged over the number of observations.
:rtype: LabelTensor :rtype: LabelTensor
:raises NotImplementedError: If the method is not implemented.
""" """
raise NotImplementedError(
"PINN is being used in a supervised learning context, but the "
"'loss_data' method has not been implemented. "
)
@abstractmethod @abstractmethod
def loss_phys(self, samples, equation): def loss_phys(self, samples, equation):
@@ -196,26 +179,6 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
residuals = self.compute_residual(samples, equation) residuals = self.compute_residual(samples, equation)
return self._loss_fn(residuals, torch.zeros_like(residuals)) return self._loss_fn(residuals, torch.zeros_like(residuals))
def _clamp_inverse_problem_params(self):
"""
Clamps the parameters of the inverse problem solver to specified ranges.
"""
for v in self._params:
self._params[v].data.clamp_(
self.problem.unknown_parameter_domain.range_[v][0],
self.problem.unknown_parameter_domain.range_[v][1],
)
@property
def loss(self):
"""
The loss used for training.
:return: The loss function used for training.
:rtype: torch.nn.Module
"""
return self._loss_fn
@property @property
def current_condition_name(self): def current_condition_name(self):
""" """

View File

@@ -5,7 +5,7 @@ import lightning
import torch import torch
from torch._dynamo import OptimizedModule from torch._dynamo import OptimizedModule
from ..problem import AbstractProblem from ..problem import AbstractProblem, InverseProblem
from ..optim import Optimizer, Scheduler, TorchOptimizer, TorchScheduler from ..optim import Optimizer, Scheduler, TorchOptimizer, TorchScheduler
from ..loss import WeightingInterface from ..loss import WeightingInterface
from ..loss.scalar_weighting import _NoWeighting from ..loss.scalar_weighting import _NoWeighting
@@ -64,6 +64,14 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
self._pina_optimizers = None self._pina_optimizers = None
self._pina_schedulers = None self._pina_schedulers = None
# inverse problem handling
if isinstance(self.problem, InverseProblem):
self._params = self.problem.unknown_parameters
self._clamp_params = self._clamp_inverse_problem_params
else:
self._params = None
self._clamp_params = lambda: None
@abstractmethod @abstractmethod
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
""" """
@@ -231,14 +239,29 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
containing the condition name and the associated scalar loss. containing the condition name and the associated scalar loss.
:rtype: dict :rtype: dict
""" """
# compute losses
losses = self.optimization_cycle(batch) losses = self.optimization_cycle(batch)
# clamp unknown parameters in InverseProblem (if needed)
self._clamp_params()
# store log
for name, value in losses.items(): for name, value in losses.items():
self.store_log( self.store_log(
f"{name}_loss", value.item(), self.get_batch_size(batch) f"{name}_loss", value.item(), self.get_batch_size(batch)
) )
# aggregate
loss = self.weighting.aggregate(losses).as_subclass(torch.Tensor) loss = self.weighting.aggregate(losses).as_subclass(torch.Tensor)
return loss return loss
def _clamp_inverse_problem_params(self):
"""
Clamps the parameters of the inverse problem solver to specified ranges.
"""
for v in self._params:
self._params[v].data.clamp_(
self.problem.unknown_parameter_domain.range_[v][0],
self.problem.unknown_parameter_domain.range_[v][1],
)
@staticmethod @staticmethod
def _compile_modules(model): def _compile_modules(model):
""" """
@@ -405,6 +428,15 @@ class SingleSolverInterface(SolverInterface, metaclass=ABCMeta):
:rtype: tuple[list[Optimizer], list[Scheduler]] :rtype: tuple[list[Optimizer], list[Scheduler]]
""" """
self.optimizer.hook(self.model.parameters()) self.optimizer.hook(self.model.parameters())
if isinstance(self.problem, InverseProblem):
self.optimizer.instance.add_param_group(
{
"params": [
self._params[var]
for var in self.problem.unknown_variables
]
}
)
self.scheduler.hook(self.optimizer) self.scheduler.hook(self.optimizer)
return ([self.optimizer.instance], [self.scheduler.instance]) return ([self.optimizer.instance], [self.scheduler.instance])