From 1bb35f79510d8dacc753d411daedac93dcbfe0b3 Mon Sep 17 00:00:00 2001 From: MatteB03 Date: Tue, 25 Feb 2025 14:12:51 +0100 Subject: [PATCH] Fix bug in handling labels with LabelTensor (#460) --------- Co-authored-by: Filippo Olivo --- pina/label_tensor.py | 3 +++ pina/model/block/convolution_2d.py | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/pina/label_tensor.py b/pina/label_tensor.py index 4c30f4b..6448044 100644 --- a/pina/label_tensor.py +++ b/pina/label_tensor.py @@ -448,6 +448,9 @@ class LabelTensor(torch.Tensor): # Retrieve selected tensor and labels selected_tensor = super().__getitem__(index) + if not hasattr(self, "_labels"): + return selected_tensor + original_labels = self._labels updated_labels = copy(original_labels) diff --git a/pina/model/block/convolution_2d.py b/pina/model/block/convolution_2d.py index 665ddaf..4c08533 100644 --- a/pina/model/block/convolution_2d.py +++ b/pina/model/block/convolution_2d.py @@ -265,7 +265,7 @@ class ContinuousConvBlock(BaseContinuousConv): """ # initialize to all zeros - tmp = torch.zeros_like(X) + tmp = torch.zeros_like(X).as_subclass(torch.Tensor) tmp[..., :-1] = X[..., :-1] # save on tmp