Solvers logging (#202)
* Modifying solvers to log every epoch correctly * add `on_epoch` flag to logger * fix bug in `pinn.py` `pts -> samples` in `_loss_phys` * add `optimizer_zero_grad()` in garom generator training loop * modify imports in `callbacks.py` * fixing tests --------- Co-authored-by: Dario Coscia <dariocoscia@dhcp-015.eduroam.sissa.it>
This commit is contained in:
committed by
Nicola Demo
parent
d654259428
commit
3f9305d475
@@ -130,7 +130,7 @@ class PINN(SolverInterface):
|
||||
|
||||
if len(batch) == 2:
|
||||
samples = pts[condition_idx == condition_id]
|
||||
loss = self._loss_phys(pts, condition.equation)
|
||||
loss = self._loss_phys(samples, condition.equation)
|
||||
elif len(batch) == 3:
|
||||
samples = pts[condition_idx == condition_id]
|
||||
ground_truth = batch['output'][condition_idx == condition_id]
|
||||
@@ -138,18 +138,19 @@ class PINN(SolverInterface):
|
||||
else:
|
||||
raise ValueError("Batch size not supported")
|
||||
|
||||
# TODO for users this us hard to remebeber when creating a new solver, to fix in a smarter way
|
||||
loss = loss.as_subclass(torch.Tensor)
|
||||
loss = loss
|
||||
|
||||
# add condition losses and accumulate logging for each epoch
|
||||
condition_losses.append(loss * condition.data_weight)
|
||||
self.log(condition_name + '_loss', float(loss),
|
||||
prog_bar=True, logger=True, on_epoch=True, on_step=False)
|
||||
|
||||
# TODO Fix the bug, tot_loss is a label tensor without labels
|
||||
# we need to pass it as a torch tensor to make everything work
|
||||
# add to tot loss and accumulate logging for each epoch
|
||||
total_loss = sum(condition_losses)
|
||||
self.log('mean_loss', float(total_loss / len(condition_losses)),
|
||||
prog_bar=True, logger=True, on_epoch=True, on_step=False)
|
||||
|
||||
self.log('mean_loss', float(total_loss / len(condition_losses)), prog_bar=True, logger=True)
|
||||
# for condition_loss, loss in zip(condition_names, condition_losses):
|
||||
# self.log(condition_loss + '_loss', float(loss), prog_bar=True, logger=True)
|
||||
return total_loss
|
||||
|
||||
@property
|
||||
|
||||
Reference in New Issue
Block a user