DeepOnet implementation, LabelTensor modification
* Implementing standard DeepOnet (trunk/branch net) * Implementing multiple reduction/ average techniques * Small change LabelTensor __getitem__ for handling list
This commit is contained in:
committed by
Nicola Demo
parent
15ecaacb7c
commit
b029f18c49
@@ -212,8 +212,11 @@ class LabelTensor(torch.Tensor):
|
||||
if selected_lt.ndim == 1:
|
||||
selected_lt = selected_lt.reshape(-1, 1)
|
||||
if hasattr(self, 'labels'):
|
||||
selected_lt.labels = self.labels[index[1]]
|
||||
|
||||
if isinstance(index[1], list):
|
||||
selected_lt.labels = [self.labels[i] for i in index[1]]
|
||||
else:
|
||||
selected_lt.labels = self.labels[index[1]]
|
||||
|
||||
return selected_lt
|
||||
|
||||
def __len__(self) -> int:
|
||||
|
||||
Reference in New Issue
Block a user