version 0.0.1
This commit is contained in:
@@ -1,8 +1,8 @@
|
||||
import torch
|
||||
|
||||
class LabelTensor():
|
||||
class LabelTensor():
|
||||
|
||||
def __init__(self, x, labels):
|
||||
def __init__(self, x, labels):
|
||||
|
||||
|
||||
if len(labels) != x.shape[1]:
|
||||
@@ -21,7 +21,19 @@ class LabelTensor():
|
||||
return self.tensor
|
||||
|
||||
def __str__(self):
|
||||
return self.tensor, self.labels
|
||||
return '{}\n {}\n'.format(self.labels, self.tensor)
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
return self.tensor.shape
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self.tensor.dtype
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return self.tensor.device
|
||||
|
||||
@property
|
||||
def labels(self):
|
||||
|
||||
Reference in New Issue
Block a user