This commit is contained in:
Dario Coscia
2023-06-29 09:30:50 +02:00
committed by Nicola Demo
parent f57a08b875
commit 6ff7c6af5b

View File

@@ -197,15 +197,15 @@ class AbstractProblem(metaclass=ABCMeta):
pts = merge_tensors(samples)
self.input_pts[location] = pts
if device:
self.input_pts[location] = self.input_pts[location].to(device=device) #TODO better fix
# setting the grad
self.input_pts[location].requires_grad_(True)
self.input_pts[location].retain_grad()
# the condition is sampled if input_pts contains all labels
if sorted(self.input_pts[location].labels) == sorted(self.input_variables):
self._have_sampled_points[location] = True
# setting device
if device:
self.input_pts[location] = self.input_pts[location].to(device=device) #TODO better fix
# setting the grad
self.input_pts[location].requires_grad_(True)
self.input_pts[location].retain_grad()
@property
def have_sampled_points(self):