From 810d215ca0dd08d7f129dbfb4c9ac9ee129d7ae7 Mon Sep 17 00:00:00 2001 From: Dario Coscia Date: Wed, 19 Feb 2025 11:14:53 +0100 Subject: [PATCH] Modify pina.__init__.py and rm useless .py files * rm meta.py, plotter.py, writer.py * modify __init__ file * modify tests due to __init__ import --- pina/__init__.py | 21 +- pina/meta.py | 22 -- pina/plotter.py | 323 ------------------ pina/writer.py | 50 --- tests/test_optimizer.py | 2 +- tests/test_plotter.py | 75 ---- .../test_supervised_problem.py | 2 +- tests/test_scheduler.py | 2 +- 8 files changed, 12 insertions(+), 485 deletions(-) delete mode 100644 pina/meta.py delete mode 100644 pina/plotter.py delete mode 100644 pina/writer.py delete mode 100644 tests/test_plotter.py diff --git a/pina/__init__.py b/pina/__init__.py index e69db88..06af48f 100644 --- a/pina/__init__.py +++ b/pina/__init__.py @@ -1,18 +1,15 @@ __all__ = [ - "Trainer", "LabelTensor", "Plotter", "Condition", - "PinaDataModule", 'TorchOptimizer', 'Graph', - "RadiusGraph", "KNNGraph" + "Trainer", + "LabelTensor", + "Condition", + "PinaDataModule", + 'Graph', + "SolverInterface", + "MultiSolverInterface" ] -from .meta import * from .label_tensor import LabelTensor -from .solvers.solver import SolverInterface +from .graph import Graph +from .solvers.solver import SolverInterface, MultiSolverInterface from .trainer import Trainer -from .plotter import Plotter from .condition.condition import Condition - -from .data import PinaDataModule - -from .optim import TorchOptimizer -from .optim import TorchScheduler -from .graph import Graph, RadiusGraph, KNNGraph diff --git a/pina/meta.py b/pina/meta.py deleted file mode 100644 index ac443b5..0000000 --- a/pina/meta.py +++ /dev/null @@ -1,22 +0,0 @@ -__all__ = [ - "__project__", - "__title__", - "__author__", - "__copyright__", - "__license__", - "__version__", - "__mail__", - "__maintainer__", - "__status__", -] - -__project__ = "PINA" -__title__ = "pina" -__author__ = "PINA Contributors" -__copyright__ = "2021-2025, PINA Contributors" -__license__ = "MIT" -__version__ = "0.2.0" -__mail__ = 'demo.nicola@gmail.com, dario.coscia@sissa.it' # TODO -__maintainer__ = __author__ -__status__ = "Alpha" -__packagename__ = "pina-mathlab" diff --git a/pina/plotter.py b/pina/plotter.py deleted file mode 100644 index eedec80..0000000 --- a/pina/plotter.py +++ /dev/null @@ -1,323 +0,0 @@ -""" Module for plotting. """ - -import matplotlib.pyplot as plt -import torch -from pina.callbacks import MetricTracker -from .label_tensor import LabelTensor - - -class Plotter: - """ - Implementation of a plotter class, for easy visualizations. - """ - - def plot_samples(self, problem, variables=None, filename=None, **kwargs): - """ - Plot the training grid samples. - - :param AbstractProblem problem: The PINA problem from where to plot - the domain. - :param list(str) variables: Variables to plot. If None, all variables - are plotted. If 'spatial', only spatial variables are plotted. If - 'temporal', only temporal variables are plotted. Defaults to None. - :param str filename: The file name to save the plot. If None, the plot - is shown using the setted matplotlib frontend. Default is None. - - .. todo:: - - Add support for 3D plots. - - Fix support for more complex problems. - - :Example: - >>> plotter = Plotter() - >>> plotter.plot_samples(problem=problem, variables='spatial') - """ - - if variables is None: - variables = problem.domain.variables - elif variables == "spatial": - variables = problem.spatial_domain.variables - elif variables == "temporal": - variables = problem.temporal_domain.variables - - if len(variables) not in [1, 2, 3]: - raise ValueError( - "Samples can be plotted only in " "dimensions 1, 2 and 3." - ) - - fig = plt.figure() - proj = "3d" if len(variables) == 3 else None - ax = fig.add_subplot(projection=proj) - for location in problem.input_pts: - coords = problem.input_pts[location].extract(variables).T.detach() - if len(variables) == 1: # 1D samples - ax.plot( - coords.flatten(), - torch.zeros(coords.flatten().shape), - ".", - label=location, - **kwargs, - ) - elif len(variables) == 2: - ax.plot(*coords, ".", label=location, **kwargs) - elif len(variables) == 3: - ax.scatter(*coords, ".", label=location, **kwargs) - - ax.set_xlabel(variables[0]) - try: - ax.set_ylabel(variables[1]) - except (IndexError, AttributeError): - pass - - try: - ax.set_zlabel(variables[2]) - except (IndexError, AttributeError): - pass - - plt.legend() - if filename: - plt.savefig(filename) - plt.close() - else: - plt.show() - - def _1d_plot(self, pts, pred, v, method, truth_solution=None, **kwargs): - """Plot solution for one dimensional function - - :param pts: Points to plot the solution. - :type pts: torch.Tensor - :param pred: SolverInterface solution evaluated at 'pts'. - :type pred: torch.Tensor - :param v: Fixed variables when plotting the solution. - :type v: torch.Tensor - :param method: Not used, kept for code compatibility - :type method: None - :param truth_solution: Real solution evaluated at 'pts', - defaults to None. - :type truth_solution: torch.Tensor, optional - """ - fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(8, 8)) - - ax.plot(pts.extract(v), pred, label="Neural Network solution", **kwargs) - - if truth_solution: - truth_output = truth_solution(pts).detach() - ax.plot( - pts.extract(v), truth_output, label="True solution", **kwargs - ) - - # TODO: pred is a torch.Tensor, so no labels is available - # extra variable for labels should be - # passed in the function arguments. - # plt.ylabel(pred.labels[0]) - plt.legend() - - def _2d_plot( - self, pts, pred, v, res, method, truth_solution=None, **kwargs - ): - """Plot solution for two dimensional function - - :param pts: Points to plot the solution. - :type pts: torch.Tensor - :param pred: ``SolverInterface`` solution evaluated at 'pts'. - :type pred: torch.Tensor - :param v: Fixed variables when plotting the solution. - :type v: torch.Tensor - :param method: Matplotlib method to plot 2-dimensional data, - see https://matplotlib.org/stable/api/axes_api.html for - reference. - :type method: str - :param truth_solution: Real solution evaluated at 'pts', - defaults to None. - :type truth_solution: torch.Tensor, optional - """ - - grids = [p_.reshape(res, res) for p_ in pts.extract(v).T] - - pred_output = pred.reshape(res, res) - if truth_solution: - truth_output = ( - truth_solution(pts) - .float() - .reshape(res, res) - .as_subclass(torch.Tensor) - ) - fig, ax = plt.subplots(nrows=1, ncols=3, figsize=(16, 6)) - - cb = getattr(ax[0], method)(*grids, pred_output, **kwargs) - fig.colorbar(cb, ax=ax[0]) - ax[0].title.set_text("Neural Network prediction") - cb = getattr(ax[1], method)(*grids, truth_output, **kwargs) - fig.colorbar(cb, ax=ax[1]) - ax[1].title.set_text("True solution") - cb = getattr(ax[2], method)( - *grids, (truth_output - pred_output), **kwargs - ) - fig.colorbar(cb, ax=ax[2]) - ax[2].title.set_text("Residual") - else: - fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(8, 6)) - cb = getattr(ax, method)(*grids, pred_output, **kwargs) - fig.colorbar(cb, ax=ax) - ax.title.set_text("Neural Network prediction") - - def plot( - self, - solver, - components=None, - fixed_variables={}, - method="contourf", - res=256, - filename=None, - title=None, - **kwargs, - ): - """ - Plot sample of SolverInterface output. - - :param SolverInterface solver: The ``SolverInterface`` object instance. - :param str | list(str) components: The output variable(s) to plot. - If None, all the output variables of the problem are selected. - Default value is None. - :param dict fixed_variables: A dictionary with all the variables that - should be kept fixed during the plot. The keys of the dictionary - are the variables name whereas the values are the corresponding - values of the variables. Defaults is `dict()`. - :param str method: The matplotlib method to use for - plotting the solution. Available methods are {'contourf', 'pcolor'}. - Default is 'contourf'. - :param int res: The resolution, aka the number of points used for - plotting in each axis. Default is 256. - :param str title: The title for the plot. If None, the plot - is shown without a title. Default is None. - :param str filename: The file name to save the plot. If None, the plot - is shown using the setted matplotlib frontend. Default is None. - """ - - if components is None: - components = solver.problem.output_variables - - if isinstance(components, str): - components = [components] - - if not isinstance(components, list): - raise NotImplementedError( - "Output variables must be passed" - "as a string or a list of strings." - ) - - if len(components) > 1: - raise NotImplementedError( - "Multidimensional plots are not implemented, " - "set components to an available components of" - " the problem." - ) - v = [ - var - for var in solver.problem.input_variables - if var not in fixed_variables.keys() - ] - pts = solver.problem.domain.sample(res, "grid", variables=v) - - fixed_pts = torch.ones(pts.shape[0], len(fixed_variables)) - fixed_pts *= torch.tensor(list(fixed_variables.values())) - fixed_pts = fixed_pts.as_subclass(LabelTensor) - fixed_pts.labels = list(fixed_variables.keys()) - - pts = pts.append(fixed_pts) - pts = pts.to(device=solver.device) - - # computing soluting and sending to cpu - predicted_output = solver.forward(pts).extract(components) - predicted_output = ( - predicted_output.as_subclass(torch.Tensor).cpu().detach() - ) - pts = pts.cpu() - truth_solution = getattr(solver.problem, "truth_solution", None) - - if len(v) == 1: - self._1d_plot( - pts, predicted_output, v, method, truth_solution, **kwargs - ) - elif len(v) == 2: - self._2d_plot( - pts, predicted_output, v, res, method, truth_solution, **kwargs - ) - - plt.tight_layout() - if title is not None: - plt.title(title) - - if filename: - plt.savefig(filename) - plt.close() - else: - plt.show() - - def plot_loss( - self, - trainer, - metrics=None, - logy=False, - logx=False, - filename=None, - **kwargs, - ): - """ - Plot the loss function values during traininig. - - :param trainer: the PINA Trainer object instance. - :type trainer: Trainer - :param str | list(str) metric: The metrics to use in the y axis. If None, the mean loss - is plotted. - :param bool logy: If True, the y axis is in log scale. Default is - True. - :param bool logx: If True, the x axis is in log scale. Default is - True. - :param str filename: The file name to save the plot. If None, the plot - is shown using the setted matplotlib frontend. Default is None. - """ - - # check that MetricTracker has been used - list_ = [ - idx - for idx, s in enumerate(trainer.callbacks) - if isinstance(s, MetricTracker) - ] - if not bool(list_): - raise FileNotFoundError( - "MetricTracker should be used as a callback during training to" - " use this method." - ) - - # extract trainer metrics - trainer_metrics = trainer.callbacks[list_[0]].metrics - if metrics is None: - metrics = ["mean_loss"] - elif not isinstance(metrics, list): - raise ValueError("metrics must be class list.") - - # loop over metrics to plot - for metric in metrics: - if metric not in trainer_metrics: - raise ValueError( - f"{metric} not a valid metric. Available metrics are {list(trainer_metrics.keys())}." - ) - loss = trainer_metrics[metric] - epochs = range(len(loss)) - plt.plot(epochs, loss.cpu(), **kwargs) - - # plotting - plt.xlabel("epoch") - plt.ylabel("loss") - plt.legend() - - # log axis - if logy: - plt.yscale("log") - if logx: - plt.xscale("log") - - # saving in file - if filename: - plt.savefig(filename) - plt.close() diff --git a/pina/writer.py b/pina/writer.py deleted file mode 100644 index 831c1cc..0000000 --- a/pina/writer.py +++ /dev/null @@ -1,50 +0,0 @@ -""" Module for plotting. """ - -import matplotlib.pyplot as plt -import numpy as np -import torch - -from pina import LabelTensor - - -class Writer: - """ - Implementation of a writer class, for textual output. - """ - - def __init__(self, frequency_print=10, header="any") -> None: - """ - The constructor of the class. - - :param int frequency_print: the frequency in epochs of printing. - :param ['any', 'begin', 'none'] header: the header of the output. - """ - - self._frequency_print = frequency_print - self._header = header - - def header(self, trainer): - """ - The method for printing the header. - """ - header = [] - for condition_name in trainer.problem.conditions: - header.append(f"{condition_name}") - - return header - - def write_loss(self, trainer): - """ - The method for writing the output. - """ - pass - - def write_loss_in_loop(self, trainer, loss): - """ - The method for writing the output within the training loop. - - :param pina.trainer.Trainer trainer: the trainer object. - """ - - if trainer.trained_epoch % self._frequency_print == 0: - print(f"Epoch {trainer.trained_epoch:05d}: {loss.item():.5e}") diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index bdc87ca..89b1293 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -1,6 +1,6 @@ import torch import pytest -from pina import TorchOptimizer +from pina.optim import TorchOptimizer opt_list = [ torch.optim.Adam, torch.optim.AdamW, torch.optim.SGD, torch.optim.RMSprop diff --git a/tests/test_plotter.py b/tests/test_plotter.py deleted file mode 100644 index 838963c..0000000 --- a/tests/test_plotter.py +++ /dev/null @@ -1,75 +0,0 @@ -from pina.domain import CartesianDomain -from pina import Condition, Plotter -from matplotlib.testing.decorators import image_comparison -import matplotlib.pyplot as plt -from pina.problem import SpatialProblem -from pina.equation import FixedValue - -""" - -!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! -TODO : Fix the tests once the Plotter class is updated -!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! - -class FooProblem1D(SpatialProblem): - - # assign output/ spatial and temporal variables - output_variables = ['u'] - spatial_domain = CartesianDomain({'x' : [-1, 1]}) - - # problem condition statement - conditions = { - 'D': Condition(location=CartesianDomain({'x': [-1, 1]}), equation=FixedValue(0.)), - } - -class FooProblem2D(SpatialProblem): - - # assign output/ spatial and temporal variables - output_variables = ['u'] - spatial_domain = CartesianDomain({'x' : [-1, 1], 'y': [-1, 1]}) - - # problem condition statement - conditions = { - 'D': Condition(location=CartesianDomain({'x' : [-1, 1], 'y': [-1, 1]}), equation=FixedValue(0.)), - } - -class FooProblem3D(SpatialProblem): - - # assign output/ spatial and temporal variables - output_variables = ['u'] - spatial_domain = CartesianDomain({'x' : [-1, 1], 'y': [-1, 1], 'z':[-1,1]}) - - # problem condition statement - conditions = { - 'D': Condition(location=CartesianDomain({'x' : [-1, 1], 'y': [-1, 1], 'z':[-1,1]}), equation=FixedValue(0.)), - } - - - -def test_constructor(): - Plotter() - -def test_plot_samples_1d(): - problem = FooProblem1D() - problem.discretise_domain(n=10, mode='grid', variables = 'x', locations=['D']) - pl = Plotter() - pl.plot_samples(problem=problem, filename='fig.png') - import os - os.remove('fig.png') - -def test_plot_samples_2d(): - problem = FooProblem2D() - problem.discretise_domain(n=10, mode='grid', variables = ['x', 'y'], locations=['D']) - pl = Plotter() - pl.plot_samples(problem=problem, filename='fig.png') - import os - os.remove('fig.png') - -def test_plot_samples_3d(): - problem = FooProblem3D() - problem.discretise_domain(n=10, mode='grid', variables = ['x', 'y', 'z'], locations=['D']) - pl = Plotter() - pl.plot_samples(problem=problem, filename='fig.png') - import os - os.remove('fig.png') -""" \ No newline at end of file diff --git a/tests/test_problem_zoo/test_supervised_problem.py b/tests/test_problem_zoo/test_supervised_problem.py index b9c7950..f3ac567 100644 --- a/tests/test_problem_zoo/test_supervised_problem.py +++ b/tests/test_problem_zoo/test_supervised_problem.py @@ -2,7 +2,7 @@ import torch from pina.problem import AbstractProblem from pina.condition import InputOutputPointsCondition from pina.problem.zoo.supervised_problem import SupervisedProblem -from pina import RadiusGraph +from pina.graph import RadiusGraph def test_constructor(): input_ = torch.rand((100,10)) diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index 4cde13e..5f3e3e7 100644 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -1,7 +1,7 @@ import torch import pytest -from pina import TorchOptimizer, TorchScheduler +from pina.optim import TorchOptimizer, TorchScheduler opt_list = [ torch.optim.Adam,