@@ -129,10 +129,10 @@ class DeepONet(torch.nn.Module):
|
||||
|
||||
# output_ = self.reduction(inner_input)
|
||||
# print(output_.shape)
|
||||
print(branch_output.shape)
|
||||
print(trunk_output.shape)
|
||||
output_ = self.reduction(trunk_output * branch_output)
|
||||
output_ = LabelTensor(output_, self.output_variables)
|
||||
# output_ = LabelTensor(output_, self.output_variables)
|
||||
output_ = output_.as_subclass(LabelTensor)
|
||||
output_.labels = self.output_variables
|
||||
# local_size = int(trunk_output.shape[1]/self.output_dimension)
|
||||
# for i, var in enumerate(self.output_variables):
|
||||
# start = i*local_size
|
||||
|
||||
Reference in New Issue
Block a user