Vectorial output
This commit is contained in:
@@ -52,6 +52,8 @@ class FeedForward(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
"""
|
||||
"""
|
||||
|
||||
x = x[self.input_variables]
|
||||
nf = len(self.extra_features)
|
||||
if nf == 0:
|
||||
return LabelTensor(self.model(x.tensor), self.output_variables)
|
||||
|
||||
Reference in New Issue
Block a user