device for sample points in absProblem (#132)

* device for sample points in absProblem
This commit is contained in:
Nicola Demo
2023-06-28 15:13:47 +02:00
parent 701046661f
commit f57a08b875
2 changed files with 17 additions and 3 deletions

View File

@@ -111,7 +111,7 @@ class AbstractProblem(metaclass=ABCMeta):
continue
self.input_pts[condition_name] = samples
def discretise_domain(self, n, mode = 'random', variables = 'all', locations = 'all'):
def discretise_domain(self, n, mode = 'random', variables = 'all', locations = 'all', device=None):
"""
Generate a set of points to span the `Location` of all the conditions of
the problem.
@@ -196,6 +196,10 @@ class AbstractProblem(metaclass=ABCMeta):
] + already_sampled
pts = merge_tensors(samples)
self.input_pts[location] = pts
if device:
self.input_pts[location] = self.input_pts[location].to(device=device) #TODO better fix
# setting the grad
self.input_pts[location].requires_grad_(True)
self.input_pts[location].retain_grad()