From 5f89968805bbf24d830797ef2a7149a36944d7da Mon Sep 17 00:00:00 2001 From: Michele Alessi <108744550+alessimichele@users.noreply.github.com> Date: Tue, 21 May 2024 09:51:13 +0200 Subject: [PATCH] adaptive refinement callback (#299) Fixed problem of non-constant number of points --- .../callbacks/adaptive_refinment_callbacks.py | 46 ++++++++----------- .../test_adaptive_refinment_callbacks.py | 8 ++-- 2 files changed, 24 insertions(+), 30 deletions(-) diff --git a/pina/callbacks/adaptive_refinment_callbacks.py b/pina/callbacks/adaptive_refinment_callbacks.py index ec9984d..cb807ac 100644 --- a/pina/callbacks/adaptive_refinment_callbacks.py +++ b/pina/callbacks/adaptive_refinment_callbacks.py @@ -38,6 +38,7 @@ class R3Refinement(Callback): # sample every check_consistency(sample_every, int) self._sample_every = sample_every + self._const_pts = None def _compute_residual(self, trainer): """ @@ -94,40 +95,29 @@ class R3Refinement(Callback): # average loss avg = (tot_loss.mean()).to("cpu") - - # points to keep - old_pts = {} - tot_points = 0 + old_pts = {} # points to be retained for location in self._sampling_locations: pts = trainer._model.problem.input_pts[location] labels = pts.labels pts = pts.cpu().detach().as_subclass(torch.Tensor) residuals = res_loss[location].cpu() mask = (residuals > avg).flatten() - if any( - mask - ): # if there are residuals greater than averge we append them - # Fix the issue, masking remove labels + if any(mask): # append residuals greater than average pts = (pts[mask]).as_subclass(LabelTensor) pts.labels = labels old_pts[location] = pts - tot_points += len(pts) + numb_pts = self._const_pts[location] - len(old_pts[location]) + # sample new points + trainer._model.problem.discretise_domain( + numb_pts, "random", locations=[location] + ) - # extract new points to sample uniformally for each location - n_points = (self._tot_pop_numb - tot_points) // len( - self._sampling_locations - ) - remainder = (self._tot_pop_numb - tot_points) % len( - self._sampling_locations - ) - n_uniform_points = [n_points] * len(self._sampling_locations) - n_uniform_points[-1] += remainder - - # sample new points - for numb_pts, loc in zip(n_uniform_points, self._sampling_locations): - trainer._model.problem.discretise_domain( - numb_pts, "random", locations=[loc] - ) + else: # if no res greater than average, samples all uniformly + numb_pts = self._const_pts[location] + # sample new points + trainer._model.problem.discretise_domain( + numb_pts, "random", locations=[location] + ) # adding previous population points trainer._model.problem.add_points(old_pts) @@ -158,11 +148,13 @@ class R3Refinement(Callback): self._sampling_locations = locations # extract total population - total_population = 0 + const_pts = {} # for each location, store the # of pts to keep constant for location in self._sampling_locations: pts = trainer._model.problem.input_pts[location] - total_population += len(pts) - self._tot_pop_numb = total_population + const_pts[location] = len(pts) + self._const_pts = const_pts + + def on_train_epoch_end(self, trainer, __): """ diff --git a/tests/test_callbacks/test_adaptive_refinment_callbacks.py b/tests/test_callbacks/test_adaptive_refinment_callbacks.py index 214257d..e5c46a1 100644 --- a/tests/test_callbacks/test_adaptive_refinment_callbacks.py +++ b/tests/test_callbacks/test_adaptive_refinment_callbacks.py @@ -75,13 +75,15 @@ def test_r3refinment_routine(): max_epochs=5) trainer.train() -def test_r3refinment_routine_double_precision(): +def test_r3refinment_routine(): model = FeedForward(len(poisson_problem.input_variables), len(poisson_problem.output_variables)) solver = PINN(problem=poisson_problem, model=model) trainer = Trainer(solver=solver, - precision='64-true', + callbacks=[R3Refinement(sample_every=1)], accelerator='cpu', - callbacks=[R3Refinement(sample_every=2)], max_epochs=5) + before_n_points = {loc : len(pts) for loc, pts in trainer.solver.problem.input_pts.items()} trainer.train() + after_n_points = {loc : len(pts) for loc, pts in trainer.solver.problem.input_pts.items()} + assert before_n_points == after_n_points