documentation (#79)
Documentation for operator.py, span.py, plotter.py. Co-authored-by: Dario Coscia <dariocoscia@dhcp-128.eduroam.sissa.it>
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
""" Module for plotting. """
|
||||
import matplotlib
|
||||
#matplotlib.use('Qt5Agg')
|
||||
# matplotlib.use('Qt5Agg')
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -12,8 +12,24 @@ from .problem import SpatialProblem, TimeDependentProblem
|
||||
|
||||
|
||||
class Plotter:
|
||||
"""
|
||||
Implementation of a plotter class, for easy visualizations.
|
||||
"""
|
||||
|
||||
def plot_samples(self, pinn, variables=None):
|
||||
"""
|
||||
Plot a sample of solution.
|
||||
|
||||
:param pinn: the PINN object.
|
||||
:type pinn: PINN
|
||||
:param variables: pinn variable domains: spatial or temporal,
|
||||
defaults to None.
|
||||
:type variables: str, optional
|
||||
|
||||
:Example:
|
||||
>>> plotter = Plotter()
|
||||
>>> plotter.plot_samples(pinn=pinn, variables='spatial')
|
||||
"""
|
||||
|
||||
if variables is None:
|
||||
variables = pinn.problem.domain.variables
|
||||
@@ -51,7 +67,17 @@ class Plotter:
|
||||
plt.show()
|
||||
|
||||
def _1d_plot(self, pts, pred, method, truth_solution=None, **kwargs):
|
||||
"""
|
||||
"""Plot solution for one dimensional function
|
||||
|
||||
:param pts: Points to plot the solution.
|
||||
:type pts: torch.Tensor
|
||||
:param pred: PINN solution evaluated at 'pts'.
|
||||
:type pred: torch.Tensor
|
||||
:param method: not used, kept for code compatibility
|
||||
:type method: None
|
||||
:param truth_solution: Real solution evaluated at 'pts',
|
||||
defaults to None.
|
||||
:type truth_solution: torch.Tensor, optional
|
||||
"""
|
||||
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(8, 8))
|
||||
|
||||
@@ -67,7 +93,19 @@ class Plotter:
|
||||
|
||||
def _2d_plot(self, pts, pred, v, res, method, truth_solution=None,
|
||||
**kwargs):
|
||||
"""
|
||||
"""Plot solution for two dimensional function
|
||||
|
||||
:param pts: Points to plot the solution.
|
||||
:type pts: torch.Tensor
|
||||
:param pred: PINN solution evaluated at 'pts'.
|
||||
:type pred: torch.Tensor
|
||||
:param method: matplotlib method to plot 2-dimensional data,
|
||||
see https://matplotlib.org/stable/api/axes_api.html for
|
||||
reference.
|
||||
:type method: str
|
||||
:param truth_solution: Real solution evaluated at 'pts',
|
||||
defaults to None.
|
||||
:type truth_solution: torch.Tensor, optional
|
||||
"""
|
||||
|
||||
grids = [p_.reshape(res, res) for p_ in pts.extract(v).cpu().T]
|
||||
@@ -77,9 +115,11 @@ class Plotter:
|
||||
truth_output = truth_solution(pts).float().reshape(res, res)
|
||||
fig, ax = plt.subplots(nrows=1, ncols=3, figsize=(16, 6))
|
||||
|
||||
cb = getattr(ax[0], method)(*grids, pred_output.cpu().detach(), **kwargs)
|
||||
cb = getattr(ax[0], method)(
|
||||
*grids, pred_output.cpu().detach(), **kwargs)
|
||||
fig.colorbar(cb, ax=ax[0])
|
||||
cb = getattr(ax[1], method)(*grids, truth_output.cpu().detach(), **kwargs)
|
||||
cb = getattr(ax[1], method)(
|
||||
*grids, truth_output.cpu().detach(), **kwargs)
|
||||
fig.colorbar(cb, ax=ax[1])
|
||||
cb = getattr(ax[2], method)(*grids,
|
||||
(truth_output-pred_output).cpu().detach(),
|
||||
@@ -87,13 +127,31 @@ class Plotter:
|
||||
fig.colorbar(cb, ax=ax[2])
|
||||
else:
|
||||
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(8, 6))
|
||||
cb = getattr(ax, method)(*grids, pred_output.cpu().detach(), **kwargs)
|
||||
cb = getattr(ax, method)(
|
||||
*grids, pred_output.cpu().detach(), **kwargs)
|
||||
fig.colorbar(cb, ax=ax)
|
||||
|
||||
|
||||
def plot(self, pinn, components=None, fixed_variables={}, method='contourf',
|
||||
res=256, filename=None, **kwargs):
|
||||
"""
|
||||
Plot sample of PINN output.
|
||||
|
||||
:param pinn: the PINN object.
|
||||
:type pinn: PINN
|
||||
:param components: function components to plot, defaults to None.
|
||||
:type components: list['str'], optional
|
||||
:param fixed_variables: function variables to be kept fixed during
|
||||
plotting passed as a dict where the dict-key is the variable
|
||||
and the dict-value is the value to be kept fixed, defaults to {}.
|
||||
:type fixed_variables: dict, optional
|
||||
:param method: matplotlib method to plot the solution,
|
||||
defaults to 'contourf'.
|
||||
:type method: str, optional
|
||||
:param res: number of points used for plotting in each axis,
|
||||
defaults to 256.
|
||||
:type res: int, optional
|
||||
:param filename: file name to save the plot, defaults to None
|
||||
:type filename: str, optional
|
||||
"""
|
||||
if components is None:
|
||||
components = [pinn.problem.output_variables]
|
||||
@@ -134,9 +192,14 @@ class Plotter:
|
||||
|
||||
def plot_loss(self, pinn, label=None, log_scale=True):
|
||||
"""
|
||||
Plot the loss trend
|
||||
Plot the loss function values during traininig.
|
||||
|
||||
TODO
|
||||
:param pinn: the PINN object.
|
||||
:type pinn: PINN
|
||||
:param label: matplolib label, defaults to None
|
||||
:type label: str, optional
|
||||
:param log_scale: use of log scale in plotting, defaults to True.
|
||||
:type log_scale: bool, optional
|
||||
"""
|
||||
|
||||
if not label:
|
||||
|
||||
Reference in New Issue
Block a user