Update label_tensor.py cpu/gpu (#292)
* Update label_tensor.py cpu/gpu * Update test_adaptive_refinment_callbacks.py * Update test_optimizer_callbacks.py
This commit is contained in:
@@ -176,7 +176,7 @@ class LabelTensor(torch.Tensor):
|
||||
tmp = super().cuda(*args, **kwargs)
|
||||
new = self.__class__.clone(self)
|
||||
new.data = tmp.data
|
||||
return tmp
|
||||
return new
|
||||
|
||||
def cpu(self, *args, **kwargs):
|
||||
"""
|
||||
@@ -185,7 +185,7 @@ class LabelTensor(torch.Tensor):
|
||||
tmp = super().cpu(*args, **kwargs)
|
||||
new = self.__class__.clone(self)
|
||||
new.data = tmp.data
|
||||
return tmp
|
||||
return new
|
||||
|
||||
def extract(self, label_to_extract):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user