Files
PINA/pina/solver/physic_informed_solver/pinn_interface.py
Filippo Olivo 4177bfbb50 Fix Codacy Warnings (#477)
---------

Co-authored-by: Dario Coscia <dariocos99@gmail.com>
2025-03-19 17:48:18 +01:00

183 lines
6.4 KiB
Python

"""Module for Physics Informed Neural Network Interface."""
from abc import ABCMeta, abstractmethod
import torch
from torch.nn.modules.loss import _Loss
from ..solver import SolverInterface
from ...utils import check_consistency
from ...loss.loss_interface import LossInterface
from ...problem import InverseProblem
from ...condition import (
InputTargetCondition,
InputEquationCondition,
DomainEquationCondition,
)
class PINNInterface(SolverInterface, metaclass=ABCMeta):
"""
Base PINN solver class. This class implements the Solver Interface
for Physics Informed Neural Network solver.
This class can be used to define PINNs with multiple ``optimizers``,
and/or ``models``.
By default it takes :class:`~pina.problem.abstract_problem.AbstractProblem`,
so the user can choose what type of problem the implemented solver,
inheriting from this class, is designed to solve.
"""
accepted_conditions_types = (
InputTargetCondition,
InputEquationCondition,
DomainEquationCondition,
)
def __init__(self, problem, loss=None, **kwargs):
"""
:param AbstractProblem problem: A problem definition instance.
:param torch.nn.Module loss: The loss function to be minimized,
default `None`.
"""
if loss is None:
loss = torch.nn.MSELoss()
super().__init__(problem=problem, use_lt=True, **kwargs)
# check consistency
check_consistency(loss, (LossInterface, _Loss), subclass=False)
# assign variables
self._loss = loss
# inverse problem handling
if isinstance(self.problem, InverseProblem):
self._params = self.problem.unknown_parameters
self._clamp_params = self._clamp_inverse_problem_params
else:
self._params = None
self._clamp_params = lambda: None
self.__metric = None
def optimization_cycle(self, batch):
return self._run_optimization_cycle(batch, self.loss_phys)
@torch.set_grad_enabled(True)
def validation_step(self, batch):
losses = self._run_optimization_cycle(batch, self._residual_loss)
loss = self.weighting.aggregate(losses).as_subclass(torch.Tensor)
self.store_log("val_loss", loss, self.get_batch_size(batch))
return loss
@torch.set_grad_enabled(True)
def test_step(self, batch):
losses = self._run_optimization_cycle(batch, self._residual_loss)
loss = self.weighting.aggregate(losses).as_subclass(torch.Tensor)
self.store_log("test_loss", loss, self.get_batch_size(batch))
return loss
def loss_data(self, input_pts, output_pts):
"""
The data loss for the PINN solver. It computes the loss between
the network output against the true solution. This function
should not be override if not intentionally.
:param LabelTensor input_pts: The input to the neural networks.
:param LabelTensor output_pts: The true solution to compare the
network solution.
:return: The residual loss averaged on the input coordinates
:rtype: torch.Tensor
"""
return self._loss(self.forward(input_pts), output_pts)
@abstractmethod
def loss_phys(self, samples, equation):
"""
Computes the physics loss for the physics informed solver based on given
samples and equation. This method must be override by all inherited
classes and it is the core to define a new physics informed solver.
:param LabelTensor samples: The samples to evaluate the physics loss.
:param EquationInterface equation: The governing equation
representing the physics.
:return: The physics loss calculated based on given
samples and equation.
:rtype: LabelTensor
"""
def compute_residual(self, samples, equation):
"""
Compute the residual for Physics Informed learning. This function
returns the :obj:`~pina.equation.equation.Equation` specified in the
:obj:`~pina.condition.Condition` evaluated at the ``samples`` points.
:param LabelTensor samples: The samples to evaluate the physics loss.
:param EquationInterface equation: The governing equation
representing the physics.
:return: The residual of the neural network solution.
:rtype: LabelTensor
"""
try:
residual = equation.residual(samples, self.forward(samples))
except TypeError:
# this occurs when the function has three inputs (inverse problem)
residual = equation.residual(
samples, self.forward(samples), self._params
)
return residual
def _residual_loss(self, samples, equation):
residuals = self.compute_residual(samples, equation)
return self.loss(residuals, torch.zeros_like(residuals))
def _run_optimization_cycle(self, batch, loss_residuals):
condition_loss = {}
for condition_name, points in batch:
self.__metric = condition_name
# if equations are passed
if "target" not in points:
input_pts = points["input"]
condition = self.problem.conditions[condition_name]
loss = loss_residuals(
input_pts.requires_grad_(), condition.equation
)
# if data are passed
else:
input_pts = points["input"]
output_pts = points["target"]
loss = self.loss_data(
input_pts=input_pts.requires_grad_(), output_pts=output_pts
)
# append loss
condition_loss[condition_name] = loss
# clamp unknown parameters in InverseProblem (if needed)
self._clamp_params()
return condition_loss
def _clamp_inverse_problem_params(self):
"""
Clamps the parameters of the inverse problem
solver to the specified ranges.
"""
for v in self._params:
self._params[v].data.clamp_(
self.problem.unknown_parameter_domain.range_[v][0],
self.problem.unknown_parameter_domain.range_[v][1],
)
@property
def loss(self):
"""
Loss used for training.
"""
return self._loss
@property
def current_condition_name(self):
"""
The current condition name.
"""
return self.__metric