Improve conditions and refactor dataset classes (#475)
* Reimplement conditions * Refactor datasets and implement LabelBatch --------- Co-authored-by: Dario Coscia <dariocos99@gmail.com>
This commit is contained in:
committed by
Nicola Demo
parent
bdad144461
commit
a0cbf1c44a
@@ -602,3 +602,20 @@ class LabelTensor(torch.Tensor):
|
||||
}
|
||||
|
||||
return LabelTensor(data, labels)
|
||||
|
||||
def reshape(self, *shape):
|
||||
"""
|
||||
Override the reshape method to update the labels of the tensor.
|
||||
|
||||
:param shape: The new shape of the tensor.
|
||||
:type shape: tuple
|
||||
:return: A tensor-like object with updated labels.
|
||||
:rtype: LabelTensor
|
||||
"""
|
||||
# As for now the reshape method is used only in the context of the
|
||||
# dataset, the labels are not
|
||||
tensor = super().reshape(*shape)
|
||||
if not hasattr(self, "_labels") or shape != (-1, *self.shape[2:]):
|
||||
return tensor
|
||||
tensor.labels = self.labels
|
||||
return tensor
|
||||
|
||||
Reference in New Issue
Block a user