lh solved (#55)

This commit is contained in:
Dario Coscia
2022-12-12 17:00:18 +01:00
committed by GitHub
parent dbd78c9cf3
commit 20f4327b27
3 changed files with 48 additions and 8 deletions

View File

@@ -3,6 +3,7 @@ import torch
from .location import Location
from .label_tensor import LabelTensor
from .utils import torch_lhs
class Span(Location):
@@ -41,10 +42,7 @@ class Span(Location):
elif mode == 'grid':
pts = torch.linspace(0, 1, n).reshape(-1, 1)
elif mode == 'lh' or mode == 'latin':
from scipy.stats import qmc
sampler = qmc.LatinHypercube(d=dim)
pts = sampler.random(n)
pts = torch.from_numpy(pts)
pts = torch_lhs(n, dim)
pts *= bounds[:, 1] - bounds[:, 0]
pts += bounds[:, 0]
@@ -83,7 +81,7 @@ class Span(Location):
return result
def _Nd_sampler(n, mode, variables):
""" Sample ll the variables together """
""" Sample all the variables together """
pairs = [(k, v) for k, v in self.range_.items() if k in variables]
keys, values = map(list, zip(*pairs))
bounds = torch.tensor(values)
@@ -107,7 +105,7 @@ class Span(Location):
if mode in ['grid', 'chebyshev']:
return _1d_sampler(n, mode, variables)
elif mode in ['random', 'lhs']:
elif mode in ['random', 'lh', 'latin']:
return _Nd_sampler(n, mode, variables)
else:
raise ValueError(f'mode={mode} is not valid.')