committed by
Nicola Demo
parent
f0d68b34c7
commit
30f865d912
@@ -5,7 +5,6 @@ import torch
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
|
||||
# class LabelTensor(torch.Tensor):
|
||||
# """Torch tensor with a label for any column."""
|
||||
|
||||
@@ -307,13 +306,13 @@ from torch import Tensor
|
||||
# s = "no labels\n"
|
||||
# s += super().__str__()
|
||||
# return s
|
||||
|
||||
def issubset(a, b):
|
||||
"""
|
||||
Check if a is a subset of b.
|
||||
"""
|
||||
return set(a).issubset(set(b))
|
||||
|
||||
|
||||
class LabelTensor(torch.Tensor):
|
||||
"""Torch tensor with a label for any column."""
|
||||
|
||||
@@ -403,6 +402,10 @@ class LabelTensor(torch.Tensor):
|
||||
return LabelTensor(new_tensor, label_to_extract)
|
||||
|
||||
def __str__(self):
|
||||
"""
|
||||
returns a string with the representation of the class
|
||||
"""
|
||||
|
||||
s = ''
|
||||
for key, value in self.labels.items():
|
||||
s += f"{key}: {value}\n"
|
||||
@@ -431,4 +434,32 @@ class LabelTensor(torch.Tensor):
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return super().dtype
|
||||
return super().dtype
|
||||
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
"""
|
||||
Performs Tensor dtype and/or device conversion. For more details, see
|
||||
:meth:`torch.Tensor.to`.
|
||||
"""
|
||||
tmp = super().to(*args, **kwargs)
|
||||
new = self.__class__.clone(self)
|
||||
new.data = tmp.data
|
||||
return new
|
||||
|
||||
|
||||
def clone(self, *args, **kwargs):
|
||||
"""
|
||||
Clone the LabelTensor. For more details, see
|
||||
:meth:`torch.Tensor.clone`.
|
||||
|
||||
:return: A copy of the tensor.
|
||||
:rtype: LabelTensor
|
||||
"""
|
||||
# # used before merging
|
||||
# try:
|
||||
# out = LabelTensor(super().clone(*args, **kwargs), self.labels)
|
||||
# except:
|
||||
# out = super().clone(*args, **kwargs)
|
||||
out = LabelTensor(super().clone(*args, **kwargs), self.labels)
|
||||
return out
|
||||
Reference in New Issue
Block a user