Tutorials v0.1 (#178)
Tutorial update and small fixes * Tutorials update + Tutorial FNO * Create a metric tracker callback * Update PINN for logging * Update plotter for plotting * Small fix LabelTensor * Small fix FNO --------- Co-authored-by: Dario Coscia <dariocoscia@cli-10-110-13-250.WIFIeduroamSTUD.units.it> Co-authored-by: Dario Coscia <dariocoscia@dhcp-176.eduroam.sissa.it>
This commit is contained in:
committed by
Nicola Demo
parent
939353f517
commit
a9b1bd2826
@@ -109,12 +109,14 @@ class PINN(SolverInterface):
|
||||
"""
|
||||
|
||||
condition_losses = []
|
||||
condition_names = []
|
||||
|
||||
for condition_name, samples in batch.items():
|
||||
|
||||
if condition_name not in self.problem.conditions:
|
||||
raise RuntimeError('Something wrong happened.')
|
||||
|
||||
condition_names.append(condition_name)
|
||||
condition = self.problem.conditions[condition_name]
|
||||
|
||||
# PINN loss: equation evaluated on location or input_points
|
||||
@@ -132,9 +134,9 @@ class PINN(SolverInterface):
|
||||
# we need to pass it as a torch tensor to make everything work
|
||||
total_loss = sum(condition_losses)
|
||||
|
||||
self.log('mean_loss', float(total_loss / len(condition_losses)), prog_bar=True, logger=False)
|
||||
for condition_loss, loss in zip(self.problem.conditions, condition_losses):
|
||||
self.log(condition_loss + '_loss', float(loss), prog_bar=True, logger=False)
|
||||
self.log('mean_loss', float(total_loss / len(condition_losses)), prog_bar=True, logger=True)
|
||||
for condition_loss, loss in zip(condition_names, condition_losses):
|
||||
self.log(condition_loss + '_loss', float(loss), prog_bar=True, logger=True)
|
||||
return total_loss
|
||||
|
||||
@property
|
||||
|
||||
Reference in New Issue
Block a user