Formatting
* Adding black as dev dependency * Formatting pina code * Formatting tests
This commit is contained in:
committed by
Nicola Demo
parent
4c4482b155
commit
42ab1a666b
@@ -1,4 +1,4 @@
|
||||
""" Module for Physics Informed Neural Network Interface."""
|
||||
"""Module for Physics Informed Neural Network Interface."""
|
||||
|
||||
from abc import ABCMeta, abstractmethod
|
||||
import torch
|
||||
@@ -11,7 +11,7 @@ from ...problem import InverseProblem
|
||||
from ...condition import (
|
||||
InputOutputPointsCondition,
|
||||
InputPointsEquationCondition,
|
||||
DomainEquationCondition
|
||||
DomainEquationCondition,
|
||||
)
|
||||
|
||||
|
||||
@@ -20,22 +20,20 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
|
||||
Base PINN solver class. This class implements the Solver Interface
|
||||
for Physics Informed Neural Network solver.
|
||||
|
||||
This class can be used to define PINNs with multiple ``optimizers``,
|
||||
This class can be used to define PINNs with multiple ``optimizers``,
|
||||
and/or ``models``.
|
||||
By default it takes :class:`~pina.problem.abstract_problem.AbstractProblem`,
|
||||
so the user can choose what type of problem the implemented solver,
|
||||
inheriting from this class, is designed to solve.
|
||||
"""
|
||||
|
||||
accepted_conditions_types = (
|
||||
InputOutputPointsCondition,
|
||||
InputPointsEquationCondition,
|
||||
DomainEquationCondition
|
||||
DomainEquationCondition,
|
||||
)
|
||||
|
||||
def __init__(self,
|
||||
problem,
|
||||
loss=None,
|
||||
**kwargs):
|
||||
def __init__(self, problem, loss=None, **kwargs):
|
||||
"""
|
||||
:param AbstractProblem problem: A problem definition instance.
|
||||
:param torch.nn.Module loss: The loss function to be minimized,
|
||||
@@ -45,9 +43,7 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
|
||||
if loss is None:
|
||||
loss = torch.nn.MSELoss()
|
||||
|
||||
super().__init__(problem=problem,
|
||||
use_lt=True,
|
||||
**kwargs)
|
||||
super().__init__(problem=problem, use_lt=True, **kwargs)
|
||||
|
||||
# check consistency
|
||||
check_consistency(loss, (LossInterface, _Loss), subclass=False)
|
||||
@@ -72,14 +68,14 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
|
||||
def validation_step(self, batch):
|
||||
losses = self._run_optimization_cycle(batch, self._residual_loss)
|
||||
loss = self.weighting.aggregate(losses).as_subclass(torch.Tensor)
|
||||
self.store_log('val_loss', loss, self.get_batch_size(batch))
|
||||
self.store_log("val_loss", loss, self.get_batch_size(batch))
|
||||
return loss
|
||||
|
||||
@torch.set_grad_enabled(True)
|
||||
def test_step(self, batch):
|
||||
losses = self._run_optimization_cycle(batch, self._residual_loss)
|
||||
loss = self.weighting.aggregate(losses).as_subclass(torch.Tensor)
|
||||
self.store_log('test_loss', loss, self.get_batch_size(batch))
|
||||
self.store_log("test_loss", loss, self.get_batch_size(batch))
|
||||
return loss
|
||||
|
||||
def loss_data(self, input_pts, output_pts):
|
||||
@@ -129,42 +125,38 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
|
||||
except TypeError:
|
||||
# this occurs when the function has three inputs (inverse problem)
|
||||
residual = equation.residual(
|
||||
samples,
|
||||
self.forward(samples),
|
||||
self._params
|
||||
samples, self.forward(samples), self._params
|
||||
)
|
||||
return residual
|
||||
|
||||
def _residual_loss(self, samples, equation):
|
||||
residuals = self.compute_residual(samples, equation)
|
||||
return self.loss(residuals, torch.zeros_like(residuals))
|
||||
|
||||
|
||||
def _run_optimization_cycle(self, batch, loss_residuals):
|
||||
condition_loss = {}
|
||||
for condition_name, points in batch:
|
||||
self.__metric = condition_name
|
||||
# if equations are passed
|
||||
if 'output_points' not in points:
|
||||
input_pts = points['input_points']
|
||||
if "output_points" not in points:
|
||||
input_pts = points["input_points"]
|
||||
condition = self.problem.conditions[condition_name]
|
||||
loss = loss_residuals(
|
||||
input_pts.requires_grad_(),
|
||||
condition.equation
|
||||
input_pts.requires_grad_(), condition.equation
|
||||
)
|
||||
# if data are passed
|
||||
else:
|
||||
input_pts = points['input_points']
|
||||
output_pts = points['output_points']
|
||||
input_pts = points["input_points"]
|
||||
output_pts = points["output_points"]
|
||||
loss = self.loss_data(
|
||||
input_pts=input_pts.requires_grad_(),
|
||||
output_pts=output_pts
|
||||
input_pts=input_pts.requires_grad_(), output_pts=output_pts
|
||||
)
|
||||
# append loss
|
||||
condition_loss[condition_name] = loss
|
||||
# clamp unknown parameters in InverseProblem (if needed)
|
||||
self._clamp_params()
|
||||
return condition_loss
|
||||
|
||||
|
||||
def _clamp_inverse_problem_params(self):
|
||||
"""
|
||||
Clamps the parameters of the inverse problem
|
||||
@@ -175,14 +167,14 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
|
||||
self.problem.unknown_parameter_domain.range_[v][0],
|
||||
self.problem.unknown_parameter_domain.range_[v][1],
|
||||
)
|
||||
|
||||
|
||||
@property
|
||||
def loss(self):
|
||||
"""
|
||||
Loss used for training.
|
||||
"""
|
||||
return self._loss
|
||||
|
||||
|
||||
@property
|
||||
def current_condition_name(self):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user