From aef134cfb36445fb5c2ffb0be4b55913bc1fd63b Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Tue, 21 May 2024 10:34:11 +0200 Subject: [PATCH] :art: Format Python code with psf/black (#301) Co-authored-by: ndem0 --- pina/callbacks/adaptive_refinment_callbacks.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/pina/callbacks/adaptive_refinment_callbacks.py b/pina/callbacks/adaptive_refinment_callbacks.py index cb807ac..5af2cc8 100644 --- a/pina/callbacks/adaptive_refinment_callbacks.py +++ b/pina/callbacks/adaptive_refinment_callbacks.py @@ -95,7 +95,7 @@ class R3Refinement(Callback): # average loss avg = (tot_loss.mean()).to("cpu") - old_pts = {} # points to be retained + old_pts = {} # points to be retained for location in self._sampling_locations: pts = trainer._model.problem.input_pts[location] labels = pts.labels @@ -106,18 +106,18 @@ class R3Refinement(Callback): pts = (pts[mask]).as_subclass(LabelTensor) pts.labels = labels old_pts[location] = pts - numb_pts = self._const_pts[location] - len(old_pts[location]) + numb_pts = self._const_pts[location] - len(old_pts[location]) # sample new points trainer._model.problem.discretise_domain( numb_pts, "random", locations=[location] - ) + ) - else: # if no res greater than average, samples all uniformly + else: # if no res greater than average, samples all uniformly numb_pts = self._const_pts[location] # sample new points trainer._model.problem.discretise_domain( numb_pts, "random", locations=[location] - ) + ) # adding previous population points trainer._model.problem.add_points(old_pts) @@ -154,8 +154,6 @@ class R3Refinement(Callback): const_pts[location] = len(pts) self._const_pts = const_pts - - def on_train_epoch_end(self, trainer, __): """ Callback function called at the end of each training epoch.