device for sample points in absProblem (#132)

* device for sample points in absProblem
This commit is contained in:
Nicola Demo
2023-06-28 15:13:47 +02:00
parent 701046661f
commit f57a08b875
2 changed files with 17 additions and 3 deletions

View File

@@ -105,8 +105,18 @@ def test_train_extra_feats_cpu():
pinn = PINN(problem = poisson_problem, model=model_extra_feats, extra_features=extra_feats)
trainer = Trainer(solver=pinn, kwargs={'max_epochs' : 5, 'accelerator':'cpu'})
trainer.train()
"""
def test_train_gpu(): #TODO fix ASAP
poisson_problem = Poisson()
boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
n = 10
poisson_problem.discretise_domain(n, 'grid', locations=boundaries)
poisson_problem.discretise_domain(n, 'grid', locations=['D'])
poisson_problem.conditions.pop('data') # The input/output pts are allocated on cpu
pinn = PINN(problem = poisson_problem, model=model, extra_features=None, loss=LpLoss())
trainer = Trainer(solver=pinn, kwargs={'max_epochs' : 5, 'accelerator':'gpu'})
trainer.train()
def test_train_2():
boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
n = 10
@@ -212,4 +222,4 @@ if torch.cuda.is_available():
pinn.discretise_domain(n, 'grid', locations=boundaries)
pinn.discretise_domain(n, 'grid', locations=['D'])
pinn.train(5)
"""
"""