Solvers logging (#202)

* Modifying solvers to log every epoch correctly
* add `on_epoch` flag to logger
* fix bug in `pinn.py` `pts -> samples` in `_loss_phys`
* add `optimizer_zero_grad()` in garom generator training loop
* modify imports in `callbacks.py`
* fixing tests

---------

Co-authored-by: Dario Coscia <dariocoscia@dhcp-015.eduroam.sissa.it>
This commit is contained in:
Dario Coscia
2023-11-08 14:10:23 +01:00
committed by Nicola Demo
parent d654259428
commit 3f9305d475
5 changed files with 28 additions and 27 deletions

View File

@@ -61,7 +61,7 @@ class R3Refinement(Callback):
pts.retain_grad()
# PINN loss: equation evaluated only on locations where sampling is needed
target = condition.equation.residual(pts, solver.forward(pts))
res_loss[location] = torch.abs(target)
res_loss[location] = torch.abs(target).as_subclass(torch.Tensor)
tot_loss.append(torch.abs(target))
return torch.vstack(tot_loss), res_loss
@@ -74,6 +74,7 @@ class R3Refinement(Callback):
"""
# compute residual (all device possible)
tot_loss, res_loss = self._compute_residual(trainer)
tot_loss = tot_loss.as_subclass(torch.Tensor)
# !!!!!! From now everything is performed on CPU !!!!!!
@@ -89,12 +90,11 @@ class R3Refinement(Callback):
pts = pts.cpu().detach()
residuals = res_loss[location].cpu()
mask = (residuals > avg).flatten()
# TODO masking remove labels
pts = pts[mask]
pts.labels = labels
####
old_pts[location] = pts
tot_points += len(pts)
if any(mask): # if there are residuals greater than averge we append them
pts = pts[mask] # TODO masking remove labels
pts.labels = labels
old_pts[location] = pts
tot_points += len(pts)
# extract new points to sample uniformally for each location
n_points = (self._tot_pop_numb - tot_points ) // len(self._sampling_locations)

View File

@@ -1,6 +1,6 @@
'''PINA Callbacks Implementations'''
from lightning.pytorch.callbacks import Callback
from pytorch_lightning.callbacks import Callback
import torch
import copy

View File

@@ -118,8 +118,6 @@ class LabelTensor(torch.Tensor):
tensors = [lt.extract(labels) for lt in label_tensors]
return LabelTensor(torch.vstack(tensors), labels)
# TODO remove try/ except thing IMPORTANT
# make the label None of default
def clone(self, *args, **kwargs):
"""
Clone the LabelTensor. For more details, see
@@ -128,11 +126,12 @@ class LabelTensor(torch.Tensor):
:return: a copy of the tensor
:rtype: LabelTensor
"""
try:
out = LabelTensor(super().clone(*args, **kwargs), self.labels)
except: # this is used when the tensor loose the labels, notice it will create a bug! Kept for compatibility with Lightining
out = super().clone(*args, **kwargs)
# # used before merging
# try:
# out = LabelTensor(super().clone(*args, **kwargs), self.labels)
# except:
# out = super().clone(*args, **kwargs)
out = LabelTensor(super().clone(*args, **kwargs), self.labels)
return out
def to(self, *args, **kwargs):

View File

@@ -166,6 +166,7 @@ class GAROM(SolverInterface):
Private method to train the generator network.
"""
optimizer = self.optimizer_generator
optimizer.zero_grad()
generated_snapshots = self.generator(parameters)
@@ -258,10 +259,10 @@ class GAROM(SolverInterface):
diff = self._update_weights(d_loss_real, d_loss_fake)
# logging
self.log('mean_loss', float(r_loss), prog_bar=True, logger=True)
self.log('d_loss', float(d_loss), prog_bar=True, logger=True)
self.log('g_loss', float(g_loss), prog_bar=True, logger=True)
self.log('stability_metric', float(d_loss_real + torch.abs(diff)), prog_bar=True, logger=True)
self.log('mean_loss', float(r_loss), prog_bar=True, logger=True, on_epoch=True, on_step=False)
self.log('d_loss', float(d_loss), prog_bar=True, logger=True, on_epoch=True, on_step=False)
self.log('g_loss', float(g_loss), prog_bar=True, logger=True, on_epoch=True, on_step=False)
self.log('stability_metric', float(d_loss_real + torch.abs(diff)), prog_bar=True, logger=True, on_epoch=True, on_step=False)
return

View File

@@ -130,7 +130,7 @@ class PINN(SolverInterface):
if len(batch) == 2:
samples = pts[condition_idx == condition_id]
loss = self._loss_phys(pts, condition.equation)
loss = self._loss_phys(samples, condition.equation)
elif len(batch) == 3:
samples = pts[condition_idx == condition_id]
ground_truth = batch['output'][condition_idx == condition_id]
@@ -138,18 +138,19 @@ class PINN(SolverInterface):
else:
raise ValueError("Batch size not supported")
# TODO for users this us hard to remebeber when creating a new solver, to fix in a smarter way
loss = loss.as_subclass(torch.Tensor)
loss = loss
# add condition losses and accumulate logging for each epoch
condition_losses.append(loss * condition.data_weight)
self.log(condition_name + '_loss', float(loss),
prog_bar=True, logger=True, on_epoch=True, on_step=False)
# TODO Fix the bug, tot_loss is a label tensor without labels
# we need to pass it as a torch tensor to make everything work
# add to tot loss and accumulate logging for each epoch
total_loss = sum(condition_losses)
self.log('mean_loss', float(total_loss / len(condition_losses)),
prog_bar=True, logger=True, on_epoch=True, on_step=False)
self.log('mean_loss', float(total_loss / len(condition_losses)), prog_bar=True, logger=True)
# for condition_loss, loss in zip(condition_names, condition_losses):
# self.log(condition_loss + '_loss', float(loss), prog_bar=True, logger=True)
return total_loss
@property