adaptive refinement callback (#299)

Fixed problem of non-constant number of points
This commit is contained in:
Michele Alessi
2024-05-21 09:51:13 +02:00
committed by GitHub
parent a72ce67873
commit 5f89968805
2 changed files with 24 additions and 30 deletions

View File

@@ -38,6 +38,7 @@ class R3Refinement(Callback):
# sample every # sample every
check_consistency(sample_every, int) check_consistency(sample_every, int)
self._sample_every = sample_every self._sample_every = sample_every
self._const_pts = None
def _compute_residual(self, trainer): def _compute_residual(self, trainer):
""" """
@@ -94,40 +95,29 @@ class R3Refinement(Callback):
# average loss # average loss
avg = (tot_loss.mean()).to("cpu") avg = (tot_loss.mean()).to("cpu")
old_pts = {} # points to be retained
# points to keep
old_pts = {}
tot_points = 0
for location in self._sampling_locations: for location in self._sampling_locations:
pts = trainer._model.problem.input_pts[location] pts = trainer._model.problem.input_pts[location]
labels = pts.labels labels = pts.labels
pts = pts.cpu().detach().as_subclass(torch.Tensor) pts = pts.cpu().detach().as_subclass(torch.Tensor)
residuals = res_loss[location].cpu() residuals = res_loss[location].cpu()
mask = (residuals > avg).flatten() mask = (residuals > avg).flatten()
if any( if any(mask): # append residuals greater than average
mask
): # if there are residuals greater than averge we append them
# Fix the issue, masking remove labels
pts = (pts[mask]).as_subclass(LabelTensor) pts = (pts[mask]).as_subclass(LabelTensor)
pts.labels = labels pts.labels = labels
old_pts[location] = pts old_pts[location] = pts
tot_points += len(pts) numb_pts = self._const_pts[location] - len(old_pts[location])
# sample new points
trainer._model.problem.discretise_domain(
numb_pts, "random", locations=[location]
)
# extract new points to sample uniformally for each location else: # if no res greater than average, samples all uniformly
n_points = (self._tot_pop_numb - tot_points) // len( numb_pts = self._const_pts[location]
self._sampling_locations # sample new points
) trainer._model.problem.discretise_domain(
remainder = (self._tot_pop_numb - tot_points) % len( numb_pts, "random", locations=[location]
self._sampling_locations )
)
n_uniform_points = [n_points] * len(self._sampling_locations)
n_uniform_points[-1] += remainder
# sample new points
for numb_pts, loc in zip(n_uniform_points, self._sampling_locations):
trainer._model.problem.discretise_domain(
numb_pts, "random", locations=[loc]
)
# adding previous population points # adding previous population points
trainer._model.problem.add_points(old_pts) trainer._model.problem.add_points(old_pts)
@@ -158,11 +148,13 @@ class R3Refinement(Callback):
self._sampling_locations = locations self._sampling_locations = locations
# extract total population # extract total population
total_population = 0 const_pts = {} # for each location, store the # of pts to keep constant
for location in self._sampling_locations: for location in self._sampling_locations:
pts = trainer._model.problem.input_pts[location] pts = trainer._model.problem.input_pts[location]
total_population += len(pts) const_pts[location] = len(pts)
self._tot_pop_numb = total_population self._const_pts = const_pts
def on_train_epoch_end(self, trainer, __): def on_train_epoch_end(self, trainer, __):
""" """

View File

@@ -75,13 +75,15 @@ def test_r3refinment_routine():
max_epochs=5) max_epochs=5)
trainer.train() trainer.train()
def test_r3refinment_routine_double_precision(): def test_r3refinment_routine():
model = FeedForward(len(poisson_problem.input_variables), model = FeedForward(len(poisson_problem.input_variables),
len(poisson_problem.output_variables)) len(poisson_problem.output_variables))
solver = PINN(problem=poisson_problem, model=model) solver = PINN(problem=poisson_problem, model=model)
trainer = Trainer(solver=solver, trainer = Trainer(solver=solver,
precision='64-true', callbacks=[R3Refinement(sample_every=1)],
accelerator='cpu', accelerator='cpu',
callbacks=[R3Refinement(sample_every=2)],
max_epochs=5) max_epochs=5)
before_n_points = {loc : len(pts) for loc, pts in trainer.solver.problem.input_pts.items()}
trainer.train() trainer.train()
after_n_points = {loc : len(pts) for loc, pts in trainer.solver.problem.input_pts.items()}
assert before_n_points == after_n_points