supervised working
This commit is contained in:
@@ -229,10 +229,6 @@ from torch import Tensor
|
||||
# detached._labels = self._labels
|
||||
# return detached
|
||||
|
||||
# def requires_grad_(self, mode=True):
|
||||
# lt = super().requires_grad_(mode)
|
||||
# lt.labels = self.labels
|
||||
# return lt
|
||||
|
||||
# def append(self, lt, mode="std"):
|
||||
# """
|
||||
@@ -406,11 +402,29 @@ class LabelTensor(torch.Tensor):
|
||||
|
||||
return LabelTensor(new_tensor, label_to_extract)
|
||||
|
||||
|
||||
def __str__(self):
|
||||
s = ''
|
||||
for key, value in self.labels.items():
|
||||
s += f"{key}: {value}\n"
|
||||
s += '\n'
|
||||
s += super().__str__()
|
||||
return s
|
||||
return s
|
||||
|
||||
@staticmethod
|
||||
def stack(tensors):
|
||||
"""
|
||||
"""
|
||||
if len(tensors) == 0:
|
||||
return []
|
||||
|
||||
if len(tensors) == 1:
|
||||
return tensors[0]
|
||||
|
||||
raise NotImplementedError
|
||||
labels = [tensor.labels for tensor in tensors]
|
||||
print(labels)
|
||||
|
||||
def requires_grad_(self, mode=True):
|
||||
lt = super().requires_grad_(mode)
|
||||
lt.labels = self.labels
|
||||
return lt
|
||||
Reference in New Issue
Block a user