Changes to Tensor labels handling
This commit is contained in:
committed by
Nicola Demo
parent
f67467e5bd
commit
8b797d589a
@@ -479,9 +479,10 @@ class LabelTensor(torch.Tensor):
|
||||
|
||||
# Retrieve selected tensor and labels
|
||||
selected_tensor = super().__getitem__(index)
|
||||
if not hasattr(self, "_labels"):
|
||||
if hasattr(self, "_labels"):
|
||||
original_labels=self._labels
|
||||
else:
|
||||
return selected_tensor
|
||||
|
||||
original_labels = self._labels
|
||||
updated_labels = copy(original_labels)
|
||||
|
||||
|
||||
@@ -500,6 +500,7 @@ class ContinuousConvBlock(BaseContinuousConv):
|
||||
|
||||
# initialize grid
|
||||
X = self._grid_transpose.clone().detach()
|
||||
|
||||
conv_transposed = self._grid_transpose.clone().detach()
|
||||
|
||||
# list to iterate for calculating nn output
|
||||
|
||||
Reference in New Issue
Block a user