Fix bug in handling labels with LabelTensor (#460)

---------

Co-authored-by: Filippo Olivo <filippo@filippoolivo.com>
This commit is contained in:
MatteB03
2025-02-25 14:12:51 +01:00
committed by Nicola Demo
parent 42ab1a666b
commit 1bb35f7951
2 changed files with 4 additions and 1 deletions

View File

@@ -448,6 +448,9 @@ class LabelTensor(torch.Tensor):
# Retrieve selected tensor and labels
selected_tensor = super().__getitem__(index)
if not hasattr(self, "_labels"):
return selected_tensor
original_labels = self._labels
updated_labels = copy(original_labels)

View File

@@ -265,7 +265,7 @@ class ContinuousConvBlock(BaseContinuousConv):
"""
# initialize to all zeros
tmp = torch.zeros_like(X)
tmp = torch.zeros_like(X).as_subclass(torch.Tensor)
tmp[..., :-1] = X[..., :-1]
# save on tmp