lh solved (#55)
This commit is contained in:
10
pina/span.py
10
pina/span.py
@@ -3,6 +3,7 @@ import torch
|
||||
|
||||
from .location import Location
|
||||
from .label_tensor import LabelTensor
|
||||
from .utils import torch_lhs
|
||||
|
||||
|
||||
class Span(Location):
|
||||
@@ -41,10 +42,7 @@ class Span(Location):
|
||||
elif mode == 'grid':
|
||||
pts = torch.linspace(0, 1, n).reshape(-1, 1)
|
||||
elif mode == 'lh' or mode == 'latin':
|
||||
from scipy.stats import qmc
|
||||
sampler = qmc.LatinHypercube(d=dim)
|
||||
pts = sampler.random(n)
|
||||
pts = torch.from_numpy(pts)
|
||||
pts = torch_lhs(n, dim)
|
||||
|
||||
pts *= bounds[:, 1] - bounds[:, 0]
|
||||
pts += bounds[:, 0]
|
||||
@@ -83,7 +81,7 @@ class Span(Location):
|
||||
return result
|
||||
|
||||
def _Nd_sampler(n, mode, variables):
|
||||
""" Sample ll the variables together """
|
||||
""" Sample all the variables together """
|
||||
pairs = [(k, v) for k, v in self.range_.items() if k in variables]
|
||||
keys, values = map(list, zip(*pairs))
|
||||
bounds = torch.tensor(values)
|
||||
@@ -107,7 +105,7 @@ class Span(Location):
|
||||
|
||||
if mode in ['grid', 'chebyshev']:
|
||||
return _1d_sampler(n, mode, variables)
|
||||
elif mode in ['random', 'lhs']:
|
||||
elif mode in ['random', 'lh', 'latin']:
|
||||
return _Nd_sampler(n, mode, variables)
|
||||
else:
|
||||
raise ValueError(f'mode={mode} is not valid.')
|
||||
|
||||
Reference in New Issue
Block a user