🎨 Format Python code with psf/black (#301)
Co-authored-by: ndem0 <ndem0@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
5f89968805
commit
aef134cfb3
@@ -95,7 +95,7 @@ class R3Refinement(Callback):
|
|||||||
|
|
||||||
# average loss
|
# average loss
|
||||||
avg = (tot_loss.mean()).to("cpu")
|
avg = (tot_loss.mean()).to("cpu")
|
||||||
old_pts = {} # points to be retained
|
old_pts = {} # points to be retained
|
||||||
for location in self._sampling_locations:
|
for location in self._sampling_locations:
|
||||||
pts = trainer._model.problem.input_pts[location]
|
pts = trainer._model.problem.input_pts[location]
|
||||||
labels = pts.labels
|
labels = pts.labels
|
||||||
@@ -106,18 +106,18 @@ class R3Refinement(Callback):
|
|||||||
pts = (pts[mask]).as_subclass(LabelTensor)
|
pts = (pts[mask]).as_subclass(LabelTensor)
|
||||||
pts.labels = labels
|
pts.labels = labels
|
||||||
old_pts[location] = pts
|
old_pts[location] = pts
|
||||||
numb_pts = self._const_pts[location] - len(old_pts[location])
|
numb_pts = self._const_pts[location] - len(old_pts[location])
|
||||||
# sample new points
|
# sample new points
|
||||||
trainer._model.problem.discretise_domain(
|
trainer._model.problem.discretise_domain(
|
||||||
numb_pts, "random", locations=[location]
|
numb_pts, "random", locations=[location]
|
||||||
)
|
)
|
||||||
|
|
||||||
else: # if no res greater than average, samples all uniformly
|
else: # if no res greater than average, samples all uniformly
|
||||||
numb_pts = self._const_pts[location]
|
numb_pts = self._const_pts[location]
|
||||||
# sample new points
|
# sample new points
|
||||||
trainer._model.problem.discretise_domain(
|
trainer._model.problem.discretise_domain(
|
||||||
numb_pts, "random", locations=[location]
|
numb_pts, "random", locations=[location]
|
||||||
)
|
)
|
||||||
# adding previous population points
|
# adding previous population points
|
||||||
trainer._model.problem.add_points(old_pts)
|
trainer._model.problem.add_points(old_pts)
|
||||||
|
|
||||||
@@ -154,8 +154,6 @@ class R3Refinement(Callback):
|
|||||||
const_pts[location] = len(pts)
|
const_pts[location] = len(pts)
|
||||||
self._const_pts = const_pts
|
self._const_pts = const_pts
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def on_train_epoch_end(self, trainer, __):
|
def on_train_epoch_end(self, trainer, __):
|
||||||
"""
|
"""
|
||||||
Callback function called at the end of each training epoch.
|
Callback function called at the end of each training epoch.
|
||||||
|
|||||||
Reference in New Issue
Block a user