solving span bugs (#57)
This commit is contained in:
17
pina/pinn.py
17
pina/pinn.py
@@ -173,7 +173,16 @@ class PINN(object):
|
|||||||
>>> pinn.span_pts(n=10, mode='grid', location=['bound1'])
|
>>> pinn.span_pts(n=10, mode='grid', location=['bound1'])
|
||||||
>>> pinn.span_pts(n=10, mode='grid', variables=['x'])
|
>>> pinn.span_pts(n=10, mode='grid', variables=['x'])
|
||||||
"""
|
"""
|
||||||
if isinstance(args[0], int) and isinstance(args[1], str):
|
|
||||||
|
if all(key in kwargs for key in ['n', 'mode']):
|
||||||
|
argument = {}
|
||||||
|
argument['n'] = kwargs['n']
|
||||||
|
argument['mode'] = kwargs['mode']
|
||||||
|
argument['variables'] = self.problem.input_variables
|
||||||
|
arguments = [argument]
|
||||||
|
elif any(key in kwargs for key in ['n', 'mode']) and args:
|
||||||
|
raise ValueError("Don't mix args and kwargs")
|
||||||
|
elif isinstance(args[0], int) and isinstance(args[1], str):
|
||||||
argument = {}
|
argument = {}
|
||||||
argument['n'] = int(args[0])
|
argument['n'] = int(args[0])
|
||||||
argument['mode'] = args[1]
|
argument['mode'] = args[1]
|
||||||
@@ -181,12 +190,6 @@ class PINN(object):
|
|||||||
arguments = [argument]
|
arguments = [argument]
|
||||||
elif all(isinstance(arg, dict) for arg in args):
|
elif all(isinstance(arg, dict) for arg in args):
|
||||||
arguments = args
|
arguments = args
|
||||||
elif all(key in kwargs for key in ['n', 'mode']):
|
|
||||||
argument = {}
|
|
||||||
argument['n'] = kwargs['n']
|
|
||||||
argument['mode'] = kwargs['mode']
|
|
||||||
argument['variables'] = self.problem.input_variables
|
|
||||||
arguments = [argument]
|
|
||||||
else:
|
else:
|
||||||
raise RuntimeError
|
raise RuntimeError
|
||||||
|
|
||||||
|
|||||||
19
pina/span.py
19
pina/span.py
@@ -100,6 +100,25 @@ class Span(Location):
|
|||||||
result = result.append(pts_variable, mode='std')
|
result = result.append(pts_variable, mode='std')
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
def _single_points_sample(n, variables):
|
||||||
|
tmp = []
|
||||||
|
for variable in variables:
|
||||||
|
if variable in self.fixed_.keys():
|
||||||
|
value = self.fixed_[variable]
|
||||||
|
pts_variable = torch.tensor([[value]]).repeat(n, 1)
|
||||||
|
pts_variable = pts_variable.as_subclass(LabelTensor)
|
||||||
|
pts_variable.labels = [variable]
|
||||||
|
tmp.append(pts_variable)
|
||||||
|
|
||||||
|
result = tmp[0]
|
||||||
|
for i in tmp[1:]:
|
||||||
|
result = result.append(i, mode='std')
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
if self.fixed_ and (not self.range_):
|
||||||
|
return _single_points_sample(n, variables)
|
||||||
|
|
||||||
if variables == 'all':
|
if variables == 'all':
|
||||||
variables = list(self.range_.keys()) + list(self.fixed_.keys())
|
variables = list(self.range_.keys()) + list(self.fixed_.keys())
|
||||||
|
|
||||||
|
|||||||
@@ -70,6 +70,32 @@ def test_span_pts():
|
|||||||
assert pinn.input_pts['D'].shape[0] == n
|
assert pinn.input_pts['D'].shape[0] == n
|
||||||
|
|
||||||
|
|
||||||
|
def test_sampling_all_args():
|
||||||
|
pinn = PINN(problem, model)
|
||||||
|
n = 10
|
||||||
|
pinn.span_pts(n, 'grid', locations=['D'])
|
||||||
|
|
||||||
|
|
||||||
|
def test_sampling_all_kwargs():
|
||||||
|
pinn = PINN(problem, model)
|
||||||
|
n = 10
|
||||||
|
pinn.span_pts(n=n, mode='latin', locations=['D'])
|
||||||
|
|
||||||
|
|
||||||
|
def test_sampling_dict():
|
||||||
|
pinn = PINN(problem, model)
|
||||||
|
n = 10
|
||||||
|
pinn.span_pts(
|
||||||
|
{'variables': ['x', 'y'], 'mode': 'grid', 'n': n}, locations=['D'])
|
||||||
|
|
||||||
|
|
||||||
|
def test_sampling_mixed_args_kwargs():
|
||||||
|
pinn = PINN(problem, model)
|
||||||
|
n = 10
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
pinn.span_pts(n, mode='latin', locations=['D'])
|
||||||
|
|
||||||
|
|
||||||
def test_train():
|
def test_train():
|
||||||
pinn = PINN(problem, model)
|
pinn = PINN(problem, model)
|
||||||
boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
|
boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
|
||||||
@@ -130,4 +156,4 @@ if torch.cuda.is_available():
|
|||||||
n = 100
|
n = 100
|
||||||
pinn.span_pts(n, 'grid', boundaries)
|
pinn.span_pts(n, 'grid', boundaries)
|
||||||
pinn.span_pts(n, 'grid', locations=['D'])
|
pinn.span_pts(n, 'grid', locations=['D'])
|
||||||
pinn.train(5)
|
pinn.train(5)
|
||||||
|
|||||||
Reference in New Issue
Block a user