tutorial validation (#185)
Co-authored-by: Ben Volokh <89551265+benv123@users.noreply.github.com>
This commit is contained in:
17
tutorials/tutorial2/tutorial.py
vendored
17
tutorials/tutorial2/tutorial.py
vendored
@@ -98,7 +98,7 @@ trainer = Trainer(pinn, max_epochs=1000, callbacks=[MetricTracker()])
|
||||
trainer.train()
|
||||
|
||||
|
||||
# Now the *Plotter* class is used to plot the results.
|
||||
# Now the `Plotter` class is used to plot the results.
|
||||
# The solution predicted by the neural network is plotted on the left, the exact one is represented at the center and on the right the error between the exact and the predicted solutions is showed.
|
||||
|
||||
# In[4]:
|
||||
@@ -238,18 +238,3 @@ trainer_learn.train()
|
||||
|
||||
plotter.plot(trainer_learn)
|
||||
|
||||
|
||||
# In[10]:
|
||||
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
plt.figure(figsize=(16, 6))
|
||||
plotter.plot_loss(trainer, label='Standard')
|
||||
plotter.plot_loss(trainer_feat, label='Static Features')
|
||||
plotter.plot_loss(trainer_learn, label='Learnable Features')
|
||||
|
||||
plt.grid()
|
||||
plt.legend()
|
||||
plt.show()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user