solving span bugs (#57)
This commit is contained in:
@@ -70,6 +70,32 @@ def test_span_pts():
|
||||
assert pinn.input_pts['D'].shape[0] == n
|
||||
|
||||
|
||||
def test_sampling_all_args():
|
||||
pinn = PINN(problem, model)
|
||||
n = 10
|
||||
pinn.span_pts(n, 'grid', locations=['D'])
|
||||
|
||||
|
||||
def test_sampling_all_kwargs():
|
||||
pinn = PINN(problem, model)
|
||||
n = 10
|
||||
pinn.span_pts(n=n, mode='latin', locations=['D'])
|
||||
|
||||
|
||||
def test_sampling_dict():
|
||||
pinn = PINN(problem, model)
|
||||
n = 10
|
||||
pinn.span_pts(
|
||||
{'variables': ['x', 'y'], 'mode': 'grid', 'n': n}, locations=['D'])
|
||||
|
||||
|
||||
def test_sampling_mixed_args_kwargs():
|
||||
pinn = PINN(problem, model)
|
||||
n = 10
|
||||
with pytest.raises(ValueError):
|
||||
pinn.span_pts(n, mode='latin', locations=['D'])
|
||||
|
||||
|
||||
def test_train():
|
||||
pinn = PINN(problem, model)
|
||||
boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
|
||||
@@ -130,4 +156,4 @@ if torch.cuda.is_available():
|
||||
n = 100
|
||||
pinn.span_pts(n, 'grid', boundaries)
|
||||
pinn.span_pts(n, 'grid', locations=['D'])
|
||||
pinn.train(5)
|
||||
pinn.train(5)
|
||||
|
||||
Reference in New Issue
Block a user