From 0e98abf204ea07e444f7b84abe9815762b1ed5c8 Mon Sep 17 00:00:00 2001 From: Dario Coscia <93731561+dario-coscia@users.noreply.github.com> Date: Tue, 30 Apr 2024 18:52:09 +0200 Subject: [PATCH] Update label_tensor.py cpu/gpu (#292) * Update label_tensor.py cpu/gpu * Update test_adaptive_refinment_callbacks.py * Update test_optimizer_callbacks.py --- pina/label_tensor.py | 4 ++-- tests/test_callbacks/test_adaptive_refinment_callbacks.py | 1 + tests/test_callbacks/test_optimizer_callbacks.py | 1 + 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/pina/label_tensor.py b/pina/label_tensor.py index fe8e1a8..c8a41f7 100644 --- a/pina/label_tensor.py +++ b/pina/label_tensor.py @@ -176,7 +176,7 @@ class LabelTensor(torch.Tensor): tmp = super().cuda(*args, **kwargs) new = self.__class__.clone(self) new.data = tmp.data - return tmp + return new def cpu(self, *args, **kwargs): """ @@ -185,7 +185,7 @@ class LabelTensor(torch.Tensor): tmp = super().cpu(*args, **kwargs) new = self.__class__.clone(self) new.data = tmp.data - return tmp + return new def extract(self, label_to_extract): """ diff --git a/tests/test_callbacks/test_adaptive_refinment_callbacks.py b/tests/test_callbacks/test_adaptive_refinment_callbacks.py index fb74367..214257d 100644 --- a/tests/test_callbacks/test_adaptive_refinment_callbacks.py +++ b/tests/test_callbacks/test_adaptive_refinment_callbacks.py @@ -71,6 +71,7 @@ def test_r3refinment_routine(): # make the trainer trainer = Trainer(solver=solver, callbacks=[R3Refinement(sample_every=1)], + accelerator='cpu', max_epochs=5) trainer.train() diff --git a/tests/test_callbacks/test_optimizer_callbacks.py b/tests/test_callbacks/test_optimizer_callbacks.py index 6c167b6..0b0aaba 100644 --- a/tests/test_callbacks/test_optimizer_callbacks.py +++ b/tests/test_callbacks/test_optimizer_callbacks.py @@ -84,5 +84,6 @@ def test_switch_optimizer_routine(): new_optimizers_kwargs={'lr': 0.01}, epoch_switch=3) ], + accelerator='cpu', max_epochs=5) trainer.train()