Improve conditions and refactor dataset classes (#475)

* Reimplement conditions

* Refactor datasets and implement LabelBatch

---------

Co-authored-by: Dario Coscia <dariocos99@gmail.com>
This commit is contained in:
Filippo Olivo
2025-03-07 11:24:09 +01:00
committed by Nicola Demo
parent bdad144461
commit a0cbf1c44a
40 changed files with 943 additions and 550 deletions

View File

@@ -602,3 +602,20 @@ class LabelTensor(torch.Tensor):
}
return LabelTensor(data, labels)
def reshape(self, *shape):
"""
Override the reshape method to update the labels of the tensor.
:param shape: The new shape of the tensor.
:type shape: tuple
:return: A tensor-like object with updated labels.
:rtype: LabelTensor
"""
# As for now the reshape method is used only in the context of the
# dataset, the labels are not
tensor = super().reshape(*shape)
if not hasattr(self, "_labels") or shape != (-1, *self.shape[2:]):
return tensor
tensor.labels = self.labels
return tensor