Update label_tensor.py cpu/gpu (#292)
* Update label_tensor.py cpu/gpu * Update test_adaptive_refinment_callbacks.py * Update test_optimizer_callbacks.py
This commit is contained in:
@@ -176,7 +176,7 @@ class LabelTensor(torch.Tensor):
|
|||||||
tmp = super().cuda(*args, **kwargs)
|
tmp = super().cuda(*args, **kwargs)
|
||||||
new = self.__class__.clone(self)
|
new = self.__class__.clone(self)
|
||||||
new.data = tmp.data
|
new.data = tmp.data
|
||||||
return tmp
|
return new
|
||||||
|
|
||||||
def cpu(self, *args, **kwargs):
|
def cpu(self, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
@@ -185,7 +185,7 @@ class LabelTensor(torch.Tensor):
|
|||||||
tmp = super().cpu(*args, **kwargs)
|
tmp = super().cpu(*args, **kwargs)
|
||||||
new = self.__class__.clone(self)
|
new = self.__class__.clone(self)
|
||||||
new.data = tmp.data
|
new.data = tmp.data
|
||||||
return tmp
|
return new
|
||||||
|
|
||||||
def extract(self, label_to_extract):
|
def extract(self, label_to_extract):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -71,6 +71,7 @@ def test_r3refinment_routine():
|
|||||||
# make the trainer
|
# make the trainer
|
||||||
trainer = Trainer(solver=solver,
|
trainer = Trainer(solver=solver,
|
||||||
callbacks=[R3Refinement(sample_every=1)],
|
callbacks=[R3Refinement(sample_every=1)],
|
||||||
|
accelerator='cpu',
|
||||||
max_epochs=5)
|
max_epochs=5)
|
||||||
trainer.train()
|
trainer.train()
|
||||||
|
|
||||||
|
|||||||
@@ -84,5 +84,6 @@ def test_switch_optimizer_routine():
|
|||||||
new_optimizers_kwargs={'lr': 0.01},
|
new_optimizers_kwargs={'lr': 0.01},
|
||||||
epoch_switch=3)
|
epoch_switch=3)
|
||||||
],
|
],
|
||||||
|
accelerator='cpu',
|
||||||
max_epochs=5)
|
max_epochs=5)
|
||||||
trainer.train()
|
trainer.train()
|
||||||
|
|||||||
Reference in New Issue
Block a user