adaptive refinement callback (#299)
Fixed problem of non-constant number of points
This commit is contained in:
@@ -38,6 +38,7 @@ class R3Refinement(Callback):
|
||||
# sample every
|
||||
check_consistency(sample_every, int)
|
||||
self._sample_every = sample_every
|
||||
self._const_pts = None
|
||||
|
||||
def _compute_residual(self, trainer):
|
||||
"""
|
||||
@@ -94,40 +95,29 @@ class R3Refinement(Callback):
|
||||
|
||||
# average loss
|
||||
avg = (tot_loss.mean()).to("cpu")
|
||||
|
||||
# points to keep
|
||||
old_pts = {}
|
||||
tot_points = 0
|
||||
old_pts = {} # points to be retained
|
||||
for location in self._sampling_locations:
|
||||
pts = trainer._model.problem.input_pts[location]
|
||||
labels = pts.labels
|
||||
pts = pts.cpu().detach().as_subclass(torch.Tensor)
|
||||
residuals = res_loss[location].cpu()
|
||||
mask = (residuals > avg).flatten()
|
||||
if any(
|
||||
mask
|
||||
): # if there are residuals greater than averge we append them
|
||||
# Fix the issue, masking remove labels
|
||||
if any(mask): # append residuals greater than average
|
||||
pts = (pts[mask]).as_subclass(LabelTensor)
|
||||
pts.labels = labels
|
||||
old_pts[location] = pts
|
||||
tot_points += len(pts)
|
||||
numb_pts = self._const_pts[location] - len(old_pts[location])
|
||||
# sample new points
|
||||
trainer._model.problem.discretise_domain(
|
||||
numb_pts, "random", locations=[location]
|
||||
)
|
||||
|
||||
# extract new points to sample uniformally for each location
|
||||
n_points = (self._tot_pop_numb - tot_points) // len(
|
||||
self._sampling_locations
|
||||
)
|
||||
remainder = (self._tot_pop_numb - tot_points) % len(
|
||||
self._sampling_locations
|
||||
)
|
||||
n_uniform_points = [n_points] * len(self._sampling_locations)
|
||||
n_uniform_points[-1] += remainder
|
||||
|
||||
# sample new points
|
||||
for numb_pts, loc in zip(n_uniform_points, self._sampling_locations):
|
||||
trainer._model.problem.discretise_domain(
|
||||
numb_pts, "random", locations=[loc]
|
||||
)
|
||||
else: # if no res greater than average, samples all uniformly
|
||||
numb_pts = self._const_pts[location]
|
||||
# sample new points
|
||||
trainer._model.problem.discretise_domain(
|
||||
numb_pts, "random", locations=[location]
|
||||
)
|
||||
# adding previous population points
|
||||
trainer._model.problem.add_points(old_pts)
|
||||
|
||||
@@ -158,11 +148,13 @@ class R3Refinement(Callback):
|
||||
self._sampling_locations = locations
|
||||
|
||||
# extract total population
|
||||
total_population = 0
|
||||
const_pts = {} # for each location, store the # of pts to keep constant
|
||||
for location in self._sampling_locations:
|
||||
pts = trainer._model.problem.input_pts[location]
|
||||
total_population += len(pts)
|
||||
self._tot_pop_numb = total_population
|
||||
const_pts[location] = len(pts)
|
||||
self._const_pts = const_pts
|
||||
|
||||
|
||||
|
||||
def on_train_epoch_end(self, trainer, __):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user