edited utils to take list (#115)
* enhanced difference domain * refactored utils * fixed typo * added tests --------- Co-authored-by: Dario Coscia <93731561+dario-coscia@users.noreply.github.com>
This commit is contained in:
@@ -3,16 +3,17 @@
|
||||
from .location import Location
|
||||
from ..label_tensor import LabelTensor
|
||||
|
||||
|
||||
class Difference(Location):
|
||||
"""
|
||||
"""
|
||||
def __init__(self, first, second):
|
||||
|
||||
def __init__(self, first, second):
|
||||
|
||||
self.first = first
|
||||
self.second = second
|
||||
|
||||
def sample(self, n, mode ='random', variables='all'):
|
||||
def sample(self, n, mode='random', variables='all'):
|
||||
"""
|
||||
"""
|
||||
assert mode is 'random', 'Only random mode is implemented'
|
||||
@@ -24,4 +25,4 @@ class Difference(Location):
|
||||
samples.append(sample.tolist()[0])
|
||||
|
||||
import torch
|
||||
return LabelTensor(torch.tensor(samples), labels=['x', 'y'])
|
||||
return LabelTensor(torch.tensor(samples), labels=['x', 'y'])
|
||||
|
||||
Reference in New Issue
Block a user