92 lines
3.0 KiB
Python
92 lines
3.0 KiB
Python
import numpy as np
|
|
from .chebyshev import chebyshev_roots
|
|
import torch
|
|
|
|
from .location import Location
|
|
from .label_tensor import LabelTensor
|
|
|
|
|
|
class Span(Location):
|
|
def __init__(self, span_dict):
|
|
|
|
self.fixed_ = {}
|
|
self.range_ = {}
|
|
|
|
for k, v in span_dict.items():
|
|
if isinstance(v, (int, float)):
|
|
self.fixed_[k] = v
|
|
elif isinstance(v, (list, tuple)) and len(v) == 2:
|
|
self.range_[k] = v
|
|
else:
|
|
raise TypeError
|
|
|
|
def sample(self, n, mode='random', variables='all'):
|
|
|
|
if variables == 'all':
|
|
spatial_range_ = list(self.range_.keys())
|
|
spatial_fixed_ = list(self.fixed_.keys())
|
|
bounds = np.array(list(self.range_.values()))
|
|
fixed = np.array(list(self.fixed_.values()))
|
|
else:
|
|
bounds = []
|
|
spatial_range_ = []
|
|
spatial_fixed_ = []
|
|
fixed = []
|
|
for variable in variables:
|
|
if variable in self.range_.keys():
|
|
spatial_range_.append(variable)
|
|
bounds.append(list(self.range_[variable]))
|
|
elif variable in self.fixed_.keys():
|
|
spatial_fixed_.append(variable)
|
|
fixed.append(int(self.fixed_[variable]))
|
|
fixed = torch.Tensor(fixed)
|
|
bounds = np.array(bounds)
|
|
|
|
if mode == 'random':
|
|
pts = np.random.uniform(size=(n, bounds.shape[0]))
|
|
elif mode == 'chebyshev':
|
|
pts = np.array([
|
|
chebyshev_roots(n) * .5 + .5
|
|
for _ in range(bounds.shape[0])])
|
|
grids = np.meshgrid(*pts)
|
|
pts = np.hstack([grid.reshape(-1, 1) for grid in grids])
|
|
elif mode == 'grid':
|
|
pts = np.array([
|
|
np.linspace(0, 1, n)
|
|
for _ in range(bounds.shape[0])])
|
|
grids = np.meshgrid(*pts)
|
|
pts = np.hstack([grid.reshape(-1, 1) for grid in grids])
|
|
elif mode == 'lh' or mode == 'latin':
|
|
from scipy.stats import qmc
|
|
sampler = qmc.LatinHypercube(d=bounds.shape[0])
|
|
pts = sampler.random(n)
|
|
|
|
# Scale pts
|
|
pts *= bounds[:, 1] - bounds[:, 0]
|
|
pts += bounds[:, 0]
|
|
|
|
pts = pts.astype(np.float32)
|
|
pts = torch.from_numpy(pts)
|
|
|
|
pts_range_ = LabelTensor(pts, spatial_range_)
|
|
|
|
if not len(spatial_fixed_)==0:
|
|
pts_fixed_ = torch.ones(pts.shape[0], len(spatial_fixed_),
|
|
dtype=pts.dtype) * fixed
|
|
pts_fixed_ = pts_fixed_.float()
|
|
pts_fixed_ = LabelTensor(pts_fixed_, spatial_fixed_)
|
|
pts_range_ = pts_range_.append(pts_fixed_)
|
|
|
|
return pts_range_
|
|
|
|
|
|
def meshgrid(self, n):
|
|
pts = np.array([
|
|
np.linspace(0, 1, n)
|
|
for _ in range(self.bound.shape[0])])
|
|
|
|
pts *= self.bound[:, 1] - self.bound[:, 0]
|
|
pts += self.bound[:, 0]
|
|
|
|
return np.meshgrid(*pts)
|