73 lines
2.2 KiB
Python
73 lines
2.2 KiB
Python
from .chebyshev import chebyshev_roots
|
|
import torch
|
|
|
|
from .location import Location
|
|
from .label_tensor import LabelTensor
|
|
|
|
|
|
class Span(Location):
|
|
def __init__(self, span_dict):
|
|
|
|
self.fixed_ = {}
|
|
self.range_ = {}
|
|
|
|
for k, v in span_dict.items():
|
|
if isinstance(v, (int, float)):
|
|
self.fixed_[k] = v
|
|
elif isinstance(v, (list, tuple)) and len(v) == 2:
|
|
self.range_[k] = v
|
|
else:
|
|
raise TypeError
|
|
|
|
@property
|
|
def variables(self):
|
|
return list(self.fixed_.keys()) + list(self.range_.keys())
|
|
|
|
def update(self, new_span):
|
|
self.fixed_.update(new_span.fixed_)
|
|
self.range_.update(new_span.range_)
|
|
|
|
def _sample_range(self, n, mode, bounds):
|
|
"""
|
|
"""
|
|
if mode == 'random':
|
|
pts = torch.rand(size=(n, 1))
|
|
elif mode == 'chebyshev':
|
|
pts = chebyshev_roots(n).mul(.5).add(.5).reshape(-1, 1)
|
|
elif mode == 'grid':
|
|
pts = torch.linspace(0, 1, n).reshape(-1, 1)
|
|
elif mode == 'lh' or mode == 'latin':
|
|
from scipy.stats import qmc
|
|
sampler = qmc.LatinHypercube(d=1)
|
|
pts = sampler.random(n)
|
|
pts = torch.from_numpy(pts)
|
|
|
|
pts *= bounds[1] - bounds[0]
|
|
pts += bounds[0]
|
|
|
|
return pts
|
|
|
|
def sample(self, n, mode='random', variables='all'):
|
|
|
|
if variables == 'all':
|
|
variables = list(self.range_.keys()) + list(self.fixed_.keys())
|
|
|
|
result = None
|
|
for variable in variables:
|
|
if variable in self.range_.keys():
|
|
bound = torch.tensor(self.range_[variable])
|
|
pts_variable = self._sample_range(n, mode, bound)
|
|
pts_variable = LabelTensor(pts_variable, [variable])
|
|
|
|
elif variable in self.fixed_.keys():
|
|
value = self.fixed_[variable]
|
|
pts_variable = LabelTensor(torch.ones(n, 1)*value, [variable])
|
|
|
|
if result is None:
|
|
result = pts_variable
|
|
else:
|
|
intersect = 'std' if mode == 'random' else 'cross'
|
|
result = result.append(pts_variable, intersect)
|
|
|
|
return result
|