use LabelTensor, fix minor, docs

This commit is contained in:
Your Name
2022-03-29 18:05:26 +02:00
parent 12f4084d7f
commit 6b001c6c53
19 changed files with 370 additions and 322 deletions

View File

@@ -37,7 +37,6 @@ class Span(Location):
for _ in range(bounds.shape[0])])
grids = np.meshgrid(*pts)
pts = np.hstack([grid.reshape(-1, 1) for grid in grids])
print(pts)
elif mode == 'lh' or mode == 'latin':
from scipy.stats import qmc
sampler = qmc.LatinHypercube(d=bounds.shape[0])
@@ -46,15 +45,17 @@ class Span(Location):
# Scale pts
pts *= bounds[:, 1] - bounds[:, 0]
pts += bounds[:, 0]
pts = pts.astype(np.float32)
pts = torch.from_numpy(pts)
pts_range_ = LabelTensor(pts, list(self.range_.keys()))
fixed = torch.Tensor(list(self.fixed_.values()))
pts_fixed_ = torch.ones(pts_range_.tensor.shape[0], len(self.fixed_)) * fixed
pts_fixed_ = torch.ones(pts.shape[0], len(self.fixed_),
dtype=pts.dtype) * fixed
pts_range_ = LabelTensor(pts, list(self.range_.keys()))
pts_fixed_ = LabelTensor(pts_fixed_, list(self.fixed_.keys()))
if self.fixed_:
return LabelTensor.hstack([pts_range_, pts_fixed_])
return pts_range_.append(pts_fixed_)
else:
return pts_range_