🎨 Format Python code with psf/black
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
""" Module for SystemEquation. """
|
||||
|
||||
import torch
|
||||
from .equation import Equation
|
||||
from ..utils import check_consistency
|
||||
@@ -6,7 +7,7 @@ from ..utils import check_consistency
|
||||
|
||||
class SystemEquation(Equation):
|
||||
|
||||
def __init__(self, list_equation, reduction='mean'):
|
||||
def __init__(self, list_equation, reduction="mean"):
|
||||
"""
|
||||
System of Equation class for specifing any system
|
||||
of equations in PINA.
|
||||
@@ -33,15 +34,16 @@ class SystemEquation(Equation):
|
||||
self.equations.append(Equation(equation))
|
||||
|
||||
# possible reduction
|
||||
if reduction == 'mean':
|
||||
if reduction == "mean":
|
||||
self.reduction = torch.mean
|
||||
elif reduction == 'sum':
|
||||
elif reduction == "sum":
|
||||
self.reduction = torch.sum
|
||||
elif (reduction == 'none') or callable(reduction):
|
||||
elif (reduction == "none") or callable(reduction):
|
||||
self.reduction = reduction
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
'Only mean and sum reductions implemented.')
|
||||
"Only mean and sum reductions implemented."
|
||||
)
|
||||
|
||||
def residual(self, input_, output_, params_=None):
|
||||
"""
|
||||
@@ -64,9 +66,13 @@ class SystemEquation(Equation):
|
||||
:rtype: LabelTensor
|
||||
"""
|
||||
residual = torch.hstack(
|
||||
[equation.residual(input_, output_, params_) for equation in self.equations])
|
||||
[
|
||||
equation.residual(input_, output_, params_)
|
||||
for equation in self.equations
|
||||
]
|
||||
)
|
||||
|
||||
if self.reduction == 'none':
|
||||
if self.reduction == "none":
|
||||
return residual
|
||||
|
||||
return self.reduction(residual, dim=-1)
|
||||
|
||||
Reference in New Issue
Block a user