Span (#24)
* Changing chebyshev implementation to remove numpy dependencies * Update span.py
This commit is contained in:
@@ -1,7 +1,9 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def chebyshev_roots(n):
|
||||
""" Return the roots of *n* Chebyshev polynomials (between [-1, 1]) """
|
||||
coefficents = np.zeros(n+1)
|
||||
coefficents[-1] = 1
|
||||
return np.polynomial.chebyshev.chebroots(coefficents)
|
||||
|
||||
pi = torch.acos(torch.zeros(1)).item() * 2
|
||||
k = torch.arange(n)
|
||||
nodes = torch.sort(torch.cos(pi * (k + 0.5) / n))[0]
|
||||
return nodes
|
||||
|
||||
14
pina/span.py
14
pina/span.py
@@ -1,4 +1,3 @@
|
||||
import numpy as np
|
||||
from .chebyshev import chebyshev_roots
|
||||
import torch
|
||||
|
||||
@@ -32,20 +31,20 @@ class Span(Location):
|
||||
"""
|
||||
"""
|
||||
if mode == 'random':
|
||||
pts = np.random.uniform(size=(n, 1))
|
||||
pts = torch.rand(size=(n, 1))
|
||||
elif mode == 'chebyshev':
|
||||
pts = np.array([chebyshev_roots(n) * .5 + .5]).reshape(-1, 1)
|
||||
pts = chebyshev_roots(n).mul(.5).add(.5).reshape(-1, 1)
|
||||
elif mode == 'grid':
|
||||
pts = np.linspace(0, 1, n).reshape(-1, 1)
|
||||
pts = torch.linspace(0, 1, n).reshape(-1, 1)
|
||||
elif mode == 'lh' or mode == 'latin':
|
||||
from scipy.stats import qmc
|
||||
sampler = qmc.LatinHypercube(d=1)
|
||||
pts = sampler.random(n)
|
||||
pts = torch.from_numpy(pts)
|
||||
|
||||
pts *= bounds[1] - bounds[0]
|
||||
pts += bounds[0]
|
||||
|
||||
pts = pts.astype(np.float32)
|
||||
return pts
|
||||
|
||||
def sample(self, n, mode='random', variables='all'):
|
||||
@@ -56,10 +55,9 @@ class Span(Location):
|
||||
result = None
|
||||
for variable in variables:
|
||||
if variable in self.range_.keys():
|
||||
bound = np.asarray(self.range_[variable])
|
||||
bound = torch.tensor(self.range_[variable])
|
||||
pts_variable = self._sample_range(n, mode, bound)
|
||||
pts_variable = LabelTensor(
|
||||
torch.from_numpy(pts_variable), [variable])
|
||||
pts_variable = LabelTensor(pts_variable, [variable])
|
||||
|
||||
elif variable in self.fixed_.keys():
|
||||
value = self.fixed_[variable]
|
||||
|
||||
Reference in New Issue
Block a user