🎨 Format Python code with psf/black

This commit is contained in:
ndem0
2024-02-09 11:25:00 +00:00
committed by Nicola Demo
parent 591aeeb02b
commit cbb43a5392
64 changed files with 1323 additions and 955 deletions

View File

@@ -6,13 +6,15 @@ from ..label_tensor import LabelTensor
class Network(torch.nn.Module):
def __init__(self, model, input_variables, output_variables, extra_features=None):
def __init__(
self, model, input_variables, output_variables, extra_features=None
):
"""
Network class with standard forward method
and possibility to pass extra features. This
class is used internally in PINA to convert
any :class:`torch.nn.Module` s in a PINA module.
:param model: The torch model to convert in a PINA model.
:type model: torch.nn.Module
:param list(str) input_variables: The input variables of the :class:`AbstractProblem`, whose type depends on the
@@ -57,7 +59,9 @@ class Network(torch.nn.Module):
:return torch.Tensor: Output of the network.
"""
# only labeltensors as input
assert isinstance(x, LabelTensor), "Expected LabelTensor as input to the model."
assert isinstance(
x, LabelTensor
), "Expected LabelTensor as input to the model."
# extract torch.Tensor from corresponding label
# in case `input_variables = []` all points are used
@@ -75,7 +79,7 @@ class Network(torch.nn.Module):
output.labels = self._output_variables
return output
# TODO to remove in next releases (only used in GAROM solver)
def forward_map(self, x):
"""