fix bug network

This commit is contained in:
Dario Coscia
2023-11-13 12:25:40 +01:00
committed by Nicola Demo
parent ee39b39805
commit a9f14ac323
6 changed files with 127 additions and 80 deletions

View File

@@ -56,6 +56,9 @@ class Network(torch.nn.Module):
:param torch.Tensor x: Input of the network.
:return torch.Tensor: Output of the network.
"""
# only labeltensors as input
assert isinstance(x, LabelTensor), "Expected LabelTensor as input to the model."
# extract torch.Tensor from corresponding label
# in case `input_variables = []` all points are used
if self._input_variables:
@@ -65,22 +68,20 @@ class Network(torch.nn.Module):
for feature in self._extra_features:
x = x.append(feature(x))
# convert LabelTensor to torch.Tensor
x = x.as_subclass(torch.Tensor)
# perform forward pass (using torch.Tensor) + converting to LabelTensor
# perform forward pass + converting to LabelTensor
output = self._model(x).as_subclass(LabelTensor)
# set the labels for LabelTensor
output.labels = self._output_variables
return output
# TODO to remove in next releases (only used in GAROM solver)
def forward_map(self, x):
"""
Forward method for Network class when the input is
a tuple. This class implements the standard forward method,
and it adds the possibility to pass extra features.
a tuple. This class is simply a forward with the input casted as a
tuple or list :class`torch.Tensor`.
All the PINA models ``forward`` s are overriden
by this class, to enable :class:`pina.label_tensor.LabelTensor` labels
extraction.