fix bug
This commit is contained in:
committed by
Nicola Demo
parent
f57a08b875
commit
6ff7c6af5b
@@ -197,15 +197,15 @@ class AbstractProblem(metaclass=ABCMeta):
|
|||||||
pts = merge_tensors(samples)
|
pts = merge_tensors(samples)
|
||||||
self.input_pts[location] = pts
|
self.input_pts[location] = pts
|
||||||
|
|
||||||
if device:
|
|
||||||
self.input_pts[location] = self.input_pts[location].to(device=device) #TODO better fix
|
|
||||||
|
|
||||||
# setting the grad
|
|
||||||
self.input_pts[location].requires_grad_(True)
|
|
||||||
self.input_pts[location].retain_grad()
|
|
||||||
# the condition is sampled if input_pts contains all labels
|
# the condition is sampled if input_pts contains all labels
|
||||||
if sorted(self.input_pts[location].labels) == sorted(self.input_variables):
|
if sorted(self.input_pts[location].labels) == sorted(self.input_variables):
|
||||||
self._have_sampled_points[location] = True
|
self._have_sampled_points[location] = True
|
||||||
|
# setting device
|
||||||
|
if device:
|
||||||
|
self.input_pts[location] = self.input_pts[location].to(device=device) #TODO better fix
|
||||||
|
# setting the grad
|
||||||
|
self.input_pts[location].requires_grad_(True)
|
||||||
|
self.input_pts[location].retain_grad()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def have_sampled_points(self):
|
def have_sampled_points(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user