Solvers logging (#202)

* Modifying solvers to log every epoch correctly
* add `on_epoch` flag to logger
* fix bug in `pinn.py` `pts -> samples` in `_loss_phys`
* add `optimizer_zero_grad()` in garom generator training loop
* modify imports in `callbacks.py`
* fixing tests

---------

Co-authored-by: Dario Coscia <dariocoscia@dhcp-015.eduroam.sissa.it>
This commit is contained in:
Dario Coscia
2023-11-08 14:10:23 +01:00
committed by Nicola Demo
parent d654259428
commit 3f9305d475
5 changed files with 28 additions and 27 deletions

View File

@@ -61,7 +61,7 @@ class R3Refinement(Callback):
pts.retain_grad() pts.retain_grad()
# PINN loss: equation evaluated only on locations where sampling is needed # PINN loss: equation evaluated only on locations where sampling is needed
target = condition.equation.residual(pts, solver.forward(pts)) target = condition.equation.residual(pts, solver.forward(pts))
res_loss[location] = torch.abs(target) res_loss[location] = torch.abs(target).as_subclass(torch.Tensor)
tot_loss.append(torch.abs(target)) tot_loss.append(torch.abs(target))
return torch.vstack(tot_loss), res_loss return torch.vstack(tot_loss), res_loss
@@ -74,6 +74,7 @@ class R3Refinement(Callback):
""" """
# compute residual (all device possible) # compute residual (all device possible)
tot_loss, res_loss = self._compute_residual(trainer) tot_loss, res_loss = self._compute_residual(trainer)
tot_loss = tot_loss.as_subclass(torch.Tensor)
# !!!!!! From now everything is performed on CPU !!!!!! # !!!!!! From now everything is performed on CPU !!!!!!
@@ -89,10 +90,9 @@ class R3Refinement(Callback):
pts = pts.cpu().detach() pts = pts.cpu().detach()
residuals = res_loss[location].cpu() residuals = res_loss[location].cpu()
mask = (residuals > avg).flatten() mask = (residuals > avg).flatten()
# TODO masking remove labels if any(mask): # if there are residuals greater than averge we append them
pts = pts[mask] pts = pts[mask] # TODO masking remove labels
pts.labels = labels pts.labels = labels
####
old_pts[location] = pts old_pts[location] = pts
tot_points += len(pts) tot_points += len(pts)

View File

@@ -1,6 +1,6 @@
'''PINA Callbacks Implementations''' '''PINA Callbacks Implementations'''
from lightning.pytorch.callbacks import Callback from pytorch_lightning.callbacks import Callback
import torch import torch
import copy import copy

View File

@@ -118,8 +118,6 @@ class LabelTensor(torch.Tensor):
tensors = [lt.extract(labels) for lt in label_tensors] tensors = [lt.extract(labels) for lt in label_tensors]
return LabelTensor(torch.vstack(tensors), labels) return LabelTensor(torch.vstack(tensors), labels)
# TODO remove try/ except thing IMPORTANT
# make the label None of default
def clone(self, *args, **kwargs): def clone(self, *args, **kwargs):
""" """
Clone the LabelTensor. For more details, see Clone the LabelTensor. For more details, see
@@ -128,11 +126,12 @@ class LabelTensor(torch.Tensor):
:return: a copy of the tensor :return: a copy of the tensor
:rtype: LabelTensor :rtype: LabelTensor
""" """
try: # # used before merging
# try:
# out = LabelTensor(super().clone(*args, **kwargs), self.labels)
# except:
# out = super().clone(*args, **kwargs)
out = LabelTensor(super().clone(*args, **kwargs), self.labels) out = LabelTensor(super().clone(*args, **kwargs), self.labels)
except: # this is used when the tensor loose the labels, notice it will create a bug! Kept for compatibility with Lightining
out = super().clone(*args, **kwargs)
return out return out
def to(self, *args, **kwargs): def to(self, *args, **kwargs):

View File

@@ -166,6 +166,7 @@ class GAROM(SolverInterface):
Private method to train the generator network. Private method to train the generator network.
""" """
optimizer = self.optimizer_generator optimizer = self.optimizer_generator
optimizer.zero_grad()
generated_snapshots = self.generator(parameters) generated_snapshots = self.generator(parameters)
@@ -258,10 +259,10 @@ class GAROM(SolverInterface):
diff = self._update_weights(d_loss_real, d_loss_fake) diff = self._update_weights(d_loss_real, d_loss_fake)
# logging # logging
self.log('mean_loss', float(r_loss), prog_bar=True, logger=True) self.log('mean_loss', float(r_loss), prog_bar=True, logger=True, on_epoch=True, on_step=False)
self.log('d_loss', float(d_loss), prog_bar=True, logger=True) self.log('d_loss', float(d_loss), prog_bar=True, logger=True, on_epoch=True, on_step=False)
self.log('g_loss', float(g_loss), prog_bar=True, logger=True) self.log('g_loss', float(g_loss), prog_bar=True, logger=True, on_epoch=True, on_step=False)
self.log('stability_metric', float(d_loss_real + torch.abs(diff)), prog_bar=True, logger=True) self.log('stability_metric', float(d_loss_real + torch.abs(diff)), prog_bar=True, logger=True, on_epoch=True, on_step=False)
return return

View File

@@ -130,7 +130,7 @@ class PINN(SolverInterface):
if len(batch) == 2: if len(batch) == 2:
samples = pts[condition_idx == condition_id] samples = pts[condition_idx == condition_id]
loss = self._loss_phys(pts, condition.equation) loss = self._loss_phys(samples, condition.equation)
elif len(batch) == 3: elif len(batch) == 3:
samples = pts[condition_idx == condition_id] samples = pts[condition_idx == condition_id]
ground_truth = batch['output'][condition_idx == condition_id] ground_truth = batch['output'][condition_idx == condition_id]
@@ -138,18 +138,19 @@ class PINN(SolverInterface):
else: else:
raise ValueError("Batch size not supported") raise ValueError("Batch size not supported")
# TODO for users this us hard to remebeber when creating a new solver, to fix in a smarter way
loss = loss.as_subclass(torch.Tensor) loss = loss.as_subclass(torch.Tensor)
loss = loss
# add condition losses and accumulate logging for each epoch
condition_losses.append(loss * condition.data_weight) condition_losses.append(loss * condition.data_weight)
self.log(condition_name + '_loss', float(loss),
prog_bar=True, logger=True, on_epoch=True, on_step=False)
# TODO Fix the bug, tot_loss is a label tensor without labels # add to tot loss and accumulate logging for each epoch
# we need to pass it as a torch tensor to make everything work
total_loss = sum(condition_losses) total_loss = sum(condition_losses)
self.log('mean_loss', float(total_loss / len(condition_losses)),
prog_bar=True, logger=True, on_epoch=True, on_step=False)
self.log('mean_loss', float(total_loss / len(condition_losses)), prog_bar=True, logger=True)
# for condition_loss, loss in zip(condition_names, condition_losses):
# self.log(condition_loss + '_loss', float(loss), prog_bar=True, logger=True)
return total_loss return total_loss
@property @property