Tutorials v0.1 (#178)

Tutorial update and small fixes

* Tutorials update + Tutorial FNO
* Create a metric tracker callback
* Update PINN for logging
* Update plotter for plotting
* Small fix LabelTensor
* Small fix FNO

---------

Co-authored-by: Dario Coscia <dariocoscia@cli-10-110-13-250.WIFIeduroamSTUD.units.it>
Co-authored-by: Dario Coscia <dariocoscia@dhcp-176.eduroam.sissa.it>
This commit is contained in:
Dario Coscia
2023-09-26 17:29:37 +02:00
committed by Nicola Demo
parent 939353f517
commit a9b1bd2826
45 changed files with 2760 additions and 1321 deletions

View File

@@ -109,12 +109,14 @@ class PINN(SolverInterface):
"""
condition_losses = []
condition_names = []
for condition_name, samples in batch.items():
if condition_name not in self.problem.conditions:
raise RuntimeError('Something wrong happened.')
condition_names.append(condition_name)
condition = self.problem.conditions[condition_name]
# PINN loss: equation evaluated on location or input_points
@@ -132,9 +134,9 @@ class PINN(SolverInterface):
# we need to pass it as a torch tensor to make everything work
total_loss = sum(condition_losses)
self.log('mean_loss', float(total_loss / len(condition_losses)), prog_bar=True, logger=False)
for condition_loss, loss in zip(self.problem.conditions, condition_losses):
self.log(condition_loss + '_loss', float(loss), prog_bar=True, logger=False)
self.log('mean_loss', float(total_loss / len(condition_losses)), prog_bar=True, logger=True)
for condition_loss, loss in zip(condition_names, condition_losses):
self.log(condition_loss + '_loss', float(loss), prog_bar=True, logger=True)
return total_loss
@property