🎨 Format Python code with psf/black
This commit is contained in:
@@ -6,13 +6,15 @@ from ..label_tensor import LabelTensor
|
||||
|
||||
class Network(torch.nn.Module):
|
||||
|
||||
def __init__(self, model, input_variables, output_variables, extra_features=None):
|
||||
def __init__(
|
||||
self, model, input_variables, output_variables, extra_features=None
|
||||
):
|
||||
"""
|
||||
Network class with standard forward method
|
||||
and possibility to pass extra features. This
|
||||
class is used internally in PINA to convert
|
||||
any :class:`torch.nn.Module` s in a PINA module.
|
||||
|
||||
|
||||
:param model: The torch model to convert in a PINA model.
|
||||
:type model: torch.nn.Module
|
||||
:param list(str) input_variables: The input variables of the :class:`AbstractProblem`, whose type depends on the
|
||||
@@ -57,7 +59,9 @@ class Network(torch.nn.Module):
|
||||
:return torch.Tensor: Output of the network.
|
||||
"""
|
||||
# only labeltensors as input
|
||||
assert isinstance(x, LabelTensor), "Expected LabelTensor as input to the model."
|
||||
assert isinstance(
|
||||
x, LabelTensor
|
||||
), "Expected LabelTensor as input to the model."
|
||||
|
||||
# extract torch.Tensor from corresponding label
|
||||
# in case `input_variables = []` all points are used
|
||||
@@ -75,7 +79,7 @@ class Network(torch.nn.Module):
|
||||
output.labels = self._output_variables
|
||||
|
||||
return output
|
||||
|
||||
|
||||
# TODO to remove in next releases (only used in GAROM solver)
|
||||
def forward_map(self, x):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user