Fix bug in span_pts (#37)
This commit is contained in:
86
pina/span.py
86
pina/span.py
@@ -19,6 +19,8 @@ class Span(Location):
|
|||||||
else:
|
else:
|
||||||
raise TypeError
|
raise TypeError
|
||||||
|
|
||||||
|
print(span_dict, self.fixed_, self.range_, 'YYYYYYYYYY')
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def variables(self):
|
def variables(self):
|
||||||
return list(self.fixed_.keys()) + list(self.range_.keys())
|
return list(self.fixed_.keys()) + list(self.range_.keys())
|
||||||
@@ -30,43 +32,85 @@ class Span(Location):
|
|||||||
def _sample_range(self, n, mode, bounds):
|
def _sample_range(self, n, mode, bounds):
|
||||||
"""
|
"""
|
||||||
"""
|
"""
|
||||||
|
dim = bounds.shape[0]
|
||||||
|
if mode in ['chebyshev', 'grid'] and dim != 1:
|
||||||
|
raise RuntimeError('Something wrong in Span...')
|
||||||
|
|
||||||
if mode == 'random':
|
if mode == 'random':
|
||||||
pts = torch.rand(size=(n, 1))
|
pts = torch.rand(size=(n, dim))
|
||||||
elif mode == 'chebyshev':
|
elif mode == 'chebyshev':
|
||||||
pts = chebyshev_roots(n).mul(.5).add(.5).reshape(-1, 1)
|
pts = chebyshev_roots(n).mul(.5).add(.5).reshape(-1, 1)
|
||||||
elif mode == 'grid':
|
elif mode == 'grid':
|
||||||
pts = torch.linspace(0, 1, n).reshape(-1, 1)
|
pts = torch.linspace(0, 1, n).reshape(-1, 1)
|
||||||
elif mode == 'lh' or mode == 'latin':
|
elif mode == 'lh' or mode == 'latin':
|
||||||
from scipy.stats import qmc
|
from scipy.stats import qmc
|
||||||
sampler = qmc.LatinHypercube(d=1)
|
sampler = qmc.LatinHypercube(d=dim)
|
||||||
pts = sampler.random(n)
|
pts = sampler.random(n)
|
||||||
pts = torch.from_numpy(pts)
|
pts = torch.from_numpy(pts)
|
||||||
|
|
||||||
pts *= bounds[1] - bounds[0]
|
pts *= bounds[:, 1] - bounds[:, 0]
|
||||||
pts += bounds[0]
|
pts += bounds[:, 0]
|
||||||
|
|
||||||
return pts
|
return pts
|
||||||
|
|
||||||
def sample(self, n, mode='random', variables='all'):
|
def sample(self, n, mode='random', variables='all'):
|
||||||
|
"""TODO
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _1d_sampler(n, mode, variables):
|
||||||
|
""" Sample independentely the variables and cross the results"""
|
||||||
|
tmp = []
|
||||||
|
for variable in variables:
|
||||||
|
if variable in self.range_.keys():
|
||||||
|
bound = torch.tensor([self.range_[variable]])
|
||||||
|
pts_variable = self._sample_range(n, mode, bound)
|
||||||
|
pts_variable = pts_variable.as_subclass(LabelTensor)
|
||||||
|
pts_variable.labels = [variable]
|
||||||
|
|
||||||
|
tmp.append(pts_variable)
|
||||||
|
|
||||||
|
result = tmp[0]
|
||||||
|
for i in tmp[1:]:
|
||||||
|
result = result.append(i, mode='cross')
|
||||||
|
|
||||||
|
for variable in variables:
|
||||||
|
if variable in self.fixed_.keys():
|
||||||
|
value = self.fixed_[variable]
|
||||||
|
pts_variable = torch.tensor([[value]]).repeat(
|
||||||
|
result.shape[0], 1)
|
||||||
|
pts_variable = pts_variable.as_subclass(LabelTensor)
|
||||||
|
pts_variable.labels = [variable]
|
||||||
|
|
||||||
|
result = result.append(pts_variable, mode='std')
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def _Nd_sampler(n, mode, variables):
|
||||||
|
""" Sample ll the variables together """
|
||||||
|
bounds = torch.tensor(
|
||||||
|
[v for k, v in self.range_.items() if k in variables]
|
||||||
|
)
|
||||||
|
result = self._sample_range(n, mode, bounds)
|
||||||
|
result = result.as_subclass(LabelTensor)
|
||||||
|
result.labels = list(self.range_.keys())
|
||||||
|
|
||||||
|
for variable in variables:
|
||||||
|
if variable in self.fixed_.keys():
|
||||||
|
value = self.fixed_[variable]
|
||||||
|
pts_variable = torch.tensor([[value]]).repeat(
|
||||||
|
result.shape[0], 1)
|
||||||
|
pts_variable = pts_variable.as_subclass(LabelTensor)
|
||||||
|
pts_variable.labels = [variable]
|
||||||
|
|
||||||
|
result = result.append(pts_variable, mode='std')
|
||||||
|
return result
|
||||||
|
|
||||||
if variables == 'all':
|
if variables == 'all':
|
||||||
variables = list(self.range_.keys()) + list(self.fixed_.keys())
|
variables = list(self.range_.keys()) + list(self.fixed_.keys())
|
||||||
|
|
||||||
result = None
|
if mode in ['grid', 'chebyshev']:
|
||||||
for variable in variables:
|
return _1d_sampler(n, mode, variables)
|
||||||
if variable in self.range_.keys():
|
elif mode in ['random', 'lhs']:
|
||||||
bound = torch.tensor(self.range_[variable])
|
return _Nd_sampler(n, mode, variables)
|
||||||
pts_variable = self._sample_range(n, mode, bound)
|
|
||||||
pts_variable = LabelTensor(pts_variable, [variable])
|
|
||||||
|
|
||||||
elif variable in self.fixed_.keys():
|
|
||||||
value = self.fixed_[variable]
|
|
||||||
pts_variable = LabelTensor(torch.ones(n, 1)*value, [variable])
|
|
||||||
|
|
||||||
if result is None:
|
|
||||||
result = pts_variable
|
|
||||||
else:
|
else:
|
||||||
intersect = 'std' if mode == 'random' else 'cross'
|
raise ValueError(f'mode={mode} is not valid.')
|
||||||
result = result.append(pts_variable, intersect)
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|||||||
61
tests/test_pinn.py
Normal file
61
tests/test_pinn.py
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
import torch
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from pina import LabelTensor, Condition, Span, PINN
|
||||||
|
from pina.problem import SpatialProblem
|
||||||
|
from pina.model import FeedForward
|
||||||
|
from pina.operators import nabla
|
||||||
|
|
||||||
|
|
||||||
|
class Poisson(SpatialProblem):
|
||||||
|
output_variables = ['u']
|
||||||
|
spatial_domain = Span({'x': [0, 1], 'y': [0, 1]})
|
||||||
|
|
||||||
|
def laplace_equation(input_, output_):
|
||||||
|
force_term = (torch.sin(input_.extract(['x'])*torch.pi) *
|
||||||
|
torch.sin(input_.extract(['y'])*torch.pi))
|
||||||
|
nabla_u = nabla(output_, input_, components=['u'], d=['x', 'y'])
|
||||||
|
return nabla_u - force_term
|
||||||
|
|
||||||
|
def nil_dirichlet(input_, output_):
|
||||||
|
value = 0.0
|
||||||
|
return output_.extract(['u']) - value
|
||||||
|
|
||||||
|
conditions = {
|
||||||
|
'gamma1': Condition(Span({'x': [0, 1], 'y': 1}), nil_dirichlet),
|
||||||
|
'gamma2': Condition(Span({'x': [0, 1], 'y': 0}), nil_dirichlet),
|
||||||
|
'gamma3': Condition(Span({'x': 1, 'y': [0, 1]}), nil_dirichlet),
|
||||||
|
'gamma4': Condition(Span({'x': 0, 'y': [0, 1]}), nil_dirichlet),
|
||||||
|
'D': Condition(Span({'x': [0, 1], 'y': [0, 1]}), laplace_equation),
|
||||||
|
}
|
||||||
|
|
||||||
|
def poisson_sol(self, pts):
|
||||||
|
return -(
|
||||||
|
torch.sin(pts.extract(['x'])*torch.pi)*
|
||||||
|
torch.sin(pts.extract(['y'])*torch.pi)
|
||||||
|
)/(2*torch.pi**2)
|
||||||
|
|
||||||
|
truth_solution = poisson_sol
|
||||||
|
|
||||||
|
problem = Poisson()
|
||||||
|
model = FeedForward(2, 1)
|
||||||
|
|
||||||
|
|
||||||
|
def test_constructor():
|
||||||
|
PINN(problem, model)
|
||||||
|
|
||||||
|
def test_span_pts():
|
||||||
|
pinn = PINN(problem, model)
|
||||||
|
n = 10
|
||||||
|
boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
|
||||||
|
pinn.span_pts(n, 'grid', boundaries)
|
||||||
|
for b in boundaries:
|
||||||
|
assert pinn.input_pts[b].shape[0] == n
|
||||||
|
pinn.span_pts(n, 'random', boundaries)
|
||||||
|
for b in boundaries:
|
||||||
|
assert pinn.input_pts[b].shape[0] == n
|
||||||
|
|
||||||
|
pinn.span_pts(n, 'grid', locations=['D'])
|
||||||
|
assert pinn.input_pts['D'].shape[0] == n**2
|
||||||
|
pinn.span_pts(n, 'random', locations=['D'])
|
||||||
|
assert pinn.input_pts['D'].shape[0] == n
|
||||||
Reference in New Issue
Block a user