device for sample points in absProblem (#132)
* device for sample points in absProblem
This commit is contained in:
@@ -111,7 +111,7 @@ class AbstractProblem(metaclass=ABCMeta):
|
|||||||
continue
|
continue
|
||||||
self.input_pts[condition_name] = samples
|
self.input_pts[condition_name] = samples
|
||||||
|
|
||||||
def discretise_domain(self, n, mode = 'random', variables = 'all', locations = 'all'):
|
def discretise_domain(self, n, mode = 'random', variables = 'all', locations = 'all', device=None):
|
||||||
"""
|
"""
|
||||||
Generate a set of points to span the `Location` of all the conditions of
|
Generate a set of points to span the `Location` of all the conditions of
|
||||||
the problem.
|
the problem.
|
||||||
@@ -196,6 +196,10 @@ class AbstractProblem(metaclass=ABCMeta):
|
|||||||
] + already_sampled
|
] + already_sampled
|
||||||
pts = merge_tensors(samples)
|
pts = merge_tensors(samples)
|
||||||
self.input_pts[location] = pts
|
self.input_pts[location] = pts
|
||||||
|
|
||||||
|
if device:
|
||||||
|
self.input_pts[location] = self.input_pts[location].to(device=device) #TODO better fix
|
||||||
|
|
||||||
# setting the grad
|
# setting the grad
|
||||||
self.input_pts[location].requires_grad_(True)
|
self.input_pts[location].requires_grad_(True)
|
||||||
self.input_pts[location].retain_grad()
|
self.input_pts[location].retain_grad()
|
||||||
|
|||||||
@@ -105,8 +105,18 @@ def test_train_extra_feats_cpu():
|
|||||||
pinn = PINN(problem = poisson_problem, model=model_extra_feats, extra_features=extra_feats)
|
pinn = PINN(problem = poisson_problem, model=model_extra_feats, extra_features=extra_feats)
|
||||||
trainer = Trainer(solver=pinn, kwargs={'max_epochs' : 5, 'accelerator':'cpu'})
|
trainer = Trainer(solver=pinn, kwargs={'max_epochs' : 5, 'accelerator':'cpu'})
|
||||||
trainer.train()
|
trainer.train()
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
def test_train_gpu(): #TODO fix ASAP
|
||||||
|
poisson_problem = Poisson()
|
||||||
|
boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
|
||||||
|
n = 10
|
||||||
|
poisson_problem.discretise_domain(n, 'grid', locations=boundaries)
|
||||||
|
poisson_problem.discretise_domain(n, 'grid', locations=['D'])
|
||||||
|
poisson_problem.conditions.pop('data') # The input/output pts are allocated on cpu
|
||||||
|
pinn = PINN(problem = poisson_problem, model=model, extra_features=None, loss=LpLoss())
|
||||||
|
trainer = Trainer(solver=pinn, kwargs={'max_epochs' : 5, 'accelerator':'gpu'})
|
||||||
|
trainer.train()
|
||||||
|
|
||||||
def test_train_2():
|
def test_train_2():
|
||||||
boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
|
boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
|
||||||
n = 10
|
n = 10
|
||||||
|
|||||||
Reference in New Issue
Block a user