batch_enhancement (#51)

This commit is contained in:
Dario Coscia
2022-12-12 11:09:20 +01:00
committed by GitHub
parent d70f5e730a
commit dbd78c9cf3
4 changed files with 236 additions and 59 deletions

View File

@@ -31,19 +31,22 @@ class Poisson(SpatialProblem):
def poisson_sol(self, pts):
return -(
torch.sin(pts.extract(['x'])*torch.pi)*
torch.sin(pts.extract(['x'])*torch.pi) *
torch.sin(pts.extract(['y'])*torch.pi)
)/(2*torch.pi**2)
truth_solution = poisson_sol
problem = Poisson()
model = FeedForward(problem.input_variables, problem.output_variables)
def test_constructor():
PINN(problem, model)
def test_span_pts():
pinn = PINN(problem, model)
n = 10
@@ -60,6 +63,7 @@ def test_span_pts():
pinn.span_pts(n, 'random', locations=['D'])
assert pinn.input_pts['D'].shape[0] == n
def test_train():
pinn = PINN(problem, model)
boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
@@ -68,6 +72,7 @@ def test_train():
pinn.span_pts(n, 'grid', locations=['D'])
pinn.train(5)
def test_train():
boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
n = 10
@@ -78,4 +83,45 @@ def test_train():
pinn.span_pts(n, 'grid', boundaries)
pinn.span_pts(n, 'grid', locations=['D'])
pinn.train(50, save_loss=i)
assert list(pinn.history_loss.keys()) == truth_key
assert list(pinn.history_loss.keys()) == truth_key
def test_train_batch():
pinn = PINN(problem, model, batch_size=6)
boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
n = 10
pinn.span_pts(n, 'grid', boundaries)
pinn.span_pts(n, 'grid', locations=['D'])
pinn.train(5)
def test_train_batch():
boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
n = 10
expected_keys = [[], list(range(0, 50, 3))]
param = [0, 3]
for i, truth_key in zip(param, expected_keys):
pinn = PINN(problem, model, batch_size=6)
pinn.span_pts(n, 'grid', boundaries)
pinn.span_pts(n, 'grid', locations=['D'])
pinn.train(50, save_loss=i)
assert list(pinn.history_loss.keys()) == truth_key
if torch.cuda.is_available():
def test_gpu_train():
pinn = PINN(problem, model, batch_size=20, device='cuda')
boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
n = 100
pinn.span_pts(n, 'grid', boundaries)
pinn.span_pts(n, 'grid', locations=['D'])
pinn.train(5)
def test_gpu_train_nobatch():
pinn = PINN(problem, model, batch_size=None, device='cuda')
boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
n = 100
pinn.span_pts(n, 'grid', boundaries)
pinn.span_pts(n, 'grid', locations=['D'])
pinn.train(5)