supervised working

This commit is contained in:
Nicola Demo
2024-08-08 16:19:52 +02:00
parent 5245a0b68c
commit 9d9c2aa23e
61 changed files with 375 additions and 262 deletions

View File

@@ -229,10 +229,6 @@ from torch import Tensor
# detached._labels = self._labels
# return detached
# def requires_grad_(self, mode=True):
# lt = super().requires_grad_(mode)
# lt.labels = self.labels
# return lt
# def append(self, lt, mode="std"):
# """
@@ -406,11 +402,29 @@ class LabelTensor(torch.Tensor):
return LabelTensor(new_tensor, label_to_extract)
def __str__(self):
s = ''
for key, value in self.labels.items():
s += f"{key}: {value}\n"
s += '\n'
s += super().__str__()
return s
return s
@staticmethod
def stack(tensors):
"""
"""
if len(tensors) == 0:
return []
if len(tensors) == 1:
return tensors[0]
raise NotImplementedError
labels = [tensor.labels for tensor in tensors]
print(labels)
def requires_grad_(self, mode=True):
lt = super().requires_grad_(mode)
lt.labels = self.labels
return lt