🎨 Format Python code with psf/black (#301)

Co-authored-by: ndem0 <ndem0@users.noreply.github.com>
This commit is contained in:
github-actions[bot]
2024-05-21 10:34:11 +02:00
committed by GitHub
parent 5f89968805
commit aef134cfb3

View File

@@ -95,7 +95,7 @@ class R3Refinement(Callback):
# average loss # average loss
avg = (tot_loss.mean()).to("cpu") avg = (tot_loss.mean()).to("cpu")
old_pts = {} # points to be retained old_pts = {} # points to be retained
for location in self._sampling_locations: for location in self._sampling_locations:
pts = trainer._model.problem.input_pts[location] pts = trainer._model.problem.input_pts[location]
labels = pts.labels labels = pts.labels
@@ -106,18 +106,18 @@ class R3Refinement(Callback):
pts = (pts[mask]).as_subclass(LabelTensor) pts = (pts[mask]).as_subclass(LabelTensor)
pts.labels = labels pts.labels = labels
old_pts[location] = pts old_pts[location] = pts
numb_pts = self._const_pts[location] - len(old_pts[location]) numb_pts = self._const_pts[location] - len(old_pts[location])
# sample new points # sample new points
trainer._model.problem.discretise_domain( trainer._model.problem.discretise_domain(
numb_pts, "random", locations=[location] numb_pts, "random", locations=[location]
) )
else: # if no res greater than average, samples all uniformly else: # if no res greater than average, samples all uniformly
numb_pts = self._const_pts[location] numb_pts = self._const_pts[location]
# sample new points # sample new points
trainer._model.problem.discretise_domain( trainer._model.problem.discretise_domain(
numb_pts, "random", locations=[location] numb_pts, "random", locations=[location]
) )
# adding previous population points # adding previous population points
trainer._model.problem.add_points(old_pts) trainer._model.problem.add_points(old_pts)
@@ -154,8 +154,6 @@ class R3Refinement(Callback):
const_pts[location] = len(pts) const_pts[location] = len(pts)
self._const_pts = const_pts self._const_pts = const_pts
def on_train_epoch_end(self, trainer, __): def on_train_epoch_end(self, trainer, __):
""" """
Callback function called at the end of each training epoch. Callback function called at the end of each training epoch.