fix Supervised/PINN solvers forward + fix tut5

This commit is contained in:
Dario Coscia
2023-11-09 11:24:00 +01:00
committed by Nicola Demo
parent 4977d55507
commit 4844640727
6 changed files with 123 additions and 79 deletions

View File

@@ -83,11 +83,11 @@ class PINN(SolverInterface):
:return: PINN solution.
:rtype: torch.Tensor
"""
# extract labels
x = x.extract(self.problem.input_variables)
# perform forward pass
# extract torch.Tensor from corresponding label
x = x.extract(self.problem.input_variables).as_subclass(torch.Tensor)
# perform forward pass (using torch.Tensor) + converting to LabelTensor
output = self.neural_net(x).as_subclass(LabelTensor)
# set the labels
# set the labels for LabelTensor
output.labels = self.problem.output_variables
return output