Solvers logging (#202)
* Modifying solvers to log every epoch correctly * add `on_epoch` flag to logger * fix bug in `pinn.py` `pts -> samples` in `_loss_phys` * add `optimizer_zero_grad()` in garom generator training loop * modify imports in `callbacks.py` * fixing tests --------- Co-authored-by: Dario Coscia <dariocoscia@dhcp-015.eduroam.sissa.it>
This commit is contained in:
committed by
Nicola Demo
parent
d654259428
commit
3f9305d475
@@ -61,7 +61,7 @@ class R3Refinement(Callback):
|
||||
pts.retain_grad()
|
||||
# PINN loss: equation evaluated only on locations where sampling is needed
|
||||
target = condition.equation.residual(pts, solver.forward(pts))
|
||||
res_loss[location] = torch.abs(target)
|
||||
res_loss[location] = torch.abs(target).as_subclass(torch.Tensor)
|
||||
tot_loss.append(torch.abs(target))
|
||||
|
||||
return torch.vstack(tot_loss), res_loss
|
||||
@@ -74,6 +74,7 @@ class R3Refinement(Callback):
|
||||
"""
|
||||
# compute residual (all device possible)
|
||||
tot_loss, res_loss = self._compute_residual(trainer)
|
||||
tot_loss = tot_loss.as_subclass(torch.Tensor)
|
||||
|
||||
# !!!!!! From now everything is performed on CPU !!!!!!
|
||||
|
||||
@@ -89,12 +90,11 @@ class R3Refinement(Callback):
|
||||
pts = pts.cpu().detach()
|
||||
residuals = res_loss[location].cpu()
|
||||
mask = (residuals > avg).flatten()
|
||||
# TODO masking remove labels
|
||||
pts = pts[mask]
|
||||
pts.labels = labels
|
||||
####
|
||||
old_pts[location] = pts
|
||||
tot_points += len(pts)
|
||||
if any(mask): # if there are residuals greater than averge we append them
|
||||
pts = pts[mask] # TODO masking remove labels
|
||||
pts.labels = labels
|
||||
old_pts[location] = pts
|
||||
tot_points += len(pts)
|
||||
|
||||
# extract new points to sample uniformally for each location
|
||||
n_points = (self._tot_pop_numb - tot_points ) // len(self._sampling_locations)
|
||||
|
||||
Reference in New Issue
Block a user