Solvers logging (#202)

* Modifying solvers to log every epoch correctly
* add `on_epoch` flag to logger
* fix bug in `pinn.py` `pts -> samples` in `_loss_phys`
* add `optimizer_zero_grad()` in garom generator training loop
* modify imports in `callbacks.py`
* fixing tests

---------

Co-authored-by: Dario Coscia <dariocoscia@dhcp-015.eduroam.sissa.it>
This commit is contained in:
Dario Coscia
2023-11-08 14:10:23 +01:00
committed by Nicola Demo
parent d654259428
commit 3f9305d475
5 changed files with 28 additions and 27 deletions

View File

@@ -61,7 +61,7 @@ class R3Refinement(Callback):
pts.retain_grad()
# PINN loss: equation evaluated only on locations where sampling is needed
target = condition.equation.residual(pts, solver.forward(pts))
res_loss[location] = torch.abs(target)
res_loss[location] = torch.abs(target).as_subclass(torch.Tensor)
tot_loss.append(torch.abs(target))
return torch.vstack(tot_loss), res_loss
@@ -74,6 +74,7 @@ class R3Refinement(Callback):
"""
# compute residual (all device possible)
tot_loss, res_loss = self._compute_residual(trainer)
tot_loss = tot_loss.as_subclass(torch.Tensor)
# !!!!!! From now everything is performed on CPU !!!!!!
@@ -89,12 +90,11 @@ class R3Refinement(Callback):
pts = pts.cpu().detach()
residuals = res_loss[location].cpu()
mask = (residuals > avg).flatten()
# TODO masking remove labels
pts = pts[mask]
pts.labels = labels
####
old_pts[location] = pts
tot_points += len(pts)
if any(mask): # if there are residuals greater than averge we append them
pts = pts[mask] # TODO masking remove labels
pts.labels = labels
old_pts[location] = pts
tot_points += len(pts)
# extract new points to sample uniformally for each location
n_points = (self._tot_pop_numb - tot_points ) // len(self._sampling_locations)