documentation (#79)

Documentation for operator.py, span.py, plotter.py. 
Co-authored-by: Dario Coscia <dariocoscia@dhcp-128.eduroam.sissa.it>
This commit is contained in:
Dario Coscia
2023-04-18 10:48:11 +02:00
committed by GitHub
parent f4efaff5a5
commit c536f8267f
8 changed files with 317 additions and 24 deletions

View File

@@ -1,6 +1,6 @@
""" Module for plotting. """
import matplotlib
#matplotlib.use('Qt5Agg')
# matplotlib.use('Qt5Agg')
import matplotlib.pyplot as plt
import numpy as np
import torch
@@ -12,8 +12,24 @@ from .problem import SpatialProblem, TimeDependentProblem
class Plotter:
"""
Implementation of a plotter class, for easy visualizations.
"""
def plot_samples(self, pinn, variables=None):
"""
Plot a sample of solution.
:param pinn: the PINN object.
:type pinn: PINN
:param variables: pinn variable domains: spatial or temporal,
defaults to None.
:type variables: str, optional
:Example:
>>> plotter = Plotter()
>>> plotter.plot_samples(pinn=pinn, variables='spatial')
"""
if variables is None:
variables = pinn.problem.domain.variables
@@ -51,7 +67,17 @@ class Plotter:
plt.show()
def _1d_plot(self, pts, pred, method, truth_solution=None, **kwargs):
"""
"""Plot solution for one dimensional function
:param pts: Points to plot the solution.
:type pts: torch.Tensor
:param pred: PINN solution evaluated at 'pts'.
:type pred: torch.Tensor
:param method: not used, kept for code compatibility
:type method: None
:param truth_solution: Real solution evaluated at 'pts',
defaults to None.
:type truth_solution: torch.Tensor, optional
"""
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(8, 8))
@@ -67,7 +93,19 @@ class Plotter:
def _2d_plot(self, pts, pred, v, res, method, truth_solution=None,
**kwargs):
"""
"""Plot solution for two dimensional function
:param pts: Points to plot the solution.
:type pts: torch.Tensor
:param pred: PINN solution evaluated at 'pts'.
:type pred: torch.Tensor
:param method: matplotlib method to plot 2-dimensional data,
see https://matplotlib.org/stable/api/axes_api.html for
reference.
:type method: str
:param truth_solution: Real solution evaluated at 'pts',
defaults to None.
:type truth_solution: torch.Tensor, optional
"""
grids = [p_.reshape(res, res) for p_ in pts.extract(v).cpu().T]
@@ -77,9 +115,11 @@ class Plotter:
truth_output = truth_solution(pts).float().reshape(res, res)
fig, ax = plt.subplots(nrows=1, ncols=3, figsize=(16, 6))
cb = getattr(ax[0], method)(*grids, pred_output.cpu().detach(), **kwargs)
cb = getattr(ax[0], method)(
*grids, pred_output.cpu().detach(), **kwargs)
fig.colorbar(cb, ax=ax[0])
cb = getattr(ax[1], method)(*grids, truth_output.cpu().detach(), **kwargs)
cb = getattr(ax[1], method)(
*grids, truth_output.cpu().detach(), **kwargs)
fig.colorbar(cb, ax=ax[1])
cb = getattr(ax[2], method)(*grids,
(truth_output-pred_output).cpu().detach(),
@@ -87,13 +127,31 @@ class Plotter:
fig.colorbar(cb, ax=ax[2])
else:
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(8, 6))
cb = getattr(ax, method)(*grids, pred_output.cpu().detach(), **kwargs)
cb = getattr(ax, method)(
*grids, pred_output.cpu().detach(), **kwargs)
fig.colorbar(cb, ax=ax)
def plot(self, pinn, components=None, fixed_variables={}, method='contourf',
res=256, filename=None, **kwargs):
"""
Plot sample of PINN output.
:param pinn: the PINN object.
:type pinn: PINN
:param components: function components to plot, defaults to None.
:type components: list['str'], optional
:param fixed_variables: function variables to be kept fixed during
plotting passed as a dict where the dict-key is the variable
and the dict-value is the value to be kept fixed, defaults to {}.
:type fixed_variables: dict, optional
:param method: matplotlib method to plot the solution,
defaults to 'contourf'.
:type method: str, optional
:param res: number of points used for plotting in each axis,
defaults to 256.
:type res: int, optional
:param filename: file name to save the plot, defaults to None
:type filename: str, optional
"""
if components is None:
components = [pinn.problem.output_variables]
@@ -134,9 +192,14 @@ class Plotter:
def plot_loss(self, pinn, label=None, log_scale=True):
"""
Plot the loss trend
Plot the loss function values during traininig.
TODO
:param pinn: the PINN object.
:type pinn: PINN
:param label: matplolib label, defaults to None
:type label: str, optional
:param log_scale: use of log scale in plotting, defaults to True.
:type log_scale: bool, optional
"""
if not label: