Update label_tensor.py cpu/gpu (#292)

* Update label_tensor.py cpu/gpu
* Update test_adaptive_refinment_callbacks.py
* Update test_optimizer_callbacks.py
This commit is contained in:
Dario Coscia
2024-04-30 18:52:09 +02:00
committed by GitHub
parent 7e38fb3a82
commit 0e98abf204
3 changed files with 4 additions and 2 deletions

View File

@@ -176,7 +176,7 @@ class LabelTensor(torch.Tensor):
tmp = super().cuda(*args, **kwargs) tmp = super().cuda(*args, **kwargs)
new = self.__class__.clone(self) new = self.__class__.clone(self)
new.data = tmp.data new.data = tmp.data
return tmp return new
def cpu(self, *args, **kwargs): def cpu(self, *args, **kwargs):
""" """
@@ -185,7 +185,7 @@ class LabelTensor(torch.Tensor):
tmp = super().cpu(*args, **kwargs) tmp = super().cpu(*args, **kwargs)
new = self.__class__.clone(self) new = self.__class__.clone(self)
new.data = tmp.data new.data = tmp.data
return tmp return new
def extract(self, label_to_extract): def extract(self, label_to_extract):
""" """

View File

@@ -71,6 +71,7 @@ def test_r3refinment_routine():
# make the trainer # make the trainer
trainer = Trainer(solver=solver, trainer = Trainer(solver=solver,
callbacks=[R3Refinement(sample_every=1)], callbacks=[R3Refinement(sample_every=1)],
accelerator='cpu',
max_epochs=5) max_epochs=5)
trainer.train() trainer.train()

View File

@@ -84,5 +84,6 @@ def test_switch_optimizer_routine():
new_optimizers_kwargs={'lr': 0.01}, new_optimizers_kwargs={'lr': 0.01},
epoch_switch=3) epoch_switch=3)
], ],
accelerator='cpu',
max_epochs=5) max_epochs=5)
trainer.train() trainer.train()