@@ -1,14 +1,9 @@
|
||||
""" Module for plotting. """
|
||||
import matplotlib
|
||||
# matplotlib.use('Qt5Agg')
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from pina import LabelTensor
|
||||
from pina import PINN
|
||||
from .problem import SpatialProblem, TimeDependentProblem
|
||||
#from pina.tdproblem1d import TimeDepProblem1D
|
||||
|
||||
|
||||
class Plotter:
|
||||
@@ -20,11 +15,14 @@ class Plotter:
|
||||
"""
|
||||
Plot a sample of solution.
|
||||
|
||||
:param pinn: the PINN object.
|
||||
:type pinn: PINN
|
||||
:param variables: pinn variable domains: spatial or temporal,
|
||||
defaults to None.
|
||||
:type variables: str, optional
|
||||
:param PINN pinn: the PINN object.
|
||||
:param list(str) variables: variables to plot. If None, all variables
|
||||
are plotted. If 'spatial', only spatial variables are plotted. If
|
||||
'temporal', only temporal variables are plotted. Defaults to None.
|
||||
|
||||
.. todo::
|
||||
- Add support for 3D plots.
|
||||
- Fix support for more complex problems.
|
||||
|
||||
:Example:
|
||||
>>> plotter = Plotter()
|
||||
@@ -134,24 +132,22 @@ class Plotter:
|
||||
def plot(self, pinn, components=None, fixed_variables={}, method='contourf',
|
||||
res=256, filename=None, **kwargs):
|
||||
"""
|
||||
Plot sample of PINN output.
|
||||
Plot sample of PINN output.
|
||||
|
||||
:param pinn: the PINN object.
|
||||
:type pinn: PINN
|
||||
:param components: function components to plot, defaults to None.
|
||||
:type components: list['str'], optional
|
||||
:param fixed_variables: function variables to be kept fixed during
|
||||
plotting passed as a dict where the dict-key is the variable
|
||||
and the dict-value is the value to be kept fixed, defaults to {}.
|
||||
:type fixed_variables: dict, optional
|
||||
:param method: matplotlib method to plot the solution,
|
||||
defaults to 'contourf'.
|
||||
:type method: str, optional
|
||||
:param res: number of points used for plotting in each axis,
|
||||
defaults to 256.
|
||||
:type res: int, optional
|
||||
:param filename: file name to save the plot, defaults to None
|
||||
:type filename: str, optional
|
||||
:param PINN pinn: the PINN object.
|
||||
:param list(str) components: the output variable to plot. If None, all
|
||||
the output variables of the problem are selected. Default value is
|
||||
None.
|
||||
:param dict fixed_variables: a dictionary with all the variables that
|
||||
should be kept fixed during the plot. The keys of the dictionary
|
||||
are the variables name whereas the values are the corresponding
|
||||
values of the variables. Defaults is `dict()`.
|
||||
:param {'contourf', 'pcolor'} method: the matplotlib method to use for
|
||||
plotting the solution. Default is 'contourf'.
|
||||
:param int res: the resolution, aka the number of points used for
|
||||
plotting in each axis. Default is 256.
|
||||
:param str filename: the file name to save the plot. If None, the plot
|
||||
is shown using the setted matplotlib frontend. Default is None.
|
||||
"""
|
||||
if components is None:
|
||||
components = [pinn.problem.output_variables]
|
||||
@@ -192,14 +188,12 @@ class Plotter:
|
||||
|
||||
def plot_loss(self, pinn, label=None, log_scale=True):
|
||||
"""
|
||||
Plot the loss function values during traininig.
|
||||
Plot the loss function values during traininig.
|
||||
|
||||
:param pinn: the PINN object.
|
||||
:type pinn: PINN
|
||||
:param label: matplolib label, defaults to None
|
||||
:type label: str, optional
|
||||
:param log_scale: use of log scale in plotting, defaults to True.
|
||||
:type log_scale: bool, optional
|
||||
:param PINN pinn: the PINN object.
|
||||
:param str label: the label to use in the legend, defaults to None.
|
||||
:param bool log_scale: If True, the y axis is in log scale. Default is
|
||||
True.
|
||||
"""
|
||||
|
||||
if not label:
|
||||
|
||||
Reference in New Issue
Block a user