diff --git a/pina/chebyshev.py b/pina/chebyshev.py index 89cb1e6..d9f7e4b 100644 --- a/pina/chebyshev.py +++ b/pina/chebyshev.py @@ -1,7 +1,9 @@ -import numpy as np +import torch + + def chebyshev_roots(n): """ Return the roots of *n* Chebyshev polynomials (between [-1, 1]) """ - coefficents = np.zeros(n+1) - coefficents[-1] = 1 - return np.polynomial.chebyshev.chebroots(coefficents) - + pi = torch.acos(torch.zeros(1)).item() * 2 + k = torch.arange(n) + nodes = torch.sort(torch.cos(pi * (k + 0.5) / n))[0] + return nodes diff --git a/pina/span.py b/pina/span.py index 9f83ce7..046959c 100644 --- a/pina/span.py +++ b/pina/span.py @@ -1,4 +1,3 @@ -import numpy as np from .chebyshev import chebyshev_roots import torch @@ -32,20 +31,20 @@ class Span(Location): """ """ if mode == 'random': - pts = np.random.uniform(size=(n, 1)) + pts = torch.rand(size=(n, 1)) elif mode == 'chebyshev': - pts = np.array([chebyshev_roots(n) * .5 + .5]).reshape(-1, 1) + pts = chebyshev_roots(n).mul(.5).add(.5).reshape(-1, 1) elif mode == 'grid': - pts = np.linspace(0, 1, n).reshape(-1, 1) + pts = torch.linspace(0, 1, n).reshape(-1, 1) elif mode == 'lh' or mode == 'latin': from scipy.stats import qmc sampler = qmc.LatinHypercube(d=1) pts = sampler.random(n) + pts = torch.from_numpy(pts) pts *= bounds[1] - bounds[0] pts += bounds[0] - pts = pts.astype(np.float32) return pts def sample(self, n, mode='random', variables='all'): @@ -56,10 +55,9 @@ class Span(Location): result = None for variable in variables: if variable in self.range_.keys(): - bound = np.asarray(self.range_[variable]) + bound = torch.tensor(self.range_[variable]) pts_variable = self._sample_range(n, mode, bound) - pts_variable = LabelTensor( - torch.from_numpy(pts_variable), [variable]) + pts_variable = LabelTensor(pts_variable, [variable]) elif variable in self.fixed_.keys(): value = self.fixed_[variable]