Refactoring code
This commit is contained in:
19
pina/cube.py
19
pina/cube.py
@@ -1,6 +1,7 @@
|
||||
import numpy as np
|
||||
from .chebyshev import chebyshev_roots
|
||||
|
||||
|
||||
class Cube():
|
||||
def __init__(self, bound):
|
||||
self.bound = np.asarray(bound)
|
||||
@@ -10,11 +11,15 @@ class Cube():
|
||||
if mode == 'random':
|
||||
pts = np.random.uniform(size=(n, self.bound.shape[0]))
|
||||
elif mode == 'chebyshev':
|
||||
pts = np.array([chebyshev_roots(n) *.5 + .5 for _ in range(self.bound.shape[0])])
|
||||
pts = np.array([
|
||||
chebyshev_roots(n) * .5 + .5
|
||||
for _ in range(self.bound.shape[0])])
|
||||
grids = np.meshgrid(*pts)
|
||||
pts = np.hstack([grid.reshape(-1, 1) for grid in grids])
|
||||
elif mode == 'grid':
|
||||
pts = np.array([np.linspace(0, 1, n) for _ in range(self.bound.shape[0])])
|
||||
pts = np.array([
|
||||
np.linspace(0, 1, n)
|
||||
for _ in range(self.bound.shape[0])])
|
||||
grids = np.meshgrid(*pts)
|
||||
pts = np.hstack([grid.reshape(-1, 1) for grid in grids])
|
||||
elif mode == 'lh' or mode == 'latin':
|
||||
@@ -27,3 +32,13 @@ class Cube():
|
||||
pts += self.bound[:, 0]
|
||||
|
||||
return pts
|
||||
|
||||
def meshgrid(self, n):
|
||||
pts = np.array([
|
||||
np.linspace(0, 1, n)
|
||||
for _ in range(self.bound.shape[0])])
|
||||
|
||||
pts *= self.bound[:, 1] - self.bound[:, 0]
|
||||
pts += self.bound[:, 0]
|
||||
|
||||
return np.meshgrid(*pts)
|
||||
|
||||
Reference in New Issue
Block a user