Span (#24)
* Changing chebyshev implementation to remove numpy dependencies * Update span.py
This commit is contained in:
@@ -1,7 +1,9 @@
|
|||||||
import numpy as np
|
import torch
|
||||||
|
|
||||||
|
|
||||||
def chebyshev_roots(n):
|
def chebyshev_roots(n):
|
||||||
""" Return the roots of *n* Chebyshev polynomials (between [-1, 1]) """
|
""" Return the roots of *n* Chebyshev polynomials (between [-1, 1]) """
|
||||||
coefficents = np.zeros(n+1)
|
pi = torch.acos(torch.zeros(1)).item() * 2
|
||||||
coefficents[-1] = 1
|
k = torch.arange(n)
|
||||||
return np.polynomial.chebyshev.chebroots(coefficents)
|
nodes = torch.sort(torch.cos(pi * (k + 0.5) / n))[0]
|
||||||
|
return nodes
|
||||||
|
|||||||
14
pina/span.py
14
pina/span.py
@@ -1,4 +1,3 @@
|
|||||||
import numpy as np
|
|
||||||
from .chebyshev import chebyshev_roots
|
from .chebyshev import chebyshev_roots
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@@ -32,20 +31,20 @@ class Span(Location):
|
|||||||
"""
|
"""
|
||||||
"""
|
"""
|
||||||
if mode == 'random':
|
if mode == 'random':
|
||||||
pts = np.random.uniform(size=(n, 1))
|
pts = torch.rand(size=(n, 1))
|
||||||
elif mode == 'chebyshev':
|
elif mode == 'chebyshev':
|
||||||
pts = np.array([chebyshev_roots(n) * .5 + .5]).reshape(-1, 1)
|
pts = chebyshev_roots(n).mul(.5).add(.5).reshape(-1, 1)
|
||||||
elif mode == 'grid':
|
elif mode == 'grid':
|
||||||
pts = np.linspace(0, 1, n).reshape(-1, 1)
|
pts = torch.linspace(0, 1, n).reshape(-1, 1)
|
||||||
elif mode == 'lh' or mode == 'latin':
|
elif mode == 'lh' or mode == 'latin':
|
||||||
from scipy.stats import qmc
|
from scipy.stats import qmc
|
||||||
sampler = qmc.LatinHypercube(d=1)
|
sampler = qmc.LatinHypercube(d=1)
|
||||||
pts = sampler.random(n)
|
pts = sampler.random(n)
|
||||||
|
pts = torch.from_numpy(pts)
|
||||||
|
|
||||||
pts *= bounds[1] - bounds[0]
|
pts *= bounds[1] - bounds[0]
|
||||||
pts += bounds[0]
|
pts += bounds[0]
|
||||||
|
|
||||||
pts = pts.astype(np.float32)
|
|
||||||
return pts
|
return pts
|
||||||
|
|
||||||
def sample(self, n, mode='random', variables='all'):
|
def sample(self, n, mode='random', variables='all'):
|
||||||
@@ -56,10 +55,9 @@ class Span(Location):
|
|||||||
result = None
|
result = None
|
||||||
for variable in variables:
|
for variable in variables:
|
||||||
if variable in self.range_.keys():
|
if variable in self.range_.keys():
|
||||||
bound = np.asarray(self.range_[variable])
|
bound = torch.tensor(self.range_[variable])
|
||||||
pts_variable = self._sample_range(n, mode, bound)
|
pts_variable = self._sample_range(n, mode, bound)
|
||||||
pts_variable = LabelTensor(
|
pts_variable = LabelTensor(pts_variable, [variable])
|
||||||
torch.from_numpy(pts_variable), [variable])
|
|
||||||
|
|
||||||
elif variable in self.fixed_.keys():
|
elif variable in self.fixed_.keys():
|
||||||
value = self.fixed_[variable]
|
value = self.fixed_[variable]
|
||||||
|
|||||||
Reference in New Issue
Block a user