diff --git a/pina/callbacks/adaptive_refinment_callbacks.py b/pina/callbacks/adaptive_refinment_callbacks.py index cb807ac..5af2cc8 100644 --- a/pina/callbacks/adaptive_refinment_callbacks.py +++ b/pina/callbacks/adaptive_refinment_callbacks.py @@ -95,7 +95,7 @@ class R3Refinement(Callback): # average loss avg = (tot_loss.mean()).to("cpu") - old_pts = {} # points to be retained + old_pts = {} # points to be retained for location in self._sampling_locations: pts = trainer._model.problem.input_pts[location] labels = pts.labels @@ -106,18 +106,18 @@ class R3Refinement(Callback): pts = (pts[mask]).as_subclass(LabelTensor) pts.labels = labels old_pts[location] = pts - numb_pts = self._const_pts[location] - len(old_pts[location]) + numb_pts = self._const_pts[location] - len(old_pts[location]) # sample new points trainer._model.problem.discretise_domain( numb_pts, "random", locations=[location] - ) + ) - else: # if no res greater than average, samples all uniformly + else: # if no res greater than average, samples all uniformly numb_pts = self._const_pts[location] # sample new points trainer._model.problem.discretise_domain( numb_pts, "random", locations=[location] - ) + ) # adding previous population points trainer._model.problem.add_points(old_pts) @@ -154,8 +154,6 @@ class R3Refinement(Callback): const_pts[location] = len(pts) self._const_pts = const_pts - - def on_train_epoch_end(self, trainer, __): """ Callback function called at the end of each training epoch.