Fix bug in handling labels with LabelTensor (#460)
--------- Co-authored-by: Filippo Olivo <filippo@filippoolivo.com>
This commit is contained in:
@@ -448,6 +448,9 @@ class LabelTensor(torch.Tensor):
|
||||
|
||||
# Retrieve selected tensor and labels
|
||||
selected_tensor = super().__getitem__(index)
|
||||
if not hasattr(self, "_labels"):
|
||||
return selected_tensor
|
||||
|
||||
original_labels = self._labels
|
||||
updated_labels = copy(original_labels)
|
||||
|
||||
|
||||
@@ -265,7 +265,7 @@ class ContinuousConvBlock(BaseContinuousConv):
|
||||
|
||||
"""
|
||||
# initialize to all zeros
|
||||
tmp = torch.zeros_like(X)
|
||||
tmp = torch.zeros_like(X).as_subclass(torch.Tensor)
|
||||
tmp[..., :-1] = X[..., :-1]
|
||||
|
||||
# save on tmp
|
||||
|
||||
Reference in New Issue
Block a user