Fix bug in handling labels with LabelTensor (#460)
--------- Co-authored-by: Filippo Olivo <filippo@filippoolivo.com>
This commit is contained in:
@@ -448,6 +448,9 @@ class LabelTensor(torch.Tensor):
|
|||||||
|
|
||||||
# Retrieve selected tensor and labels
|
# Retrieve selected tensor and labels
|
||||||
selected_tensor = super().__getitem__(index)
|
selected_tensor = super().__getitem__(index)
|
||||||
|
if not hasattr(self, "_labels"):
|
||||||
|
return selected_tensor
|
||||||
|
|
||||||
original_labels = self._labels
|
original_labels = self._labels
|
||||||
updated_labels = copy(original_labels)
|
updated_labels = copy(original_labels)
|
||||||
|
|
||||||
|
|||||||
@@ -265,7 +265,7 @@ class ContinuousConvBlock(BaseContinuousConv):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
# initialize to all zeros
|
# initialize to all zeros
|
||||||
tmp = torch.zeros_like(X)
|
tmp = torch.zeros_like(X).as_subclass(torch.Tensor)
|
||||||
tmp[..., :-1] = X[..., :-1]
|
tmp[..., :-1] = X[..., :-1]
|
||||||
|
|
||||||
# save on tmp
|
# save on tmp
|
||||||
|
|||||||
Reference in New Issue
Block a user