CUDA option for labeltensor (#23)

* fix cuda device for labeltensor
This commit is contained in:
Nicola Demo
2022-09-08 17:31:49 +02:00
committed by GitHub
parent 9b2ab7be41
commit 06932196a8
5 changed files with 61 additions and 56 deletions

View File

@@ -129,10 +129,10 @@ class DeepONet(torch.nn.Module):
# output_ = self.reduction(inner_input)
# print(output_.shape)
print(branch_output.shape)
print(trunk_output.shape)
output_ = self.reduction(trunk_output * branch_output)
output_ = LabelTensor(output_, self.output_variables)
# output_ = LabelTensor(output_, self.output_variables)
output_ = output_.as_subclass(LabelTensor)
output_.labels = self.output_variables
# local_size = int(trunk_output.shape[1]/self.output_dimension)
# for i, var in enumerate(self.output_variables):
# start = i*local_size