Formatting

* Adding black as dev dependency
* Formatting pina code
* Formatting tests
This commit is contained in:
Dario Coscia
2025-02-24 11:26:49 +01:00
committed by Nicola Demo
parent 4c4482b155
commit 42ab1a666b
77 changed files with 1170 additions and 924 deletions

View File

@@ -1,4 +1,4 @@
""" Module for Physics Informed Neural Network Interface."""
"""Module for Physics Informed Neural Network Interface."""
from abc import ABCMeta, abstractmethod
import torch
@@ -11,7 +11,7 @@ from ...problem import InverseProblem
from ...condition import (
InputOutputPointsCondition,
InputPointsEquationCondition,
DomainEquationCondition
DomainEquationCondition,
)
@@ -20,22 +20,20 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
Base PINN solver class. This class implements the Solver Interface
for Physics Informed Neural Network solver.
This class can be used to define PINNs with multiple ``optimizers``,
This class can be used to define PINNs with multiple ``optimizers``,
and/or ``models``.
By default it takes :class:`~pina.problem.abstract_problem.AbstractProblem`,
so the user can choose what type of problem the implemented solver,
inheriting from this class, is designed to solve.
"""
accepted_conditions_types = (
InputOutputPointsCondition,
InputPointsEquationCondition,
DomainEquationCondition
DomainEquationCondition,
)
def __init__(self,
problem,
loss=None,
**kwargs):
def __init__(self, problem, loss=None, **kwargs):
"""
:param AbstractProblem problem: A problem definition instance.
:param torch.nn.Module loss: The loss function to be minimized,
@@ -45,9 +43,7 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
if loss is None:
loss = torch.nn.MSELoss()
super().__init__(problem=problem,
use_lt=True,
**kwargs)
super().__init__(problem=problem, use_lt=True, **kwargs)
# check consistency
check_consistency(loss, (LossInterface, _Loss), subclass=False)
@@ -72,14 +68,14 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
def validation_step(self, batch):
losses = self._run_optimization_cycle(batch, self._residual_loss)
loss = self.weighting.aggregate(losses).as_subclass(torch.Tensor)
self.store_log('val_loss', loss, self.get_batch_size(batch))
self.store_log("val_loss", loss, self.get_batch_size(batch))
return loss
@torch.set_grad_enabled(True)
def test_step(self, batch):
losses = self._run_optimization_cycle(batch, self._residual_loss)
loss = self.weighting.aggregate(losses).as_subclass(torch.Tensor)
self.store_log('test_loss', loss, self.get_batch_size(batch))
self.store_log("test_loss", loss, self.get_batch_size(batch))
return loss
def loss_data(self, input_pts, output_pts):
@@ -129,42 +125,38 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
except TypeError:
# this occurs when the function has three inputs (inverse problem)
residual = equation.residual(
samples,
self.forward(samples),
self._params
samples, self.forward(samples), self._params
)
return residual
def _residual_loss(self, samples, equation):
residuals = self.compute_residual(samples, equation)
return self.loss(residuals, torch.zeros_like(residuals))
def _run_optimization_cycle(self, batch, loss_residuals):
condition_loss = {}
for condition_name, points in batch:
self.__metric = condition_name
# if equations are passed
if 'output_points' not in points:
input_pts = points['input_points']
if "output_points" not in points:
input_pts = points["input_points"]
condition = self.problem.conditions[condition_name]
loss = loss_residuals(
input_pts.requires_grad_(),
condition.equation
input_pts.requires_grad_(), condition.equation
)
# if data are passed
else:
input_pts = points['input_points']
output_pts = points['output_points']
input_pts = points["input_points"]
output_pts = points["output_points"]
loss = self.loss_data(
input_pts=input_pts.requires_grad_(),
output_pts=output_pts
input_pts=input_pts.requires_grad_(), output_pts=output_pts
)
# append loss
condition_loss[condition_name] = loss
# clamp unknown parameters in InverseProblem (if needed)
self._clamp_params()
return condition_loss
def _clamp_inverse_problem_params(self):
"""
Clamps the parameters of the inverse problem
@@ -175,14 +167,14 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
self.problem.unknown_parameter_domain.range_[v][0],
self.problem.unknown_parameter_domain.range_[v][1],
)
@property
def loss(self):
"""
Loss used for training.
"""
return self._loss
@property
def current_condition_name(self):
"""