Solvers logging (#202)
* Modifying solvers to log every epoch correctly * add `on_epoch` flag to logger * fix bug in `pinn.py` `pts -> samples` in `_loss_phys` * add `optimizer_zero_grad()` in garom generator training loop * modify imports in `callbacks.py` * fixing tests --------- Co-authored-by: Dario Coscia <dariocoscia@dhcp-015.eduroam.sissa.it>
This commit is contained in:
committed by
Nicola Demo
parent
d654259428
commit
3f9305d475
@@ -61,7 +61,7 @@ class R3Refinement(Callback):
|
||||
pts.retain_grad()
|
||||
# PINN loss: equation evaluated only on locations where sampling is needed
|
||||
target = condition.equation.residual(pts, solver.forward(pts))
|
||||
res_loss[location] = torch.abs(target)
|
||||
res_loss[location] = torch.abs(target).as_subclass(torch.Tensor)
|
||||
tot_loss.append(torch.abs(target))
|
||||
|
||||
return torch.vstack(tot_loss), res_loss
|
||||
@@ -74,6 +74,7 @@ class R3Refinement(Callback):
|
||||
"""
|
||||
# compute residual (all device possible)
|
||||
tot_loss, res_loss = self._compute_residual(trainer)
|
||||
tot_loss = tot_loss.as_subclass(torch.Tensor)
|
||||
|
||||
# !!!!!! From now everything is performed on CPU !!!!!!
|
||||
|
||||
@@ -89,10 +90,9 @@ class R3Refinement(Callback):
|
||||
pts = pts.cpu().detach()
|
||||
residuals = res_loss[location].cpu()
|
||||
mask = (residuals > avg).flatten()
|
||||
# TODO masking remove labels
|
||||
pts = pts[mask]
|
||||
if any(mask): # if there are residuals greater than averge we append them
|
||||
pts = pts[mask] # TODO masking remove labels
|
||||
pts.labels = labels
|
||||
####
|
||||
old_pts[location] = pts
|
||||
tot_points += len(pts)
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
'''PINA Callbacks Implementations'''
|
||||
|
||||
from lightning.pytorch.callbacks import Callback
|
||||
from pytorch_lightning.callbacks import Callback
|
||||
import torch
|
||||
import copy
|
||||
|
||||
|
||||
@@ -118,8 +118,6 @@ class LabelTensor(torch.Tensor):
|
||||
tensors = [lt.extract(labels) for lt in label_tensors]
|
||||
return LabelTensor(torch.vstack(tensors), labels)
|
||||
|
||||
# TODO remove try/ except thing IMPORTANT
|
||||
# make the label None of default
|
||||
def clone(self, *args, **kwargs):
|
||||
"""
|
||||
Clone the LabelTensor. For more details, see
|
||||
@@ -128,11 +126,12 @@ class LabelTensor(torch.Tensor):
|
||||
:return: a copy of the tensor
|
||||
:rtype: LabelTensor
|
||||
"""
|
||||
try:
|
||||
# # used before merging
|
||||
# try:
|
||||
# out = LabelTensor(super().clone(*args, **kwargs), self.labels)
|
||||
# except:
|
||||
# out = super().clone(*args, **kwargs)
|
||||
out = LabelTensor(super().clone(*args, **kwargs), self.labels)
|
||||
except: # this is used when the tensor loose the labels, notice it will create a bug! Kept for compatibility with Lightining
|
||||
out = super().clone(*args, **kwargs)
|
||||
|
||||
return out
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
|
||||
@@ -166,6 +166,7 @@ class GAROM(SolverInterface):
|
||||
Private method to train the generator network.
|
||||
"""
|
||||
optimizer = self.optimizer_generator
|
||||
optimizer.zero_grad()
|
||||
|
||||
generated_snapshots = self.generator(parameters)
|
||||
|
||||
@@ -258,10 +259,10 @@ class GAROM(SolverInterface):
|
||||
diff = self._update_weights(d_loss_real, d_loss_fake)
|
||||
|
||||
# logging
|
||||
self.log('mean_loss', float(r_loss), prog_bar=True, logger=True)
|
||||
self.log('d_loss', float(d_loss), prog_bar=True, logger=True)
|
||||
self.log('g_loss', float(g_loss), prog_bar=True, logger=True)
|
||||
self.log('stability_metric', float(d_loss_real + torch.abs(diff)), prog_bar=True, logger=True)
|
||||
self.log('mean_loss', float(r_loss), prog_bar=True, logger=True, on_epoch=True, on_step=False)
|
||||
self.log('d_loss', float(d_loss), prog_bar=True, logger=True, on_epoch=True, on_step=False)
|
||||
self.log('g_loss', float(g_loss), prog_bar=True, logger=True, on_epoch=True, on_step=False)
|
||||
self.log('stability_metric', float(d_loss_real + torch.abs(diff)), prog_bar=True, logger=True, on_epoch=True, on_step=False)
|
||||
|
||||
return
|
||||
|
||||
|
||||
@@ -130,7 +130,7 @@ class PINN(SolverInterface):
|
||||
|
||||
if len(batch) == 2:
|
||||
samples = pts[condition_idx == condition_id]
|
||||
loss = self._loss_phys(pts, condition.equation)
|
||||
loss = self._loss_phys(samples, condition.equation)
|
||||
elif len(batch) == 3:
|
||||
samples = pts[condition_idx == condition_id]
|
||||
ground_truth = batch['output'][condition_idx == condition_id]
|
||||
@@ -138,18 +138,19 @@ class PINN(SolverInterface):
|
||||
else:
|
||||
raise ValueError("Batch size not supported")
|
||||
|
||||
# TODO for users this us hard to remebeber when creating a new solver, to fix in a smarter way
|
||||
loss = loss.as_subclass(torch.Tensor)
|
||||
loss = loss
|
||||
|
||||
# add condition losses and accumulate logging for each epoch
|
||||
condition_losses.append(loss * condition.data_weight)
|
||||
self.log(condition_name + '_loss', float(loss),
|
||||
prog_bar=True, logger=True, on_epoch=True, on_step=False)
|
||||
|
||||
# TODO Fix the bug, tot_loss is a label tensor without labels
|
||||
# we need to pass it as a torch tensor to make everything work
|
||||
# add to tot loss and accumulate logging for each epoch
|
||||
total_loss = sum(condition_losses)
|
||||
self.log('mean_loss', float(total_loss / len(condition_losses)),
|
||||
prog_bar=True, logger=True, on_epoch=True, on_step=False)
|
||||
|
||||
self.log('mean_loss', float(total_loss / len(condition_losses)), prog_bar=True, logger=True)
|
||||
# for condition_loss, loss in zip(condition_names, condition_losses):
|
||||
# self.log(condition_loss + '_loss', float(loss), prog_bar=True, logger=True)
|
||||
return total_loss
|
||||
|
||||
@property
|
||||
|
||||
Reference in New Issue
Block a user