batch_enhancement (#51)
This commit is contained in:
@@ -31,19 +31,22 @@ class Poisson(SpatialProblem):
|
||||
|
||||
def poisson_sol(self, pts):
|
||||
return -(
|
||||
torch.sin(pts.extract(['x'])*torch.pi)*
|
||||
torch.sin(pts.extract(['x'])*torch.pi) *
|
||||
torch.sin(pts.extract(['y'])*torch.pi)
|
||||
)/(2*torch.pi**2)
|
||||
|
||||
truth_solution = poisson_sol
|
||||
|
||||
|
||||
problem = Poisson()
|
||||
|
||||
model = FeedForward(problem.input_variables, problem.output_variables)
|
||||
|
||||
|
||||
def test_constructor():
|
||||
PINN(problem, model)
|
||||
|
||||
|
||||
def test_span_pts():
|
||||
pinn = PINN(problem, model)
|
||||
n = 10
|
||||
@@ -60,6 +63,7 @@ def test_span_pts():
|
||||
pinn.span_pts(n, 'random', locations=['D'])
|
||||
assert pinn.input_pts['D'].shape[0] == n
|
||||
|
||||
|
||||
def test_train():
|
||||
pinn = PINN(problem, model)
|
||||
boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
|
||||
@@ -68,6 +72,7 @@ def test_train():
|
||||
pinn.span_pts(n, 'grid', locations=['D'])
|
||||
pinn.train(5)
|
||||
|
||||
|
||||
def test_train():
|
||||
boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
|
||||
n = 10
|
||||
@@ -78,4 +83,45 @@ def test_train():
|
||||
pinn.span_pts(n, 'grid', boundaries)
|
||||
pinn.span_pts(n, 'grid', locations=['D'])
|
||||
pinn.train(50, save_loss=i)
|
||||
assert list(pinn.history_loss.keys()) == truth_key
|
||||
assert list(pinn.history_loss.keys()) == truth_key
|
||||
|
||||
|
||||
def test_train_batch():
|
||||
pinn = PINN(problem, model, batch_size=6)
|
||||
boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
|
||||
n = 10
|
||||
pinn.span_pts(n, 'grid', boundaries)
|
||||
pinn.span_pts(n, 'grid', locations=['D'])
|
||||
pinn.train(5)
|
||||
|
||||
|
||||
def test_train_batch():
|
||||
boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
|
||||
n = 10
|
||||
expected_keys = [[], list(range(0, 50, 3))]
|
||||
param = [0, 3]
|
||||
for i, truth_key in zip(param, expected_keys):
|
||||
pinn = PINN(problem, model, batch_size=6)
|
||||
pinn.span_pts(n, 'grid', boundaries)
|
||||
pinn.span_pts(n, 'grid', locations=['D'])
|
||||
pinn.train(50, save_loss=i)
|
||||
assert list(pinn.history_loss.keys()) == truth_key
|
||||
|
||||
|
||||
if torch.cuda.is_available():
|
||||
|
||||
def test_gpu_train():
|
||||
pinn = PINN(problem, model, batch_size=20, device='cuda')
|
||||
boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
|
||||
n = 100
|
||||
pinn.span_pts(n, 'grid', boundaries)
|
||||
pinn.span_pts(n, 'grid', locations=['D'])
|
||||
pinn.train(5)
|
||||
|
||||
def test_gpu_train_nobatch():
|
||||
pinn = PINN(problem, model, batch_size=None, device='cuda')
|
||||
boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
|
||||
n = 100
|
||||
pinn.span_pts(n, 'grid', boundaries)
|
||||
pinn.span_pts(n, 'grid', locations=['D'])
|
||||
pinn.train(5)
|
||||
|
||||
Reference in New Issue
Block a user