DeepOnet implementation, LabelTensor modification

* Implementing standard DeepOnet (trunk/branch net)
* Implementing multiple reduction/ average techniques
* Small change  LabelTensor __getitem__ for handling list
This commit is contained in:
Dario Coscia
2023-09-06 12:40:21 +02:00
committed by Nicola Demo
parent 15ecaacb7c
commit b029f18c49
3 changed files with 140 additions and 154 deletions

View File

@@ -212,8 +212,11 @@ class LabelTensor(torch.Tensor):
if selected_lt.ndim == 1:
selected_lt = selected_lt.reshape(-1, 1)
if hasattr(self, 'labels'):
selected_lt.labels = self.labels[index[1]]
if isinstance(index[1], list):
selected_lt.labels = [self.labels[i] for i in index[1]]
else:
selected_lt.labels = self.labels[index[1]]
return selected_lt
def __len__(self) -> int: