Solvers logging (#202)
* Modifying solvers to log every epoch correctly * add `on_epoch` flag to logger * fix bug in `pinn.py` `pts -> samples` in `_loss_phys` * add `optimizer_zero_grad()` in garom generator training loop * modify imports in `callbacks.py` * fixing tests --------- Co-authored-by: Dario Coscia <dariocoscia@dhcp-015.eduroam.sissa.it>
This commit is contained in:
committed by
Nicola Demo
parent
d654259428
commit
3f9305d475
@@ -61,7 +61,7 @@ class R3Refinement(Callback):
|
|||||||
pts.retain_grad()
|
pts.retain_grad()
|
||||||
# PINN loss: equation evaluated only on locations where sampling is needed
|
# PINN loss: equation evaluated only on locations where sampling is needed
|
||||||
target = condition.equation.residual(pts, solver.forward(pts))
|
target = condition.equation.residual(pts, solver.forward(pts))
|
||||||
res_loss[location] = torch.abs(target)
|
res_loss[location] = torch.abs(target).as_subclass(torch.Tensor)
|
||||||
tot_loss.append(torch.abs(target))
|
tot_loss.append(torch.abs(target))
|
||||||
|
|
||||||
return torch.vstack(tot_loss), res_loss
|
return torch.vstack(tot_loss), res_loss
|
||||||
@@ -74,6 +74,7 @@ class R3Refinement(Callback):
|
|||||||
"""
|
"""
|
||||||
# compute residual (all device possible)
|
# compute residual (all device possible)
|
||||||
tot_loss, res_loss = self._compute_residual(trainer)
|
tot_loss, res_loss = self._compute_residual(trainer)
|
||||||
|
tot_loss = tot_loss.as_subclass(torch.Tensor)
|
||||||
|
|
||||||
# !!!!!! From now everything is performed on CPU !!!!!!
|
# !!!!!! From now everything is performed on CPU !!!!!!
|
||||||
|
|
||||||
@@ -89,10 +90,9 @@ class R3Refinement(Callback):
|
|||||||
pts = pts.cpu().detach()
|
pts = pts.cpu().detach()
|
||||||
residuals = res_loss[location].cpu()
|
residuals = res_loss[location].cpu()
|
||||||
mask = (residuals > avg).flatten()
|
mask = (residuals > avg).flatten()
|
||||||
# TODO masking remove labels
|
if any(mask): # if there are residuals greater than averge we append them
|
||||||
pts = pts[mask]
|
pts = pts[mask] # TODO masking remove labels
|
||||||
pts.labels = labels
|
pts.labels = labels
|
||||||
####
|
|
||||||
old_pts[location] = pts
|
old_pts[location] = pts
|
||||||
tot_points += len(pts)
|
tot_points += len(pts)
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
'''PINA Callbacks Implementations'''
|
'''PINA Callbacks Implementations'''
|
||||||
|
|
||||||
from lightning.pytorch.callbacks import Callback
|
from pytorch_lightning.callbacks import Callback
|
||||||
import torch
|
import torch
|
||||||
import copy
|
import copy
|
||||||
|
|
||||||
|
|||||||
@@ -118,8 +118,6 @@ class LabelTensor(torch.Tensor):
|
|||||||
tensors = [lt.extract(labels) for lt in label_tensors]
|
tensors = [lt.extract(labels) for lt in label_tensors]
|
||||||
return LabelTensor(torch.vstack(tensors), labels)
|
return LabelTensor(torch.vstack(tensors), labels)
|
||||||
|
|
||||||
# TODO remove try/ except thing IMPORTANT
|
|
||||||
# make the label None of default
|
|
||||||
def clone(self, *args, **kwargs):
|
def clone(self, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
Clone the LabelTensor. For more details, see
|
Clone the LabelTensor. For more details, see
|
||||||
@@ -128,11 +126,12 @@ class LabelTensor(torch.Tensor):
|
|||||||
:return: a copy of the tensor
|
:return: a copy of the tensor
|
||||||
:rtype: LabelTensor
|
:rtype: LabelTensor
|
||||||
"""
|
"""
|
||||||
try:
|
# # used before merging
|
||||||
|
# try:
|
||||||
|
# out = LabelTensor(super().clone(*args, **kwargs), self.labels)
|
||||||
|
# except:
|
||||||
|
# out = super().clone(*args, **kwargs)
|
||||||
out = LabelTensor(super().clone(*args, **kwargs), self.labels)
|
out = LabelTensor(super().clone(*args, **kwargs), self.labels)
|
||||||
except: # this is used when the tensor loose the labels, notice it will create a bug! Kept for compatibility with Lightining
|
|
||||||
out = super().clone(*args, **kwargs)
|
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def to(self, *args, **kwargs):
|
def to(self, *args, **kwargs):
|
||||||
|
|||||||
@@ -166,6 +166,7 @@ class GAROM(SolverInterface):
|
|||||||
Private method to train the generator network.
|
Private method to train the generator network.
|
||||||
"""
|
"""
|
||||||
optimizer = self.optimizer_generator
|
optimizer = self.optimizer_generator
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
generated_snapshots = self.generator(parameters)
|
generated_snapshots = self.generator(parameters)
|
||||||
|
|
||||||
@@ -258,10 +259,10 @@ class GAROM(SolverInterface):
|
|||||||
diff = self._update_weights(d_loss_real, d_loss_fake)
|
diff = self._update_weights(d_loss_real, d_loss_fake)
|
||||||
|
|
||||||
# logging
|
# logging
|
||||||
self.log('mean_loss', float(r_loss), prog_bar=True, logger=True)
|
self.log('mean_loss', float(r_loss), prog_bar=True, logger=True, on_epoch=True, on_step=False)
|
||||||
self.log('d_loss', float(d_loss), prog_bar=True, logger=True)
|
self.log('d_loss', float(d_loss), prog_bar=True, logger=True, on_epoch=True, on_step=False)
|
||||||
self.log('g_loss', float(g_loss), prog_bar=True, logger=True)
|
self.log('g_loss', float(g_loss), prog_bar=True, logger=True, on_epoch=True, on_step=False)
|
||||||
self.log('stability_metric', float(d_loss_real + torch.abs(diff)), prog_bar=True, logger=True)
|
self.log('stability_metric', float(d_loss_real + torch.abs(diff)), prog_bar=True, logger=True, on_epoch=True, on_step=False)
|
||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|||||||
@@ -130,7 +130,7 @@ class PINN(SolverInterface):
|
|||||||
|
|
||||||
if len(batch) == 2:
|
if len(batch) == 2:
|
||||||
samples = pts[condition_idx == condition_id]
|
samples = pts[condition_idx == condition_id]
|
||||||
loss = self._loss_phys(pts, condition.equation)
|
loss = self._loss_phys(samples, condition.equation)
|
||||||
elif len(batch) == 3:
|
elif len(batch) == 3:
|
||||||
samples = pts[condition_idx == condition_id]
|
samples = pts[condition_idx == condition_id]
|
||||||
ground_truth = batch['output'][condition_idx == condition_id]
|
ground_truth = batch['output'][condition_idx == condition_id]
|
||||||
@@ -138,18 +138,19 @@ class PINN(SolverInterface):
|
|||||||
else:
|
else:
|
||||||
raise ValueError("Batch size not supported")
|
raise ValueError("Batch size not supported")
|
||||||
|
|
||||||
|
# TODO for users this us hard to remebeber when creating a new solver, to fix in a smarter way
|
||||||
loss = loss.as_subclass(torch.Tensor)
|
loss = loss.as_subclass(torch.Tensor)
|
||||||
loss = loss
|
|
||||||
|
|
||||||
|
# add condition losses and accumulate logging for each epoch
|
||||||
condition_losses.append(loss * condition.data_weight)
|
condition_losses.append(loss * condition.data_weight)
|
||||||
|
self.log(condition_name + '_loss', float(loss),
|
||||||
|
prog_bar=True, logger=True, on_epoch=True, on_step=False)
|
||||||
|
|
||||||
# TODO Fix the bug, tot_loss is a label tensor without labels
|
# add to tot loss and accumulate logging for each epoch
|
||||||
# we need to pass it as a torch tensor to make everything work
|
|
||||||
total_loss = sum(condition_losses)
|
total_loss = sum(condition_losses)
|
||||||
|
self.log('mean_loss', float(total_loss / len(condition_losses)),
|
||||||
|
prog_bar=True, logger=True, on_epoch=True, on_step=False)
|
||||||
|
|
||||||
self.log('mean_loss', float(total_loss / len(condition_losses)), prog_bar=True, logger=True)
|
|
||||||
# for condition_loss, loss in zip(condition_names, condition_losses):
|
|
||||||
# self.log(condition_loss + '_loss', float(loss), prog_bar=True, logger=True)
|
|
||||||
return total_loss
|
return total_loss
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|||||||
Reference in New Issue
Block a user