Fix bugs in 0.2 (#344)

* Fix some bugs
This commit is contained in:
FilippoOlivo
2024-09-12 18:12:59 +02:00
committed by Nicola Demo
parent f0d68b34c7
commit 30f865d912
11 changed files with 112 additions and 55 deletions

View File

@@ -5,7 +5,6 @@ import torch
from torch import Tensor
# class LabelTensor(torch.Tensor):
# """Torch tensor with a label for any column."""
@@ -307,13 +306,13 @@ from torch import Tensor
# s = "no labels\n"
# s += super().__str__()
# return s
def issubset(a, b):
"""
Check if a is a subset of b.
"""
return set(a).issubset(set(b))
class LabelTensor(torch.Tensor):
"""Torch tensor with a label for any column."""
@@ -403,6 +402,10 @@ class LabelTensor(torch.Tensor):
return LabelTensor(new_tensor, label_to_extract)
def __str__(self):
"""
returns a string with the representation of the class
"""
s = ''
for key, value in self.labels.items():
s += f"{key}: {value}\n"
@@ -431,4 +434,32 @@ class LabelTensor(torch.Tensor):
@property
def dtype(self):
return super().dtype
return super().dtype
def to(self, *args, **kwargs):
"""
Performs Tensor dtype and/or device conversion. For more details, see
:meth:`torch.Tensor.to`.
"""
tmp = super().to(*args, **kwargs)
new = self.__class__.clone(self)
new.data = tmp.data
return new
def clone(self, *args, **kwargs):
"""
Clone the LabelTensor. For more details, see
:meth:`torch.Tensor.clone`.
:return: A copy of the tensor.
:rtype: LabelTensor
"""
# # used before merging
# try:
# out = LabelTensor(super().clone(*args, **kwargs), self.labels)
# except:
# out = super().clone(*args, **kwargs)
out = LabelTensor(super().clone(*args, **kwargs), self.labels)
return out