Vectorial output

This commit is contained in:
Nicola Demo
2022-03-07 10:09:40 +01:00
parent 1812ddb8d9
commit 8a1f07c8ae
6 changed files with 71 additions and 7 deletions

View File

@@ -52,6 +52,8 @@ class FeedForward(torch.nn.Module):
def forward(self, x):
"""
"""
x = x[self.input_variables]
nf = len(self.extra_features)
if nf == 0:
return LabelTensor(self.model(x.tensor), self.output_variables)