equation class, fix minor bugs, diff domain (#89)

* equation class
* difference domain
* dummy dataloader
* writer class
* refactoring and minor fix
This commit is contained in:
Nicola Demo
2023-05-15 16:06:01 +02:00
parent be11110bb2
commit 0e3625de80
25 changed files with 691 additions and 246 deletions

View File

@@ -1,5 +1,7 @@
""" Module for LabelTensor """
from typing import Any
import torch
from torch import Tensor
class LabelTensor(torch.Tensor):
@@ -79,7 +81,7 @@ class LabelTensor(torch.Tensor):
@labels.setter
def labels(self, labels):
if len(labels) != self.shape[1]: # small check
if len(labels) != self.shape[self.ndim - 1]: # small check
raise ValueError(
'the tensor has not the same number of columns of '
'the passed labels.')
@@ -140,7 +142,7 @@ class LabelTensor(torch.Tensor):
except ValueError:
raise ValueError(f'`{f}` not in the labels list')
new_data = self[:, indeces].float()
new_data = super(Tensor, self.T).__getitem__(indeces).float().T
new_labels = [self.labels[idx] for idx in indeces]
extracted_tensor = new_data.as_subclass(LabelTensor)
@@ -183,6 +185,19 @@ class LabelTensor(torch.Tensor):
new_tensor.labels = new_labels
return new_tensor
def __getitem__(self, index):
"""
Return a copy of the selected tensor.
"""
selected_lt = super(Tensor, self).__getitem__(index)
if hasattr(self, 'labels'):
selected_lt.labels = self.labels
return selected_lt
def __len__(self) -> int:
return super().__len__()
def __str__(self):
if hasattr(self, 'labels'):
s = f'labels({str(self.labels)})\n'