diff --git a/pina/problem/abstract_problem.py b/pina/problem/abstract_problem.py index c3abf4b..eae28bf 100644 --- a/pina/problem/abstract_problem.py +++ b/pina/problem/abstract_problem.py @@ -197,15 +197,15 @@ class AbstractProblem(metaclass=ABCMeta): pts = merge_tensors(samples) self.input_pts[location] = pts - if device: - self.input_pts[location] = self.input_pts[location].to(device=device) #TODO better fix - - # setting the grad - self.input_pts[location].requires_grad_(True) - self.input_pts[location].retain_grad() # the condition is sampled if input_pts contains all labels if sorted(self.input_pts[location].labels) == sorted(self.input_variables): self._have_sampled_points[location] = True + # setting device + if device: + self.input_pts[location] = self.input_pts[location].to(device=device) #TODO better fix + # setting the grad + self.input_pts[location].requires_grad_(True) + self.input_pts[location].retain_grad() @property def have_sampled_points(self):