* clean `condition` module
* add docs
This commit is contained in:
Nicola Demo
2023-04-18 15:00:26 +02:00
committed by GitHub
parent 736c78fd64
commit 2ca08b5236
18 changed files with 198 additions and 158 deletions

View File

@@ -1,14 +1,9 @@
""" Module for plotting. """
import matplotlib
# matplotlib.use('Qt5Agg')
import matplotlib.pyplot as plt
import numpy as np
import torch
from pina import LabelTensor
from pina import PINN
from .problem import SpatialProblem, TimeDependentProblem
#from pina.tdproblem1d import TimeDepProblem1D
class Plotter:
@@ -20,11 +15,14 @@ class Plotter:
"""
Plot a sample of solution.
:param pinn: the PINN object.
:type pinn: PINN
:param variables: pinn variable domains: spatial or temporal,
defaults to None.
:type variables: str, optional
:param PINN pinn: the PINN object.
:param list(str) variables: variables to plot. If None, all variables
are plotted. If 'spatial', only spatial variables are plotted. If
'temporal', only temporal variables are plotted. Defaults to None.
.. todo::
- Add support for 3D plots.
- Fix support for more complex problems.
:Example:
>>> plotter = Plotter()
@@ -134,24 +132,22 @@ class Plotter:
def plot(self, pinn, components=None, fixed_variables={}, method='contourf',
res=256, filename=None, **kwargs):
"""
Plot sample of PINN output.
Plot sample of PINN output.
:param pinn: the PINN object.
:type pinn: PINN
:param components: function components to plot, defaults to None.
:type components: list['str'], optional
:param fixed_variables: function variables to be kept fixed during
plotting passed as a dict where the dict-key is the variable
and the dict-value is the value to be kept fixed, defaults to {}.
:type fixed_variables: dict, optional
:param method: matplotlib method to plot the solution,
defaults to 'contourf'.
:type method: str, optional
:param res: number of points used for plotting in each axis,
defaults to 256.
:type res: int, optional
:param filename: file name to save the plot, defaults to None
:type filename: str, optional
:param PINN pinn: the PINN object.
:param list(str) components: the output variable to plot. If None, all
the output variables of the problem are selected. Default value is
None.
:param dict fixed_variables: a dictionary with all the variables that
should be kept fixed during the plot. The keys of the dictionary
are the variables name whereas the values are the corresponding
values of the variables. Defaults is `dict()`.
:param {'contourf', 'pcolor'} method: the matplotlib method to use for
plotting the solution. Default is 'contourf'.
:param int res: the resolution, aka the number of points used for
plotting in each axis. Default is 256.
:param str filename: the file name to save the plot. If None, the plot
is shown using the setted matplotlib frontend. Default is None.
"""
if components is None:
components = [pinn.problem.output_variables]
@@ -192,14 +188,12 @@ class Plotter:
def plot_loss(self, pinn, label=None, log_scale=True):
"""
Plot the loss function values during traininig.
Plot the loss function values during traininig.
:param pinn: the PINN object.
:type pinn: PINN
:param label: matplolib label, defaults to None
:type label: str, optional
:param log_scale: use of log scale in plotting, defaults to True.
:type log_scale: bool, optional
:param PINN pinn: the PINN object.
:param str label: the label to use in the legend, defaults to None.
:param bool log_scale: If True, the y axis is in log scale. Default is
True.
"""
if not label: