🎨 Format Python code with psf/black
This commit is contained in:
@@ -1,10 +1,10 @@
|
||||
__all__ = [
|
||||
'SystemEquation',
|
||||
'Equation',
|
||||
'FixedValue',
|
||||
'FixedGradient',
|
||||
'FixedFlux',
|
||||
'Laplace',
|
||||
"SystemEquation",
|
||||
"Equation",
|
||||
"FixedValue",
|
||||
"FixedGradient",
|
||||
"FixedFlux",
|
||||
"Laplace",
|
||||
]
|
||||
|
||||
from .equation import Equation
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
""" Module for Equation. """
|
||||
|
||||
from .equation_interface import EquationInterface
|
||||
|
||||
|
||||
@@ -15,12 +16,14 @@ class Equation(EquationInterface):
|
||||
:type equation: Callable
|
||||
"""
|
||||
if not callable(equation):
|
||||
raise ValueError('equation must be a callable function.'
|
||||
'Expected a callable function, got '
|
||||
f'{equation}')
|
||||
raise ValueError(
|
||||
"equation must be a callable function."
|
||||
"Expected a callable function, got "
|
||||
f"{equation}"
|
||||
)
|
||||
self.__equation = equation
|
||||
|
||||
def residual(self, input_, output_, params_ = None):
|
||||
def residual(self, input_, output_, params_=None):
|
||||
"""
|
||||
Residual computation of the equation.
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
""" Module """
|
||||
|
||||
from .equation import Equation
|
||||
from ..operators import grad, div, laplacian
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
""" Module for EquationInterface class """
|
||||
|
||||
from abc import ABCMeta, abstractmethod
|
||||
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
""" Module for SystemEquation. """
|
||||
|
||||
import torch
|
||||
from .equation import Equation
|
||||
from ..utils import check_consistency
|
||||
@@ -6,7 +7,7 @@ from ..utils import check_consistency
|
||||
|
||||
class SystemEquation(Equation):
|
||||
|
||||
def __init__(self, list_equation, reduction='mean'):
|
||||
def __init__(self, list_equation, reduction="mean"):
|
||||
"""
|
||||
System of Equation class for specifing any system
|
||||
of equations in PINA.
|
||||
@@ -33,15 +34,16 @@ class SystemEquation(Equation):
|
||||
self.equations.append(Equation(equation))
|
||||
|
||||
# possible reduction
|
||||
if reduction == 'mean':
|
||||
if reduction == "mean":
|
||||
self.reduction = torch.mean
|
||||
elif reduction == 'sum':
|
||||
elif reduction == "sum":
|
||||
self.reduction = torch.sum
|
||||
elif (reduction == 'none') or callable(reduction):
|
||||
elif (reduction == "none") or callable(reduction):
|
||||
self.reduction = reduction
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
'Only mean and sum reductions implemented.')
|
||||
"Only mean and sum reductions implemented."
|
||||
)
|
||||
|
||||
def residual(self, input_, output_, params_=None):
|
||||
"""
|
||||
@@ -64,9 +66,13 @@ class SystemEquation(Equation):
|
||||
:rtype: LabelTensor
|
||||
"""
|
||||
residual = torch.hstack(
|
||||
[equation.residual(input_, output_, params_) for equation in self.equations])
|
||||
[
|
||||
equation.residual(input_, output_, params_)
|
||||
for equation in self.equations
|
||||
]
|
||||
)
|
||||
|
||||
if self.reduction == 'none':
|
||||
if self.reduction == "none":
|
||||
return residual
|
||||
|
||||
return self.reduction(residual, dim=-1)
|
||||
|
||||
Reference in New Issue
Block a user