equation class, fix minor bugs, diff domain (#89)
* equation class * difference domain * dummy dataloader * writer class * refactoring and minor fix
This commit is contained in:
@@ -1,5 +1,7 @@
|
||||
""" Module for LabelTensor """
|
||||
from typing import Any
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
class LabelTensor(torch.Tensor):
|
||||
@@ -79,7 +81,7 @@ class LabelTensor(torch.Tensor):
|
||||
|
||||
@labels.setter
|
||||
def labels(self, labels):
|
||||
if len(labels) != self.shape[1]: # small check
|
||||
if len(labels) != self.shape[self.ndim - 1]: # small check
|
||||
raise ValueError(
|
||||
'the tensor has not the same number of columns of '
|
||||
'the passed labels.')
|
||||
@@ -140,7 +142,7 @@ class LabelTensor(torch.Tensor):
|
||||
except ValueError:
|
||||
raise ValueError(f'`{f}` not in the labels list')
|
||||
|
||||
new_data = self[:, indeces].float()
|
||||
new_data = super(Tensor, self.T).__getitem__(indeces).float().T
|
||||
new_labels = [self.labels[idx] for idx in indeces]
|
||||
|
||||
extracted_tensor = new_data.as_subclass(LabelTensor)
|
||||
@@ -183,6 +185,19 @@ class LabelTensor(torch.Tensor):
|
||||
new_tensor.labels = new_labels
|
||||
return new_tensor
|
||||
|
||||
def __getitem__(self, index):
|
||||
"""
|
||||
Return a copy of the selected tensor.
|
||||
"""
|
||||
selected_lt = super(Tensor, self).__getitem__(index)
|
||||
if hasattr(self, 'labels'):
|
||||
selected_lt.labels = self.labels
|
||||
|
||||
return selected_lt
|
||||
|
||||
def __len__(self) -> int:
|
||||
return super().__len__()
|
||||
|
||||
def __str__(self):
|
||||
if hasattr(self, 'labels'):
|
||||
s = f'labels({str(self.labels)})\n'
|
||||
|
||||
Reference in New Issue
Block a user