Changes to Tensor labels handling

This commit is contained in:
Matteo Bertocchi
2025-02-25 12:44:18 +01:00
committed by Nicola Demo
parent f67467e5bd
commit 8b797d589a
2 changed files with 4 additions and 2 deletions

View File

@@ -479,9 +479,10 @@ class LabelTensor(torch.Tensor):
# Retrieve selected tensor and labels
selected_tensor = super().__getitem__(index)
if not hasattr(self, "_labels"):
if hasattr(self, "_labels"):
original_labels=self._labels
else:
return selected_tensor
original_labels = self._labels
updated_labels = copy(original_labels)

View File

@@ -500,6 +500,7 @@ class ContinuousConvBlock(BaseContinuousConv):
# initialize grid
X = self._grid_transpose.clone().detach()
conv_transposed = self._grid_transpose.clone().detach()
# list to iterate for calculating nn output