Modify pina.__init__.py and rm useless .py files
* rm meta.py, plotter.py, writer.py * modify __init__ file * modify tests due to __init__ import
This commit is contained in:
committed by
Nicola Demo
parent
9c9d4fe7e4
commit
810d215ca0
@@ -1,18 +1,15 @@
|
|||||||
__all__ = [
|
__all__ = [
|
||||||
"Trainer", "LabelTensor", "Plotter", "Condition",
|
"Trainer",
|
||||||
"PinaDataModule", 'TorchOptimizer', 'Graph',
|
"LabelTensor",
|
||||||
"RadiusGraph", "KNNGraph"
|
"Condition",
|
||||||
|
"PinaDataModule",
|
||||||
|
'Graph',
|
||||||
|
"SolverInterface",
|
||||||
|
"MultiSolverInterface"
|
||||||
]
|
]
|
||||||
|
|
||||||
from .meta import *
|
|
||||||
from .label_tensor import LabelTensor
|
from .label_tensor import LabelTensor
|
||||||
from .solvers.solver import SolverInterface
|
from .graph import Graph
|
||||||
|
from .solvers.solver import SolverInterface, MultiSolverInterface
|
||||||
from .trainer import Trainer
|
from .trainer import Trainer
|
||||||
from .plotter import Plotter
|
|
||||||
from .condition.condition import Condition
|
from .condition.condition import Condition
|
||||||
|
|
||||||
from .data import PinaDataModule
|
|
||||||
|
|
||||||
from .optim import TorchOptimizer
|
|
||||||
from .optim import TorchScheduler
|
|
||||||
from .graph import Graph, RadiusGraph, KNNGraph
|
|
||||||
|
|||||||
22
pina/meta.py
22
pina/meta.py
@@ -1,22 +0,0 @@
|
|||||||
__all__ = [
|
|
||||||
"__project__",
|
|
||||||
"__title__",
|
|
||||||
"__author__",
|
|
||||||
"__copyright__",
|
|
||||||
"__license__",
|
|
||||||
"__version__",
|
|
||||||
"__mail__",
|
|
||||||
"__maintainer__",
|
|
||||||
"__status__",
|
|
||||||
]
|
|
||||||
|
|
||||||
__project__ = "PINA"
|
|
||||||
__title__ = "pina"
|
|
||||||
__author__ = "PINA Contributors"
|
|
||||||
__copyright__ = "2021-2025, PINA Contributors"
|
|
||||||
__license__ = "MIT"
|
|
||||||
__version__ = "0.2.0"
|
|
||||||
__mail__ = 'demo.nicola@gmail.com, dario.coscia@sissa.it' # TODO
|
|
||||||
__maintainer__ = __author__
|
|
||||||
__status__ = "Alpha"
|
|
||||||
__packagename__ = "pina-mathlab"
|
|
||||||
323
pina/plotter.py
323
pina/plotter.py
@@ -1,323 +0,0 @@
|
|||||||
""" Module for plotting. """
|
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
import torch
|
|
||||||
from pina.callbacks import MetricTracker
|
|
||||||
from .label_tensor import LabelTensor
|
|
||||||
|
|
||||||
|
|
||||||
class Plotter:
|
|
||||||
"""
|
|
||||||
Implementation of a plotter class, for easy visualizations.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def plot_samples(self, problem, variables=None, filename=None, **kwargs):
|
|
||||||
"""
|
|
||||||
Plot the training grid samples.
|
|
||||||
|
|
||||||
:param AbstractProblem problem: The PINA problem from where to plot
|
|
||||||
the domain.
|
|
||||||
:param list(str) variables: Variables to plot. If None, all variables
|
|
||||||
are plotted. If 'spatial', only spatial variables are plotted. If
|
|
||||||
'temporal', only temporal variables are plotted. Defaults to None.
|
|
||||||
:param str filename: The file name to save the plot. If None, the plot
|
|
||||||
is shown using the setted matplotlib frontend. Default is None.
|
|
||||||
|
|
||||||
.. todo::
|
|
||||||
- Add support for 3D plots.
|
|
||||||
- Fix support for more complex problems.
|
|
||||||
|
|
||||||
:Example:
|
|
||||||
>>> plotter = Plotter()
|
|
||||||
>>> plotter.plot_samples(problem=problem, variables='spatial')
|
|
||||||
"""
|
|
||||||
|
|
||||||
if variables is None:
|
|
||||||
variables = problem.domain.variables
|
|
||||||
elif variables == "spatial":
|
|
||||||
variables = problem.spatial_domain.variables
|
|
||||||
elif variables == "temporal":
|
|
||||||
variables = problem.temporal_domain.variables
|
|
||||||
|
|
||||||
if len(variables) not in [1, 2, 3]:
|
|
||||||
raise ValueError(
|
|
||||||
"Samples can be plotted only in " "dimensions 1, 2 and 3."
|
|
||||||
)
|
|
||||||
|
|
||||||
fig = plt.figure()
|
|
||||||
proj = "3d" if len(variables) == 3 else None
|
|
||||||
ax = fig.add_subplot(projection=proj)
|
|
||||||
for location in problem.input_pts:
|
|
||||||
coords = problem.input_pts[location].extract(variables).T.detach()
|
|
||||||
if len(variables) == 1: # 1D samples
|
|
||||||
ax.plot(
|
|
||||||
coords.flatten(),
|
|
||||||
torch.zeros(coords.flatten().shape),
|
|
||||||
".",
|
|
||||||
label=location,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
elif len(variables) == 2:
|
|
||||||
ax.plot(*coords, ".", label=location, **kwargs)
|
|
||||||
elif len(variables) == 3:
|
|
||||||
ax.scatter(*coords, ".", label=location, **kwargs)
|
|
||||||
|
|
||||||
ax.set_xlabel(variables[0])
|
|
||||||
try:
|
|
||||||
ax.set_ylabel(variables[1])
|
|
||||||
except (IndexError, AttributeError):
|
|
||||||
pass
|
|
||||||
|
|
||||||
try:
|
|
||||||
ax.set_zlabel(variables[2])
|
|
||||||
except (IndexError, AttributeError):
|
|
||||||
pass
|
|
||||||
|
|
||||||
plt.legend()
|
|
||||||
if filename:
|
|
||||||
plt.savefig(filename)
|
|
||||||
plt.close()
|
|
||||||
else:
|
|
||||||
plt.show()
|
|
||||||
|
|
||||||
def _1d_plot(self, pts, pred, v, method, truth_solution=None, **kwargs):
|
|
||||||
"""Plot solution for one dimensional function
|
|
||||||
|
|
||||||
:param pts: Points to plot the solution.
|
|
||||||
:type pts: torch.Tensor
|
|
||||||
:param pred: SolverInterface solution evaluated at 'pts'.
|
|
||||||
:type pred: torch.Tensor
|
|
||||||
:param v: Fixed variables when plotting the solution.
|
|
||||||
:type v: torch.Tensor
|
|
||||||
:param method: Not used, kept for code compatibility
|
|
||||||
:type method: None
|
|
||||||
:param truth_solution: Real solution evaluated at 'pts',
|
|
||||||
defaults to None.
|
|
||||||
:type truth_solution: torch.Tensor, optional
|
|
||||||
"""
|
|
||||||
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(8, 8))
|
|
||||||
|
|
||||||
ax.plot(pts.extract(v), pred, label="Neural Network solution", **kwargs)
|
|
||||||
|
|
||||||
if truth_solution:
|
|
||||||
truth_output = truth_solution(pts).detach()
|
|
||||||
ax.plot(
|
|
||||||
pts.extract(v), truth_output, label="True solution", **kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
# TODO: pred is a torch.Tensor, so no labels is available
|
|
||||||
# extra variable for labels should be
|
|
||||||
# passed in the function arguments.
|
|
||||||
# plt.ylabel(pred.labels[0])
|
|
||||||
plt.legend()
|
|
||||||
|
|
||||||
def _2d_plot(
|
|
||||||
self, pts, pred, v, res, method, truth_solution=None, **kwargs
|
|
||||||
):
|
|
||||||
"""Plot solution for two dimensional function
|
|
||||||
|
|
||||||
:param pts: Points to plot the solution.
|
|
||||||
:type pts: torch.Tensor
|
|
||||||
:param pred: ``SolverInterface`` solution evaluated at 'pts'.
|
|
||||||
:type pred: torch.Tensor
|
|
||||||
:param v: Fixed variables when plotting the solution.
|
|
||||||
:type v: torch.Tensor
|
|
||||||
:param method: Matplotlib method to plot 2-dimensional data,
|
|
||||||
see https://matplotlib.org/stable/api/axes_api.html for
|
|
||||||
reference.
|
|
||||||
:type method: str
|
|
||||||
:param truth_solution: Real solution evaluated at 'pts',
|
|
||||||
defaults to None.
|
|
||||||
:type truth_solution: torch.Tensor, optional
|
|
||||||
"""
|
|
||||||
|
|
||||||
grids = [p_.reshape(res, res) for p_ in pts.extract(v).T]
|
|
||||||
|
|
||||||
pred_output = pred.reshape(res, res)
|
|
||||||
if truth_solution:
|
|
||||||
truth_output = (
|
|
||||||
truth_solution(pts)
|
|
||||||
.float()
|
|
||||||
.reshape(res, res)
|
|
||||||
.as_subclass(torch.Tensor)
|
|
||||||
)
|
|
||||||
fig, ax = plt.subplots(nrows=1, ncols=3, figsize=(16, 6))
|
|
||||||
|
|
||||||
cb = getattr(ax[0], method)(*grids, pred_output, **kwargs)
|
|
||||||
fig.colorbar(cb, ax=ax[0])
|
|
||||||
ax[0].title.set_text("Neural Network prediction")
|
|
||||||
cb = getattr(ax[1], method)(*grids, truth_output, **kwargs)
|
|
||||||
fig.colorbar(cb, ax=ax[1])
|
|
||||||
ax[1].title.set_text("True solution")
|
|
||||||
cb = getattr(ax[2], method)(
|
|
||||||
*grids, (truth_output - pred_output), **kwargs
|
|
||||||
)
|
|
||||||
fig.colorbar(cb, ax=ax[2])
|
|
||||||
ax[2].title.set_text("Residual")
|
|
||||||
else:
|
|
||||||
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(8, 6))
|
|
||||||
cb = getattr(ax, method)(*grids, pred_output, **kwargs)
|
|
||||||
fig.colorbar(cb, ax=ax)
|
|
||||||
ax.title.set_text("Neural Network prediction")
|
|
||||||
|
|
||||||
def plot(
|
|
||||||
self,
|
|
||||||
solver,
|
|
||||||
components=None,
|
|
||||||
fixed_variables={},
|
|
||||||
method="contourf",
|
|
||||||
res=256,
|
|
||||||
filename=None,
|
|
||||||
title=None,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Plot sample of SolverInterface output.
|
|
||||||
|
|
||||||
:param SolverInterface solver: The ``SolverInterface`` object instance.
|
|
||||||
:param str | list(str) components: The output variable(s) to plot.
|
|
||||||
If None, all the output variables of the problem are selected.
|
|
||||||
Default value is None.
|
|
||||||
:param dict fixed_variables: A dictionary with all the variables that
|
|
||||||
should be kept fixed during the plot. The keys of the dictionary
|
|
||||||
are the variables name whereas the values are the corresponding
|
|
||||||
values of the variables. Defaults is `dict()`.
|
|
||||||
:param str method: The matplotlib method to use for
|
|
||||||
plotting the solution. Available methods are {'contourf', 'pcolor'}.
|
|
||||||
Default is 'contourf'.
|
|
||||||
:param int res: The resolution, aka the number of points used for
|
|
||||||
plotting in each axis. Default is 256.
|
|
||||||
:param str title: The title for the plot. If None, the plot
|
|
||||||
is shown without a title. Default is None.
|
|
||||||
:param str filename: The file name to save the plot. If None, the plot
|
|
||||||
is shown using the setted matplotlib frontend. Default is None.
|
|
||||||
"""
|
|
||||||
|
|
||||||
if components is None:
|
|
||||||
components = solver.problem.output_variables
|
|
||||||
|
|
||||||
if isinstance(components, str):
|
|
||||||
components = [components]
|
|
||||||
|
|
||||||
if not isinstance(components, list):
|
|
||||||
raise NotImplementedError(
|
|
||||||
"Output variables must be passed"
|
|
||||||
"as a string or a list of strings."
|
|
||||||
)
|
|
||||||
|
|
||||||
if len(components) > 1:
|
|
||||||
raise NotImplementedError(
|
|
||||||
"Multidimensional plots are not implemented, "
|
|
||||||
"set components to an available components of"
|
|
||||||
" the problem."
|
|
||||||
)
|
|
||||||
v = [
|
|
||||||
var
|
|
||||||
for var in solver.problem.input_variables
|
|
||||||
if var not in fixed_variables.keys()
|
|
||||||
]
|
|
||||||
pts = solver.problem.domain.sample(res, "grid", variables=v)
|
|
||||||
|
|
||||||
fixed_pts = torch.ones(pts.shape[0], len(fixed_variables))
|
|
||||||
fixed_pts *= torch.tensor(list(fixed_variables.values()))
|
|
||||||
fixed_pts = fixed_pts.as_subclass(LabelTensor)
|
|
||||||
fixed_pts.labels = list(fixed_variables.keys())
|
|
||||||
|
|
||||||
pts = pts.append(fixed_pts)
|
|
||||||
pts = pts.to(device=solver.device)
|
|
||||||
|
|
||||||
# computing soluting and sending to cpu
|
|
||||||
predicted_output = solver.forward(pts).extract(components)
|
|
||||||
predicted_output = (
|
|
||||||
predicted_output.as_subclass(torch.Tensor).cpu().detach()
|
|
||||||
)
|
|
||||||
pts = pts.cpu()
|
|
||||||
truth_solution = getattr(solver.problem, "truth_solution", None)
|
|
||||||
|
|
||||||
if len(v) == 1:
|
|
||||||
self._1d_plot(
|
|
||||||
pts, predicted_output, v, method, truth_solution, **kwargs
|
|
||||||
)
|
|
||||||
elif len(v) == 2:
|
|
||||||
self._2d_plot(
|
|
||||||
pts, predicted_output, v, res, method, truth_solution, **kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
plt.tight_layout()
|
|
||||||
if title is not None:
|
|
||||||
plt.title(title)
|
|
||||||
|
|
||||||
if filename:
|
|
||||||
plt.savefig(filename)
|
|
||||||
plt.close()
|
|
||||||
else:
|
|
||||||
plt.show()
|
|
||||||
|
|
||||||
def plot_loss(
|
|
||||||
self,
|
|
||||||
trainer,
|
|
||||||
metrics=None,
|
|
||||||
logy=False,
|
|
||||||
logx=False,
|
|
||||||
filename=None,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Plot the loss function values during traininig.
|
|
||||||
|
|
||||||
:param trainer: the PINA Trainer object instance.
|
|
||||||
:type trainer: Trainer
|
|
||||||
:param str | list(str) metric: The metrics to use in the y axis. If None, the mean loss
|
|
||||||
is plotted.
|
|
||||||
:param bool logy: If True, the y axis is in log scale. Default is
|
|
||||||
True.
|
|
||||||
:param bool logx: If True, the x axis is in log scale. Default is
|
|
||||||
True.
|
|
||||||
:param str filename: The file name to save the plot. If None, the plot
|
|
||||||
is shown using the setted matplotlib frontend. Default is None.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# check that MetricTracker has been used
|
|
||||||
list_ = [
|
|
||||||
idx
|
|
||||||
for idx, s in enumerate(trainer.callbacks)
|
|
||||||
if isinstance(s, MetricTracker)
|
|
||||||
]
|
|
||||||
if not bool(list_):
|
|
||||||
raise FileNotFoundError(
|
|
||||||
"MetricTracker should be used as a callback during training to"
|
|
||||||
" use this method."
|
|
||||||
)
|
|
||||||
|
|
||||||
# extract trainer metrics
|
|
||||||
trainer_metrics = trainer.callbacks[list_[0]].metrics
|
|
||||||
if metrics is None:
|
|
||||||
metrics = ["mean_loss"]
|
|
||||||
elif not isinstance(metrics, list):
|
|
||||||
raise ValueError("metrics must be class list.")
|
|
||||||
|
|
||||||
# loop over metrics to plot
|
|
||||||
for metric in metrics:
|
|
||||||
if metric not in trainer_metrics:
|
|
||||||
raise ValueError(
|
|
||||||
f"{metric} not a valid metric. Available metrics are {list(trainer_metrics.keys())}."
|
|
||||||
)
|
|
||||||
loss = trainer_metrics[metric]
|
|
||||||
epochs = range(len(loss))
|
|
||||||
plt.plot(epochs, loss.cpu(), **kwargs)
|
|
||||||
|
|
||||||
# plotting
|
|
||||||
plt.xlabel("epoch")
|
|
||||||
plt.ylabel("loss")
|
|
||||||
plt.legend()
|
|
||||||
|
|
||||||
# log axis
|
|
||||||
if logy:
|
|
||||||
plt.yscale("log")
|
|
||||||
if logx:
|
|
||||||
plt.xscale("log")
|
|
||||||
|
|
||||||
# saving in file
|
|
||||||
if filename:
|
|
||||||
plt.savefig(filename)
|
|
||||||
plt.close()
|
|
||||||
@@ -1,50 +0,0 @@
|
|||||||
""" Module for plotting. """
|
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from pina import LabelTensor
|
|
||||||
|
|
||||||
|
|
||||||
class Writer:
|
|
||||||
"""
|
|
||||||
Implementation of a writer class, for textual output.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, frequency_print=10, header="any") -> None:
|
|
||||||
"""
|
|
||||||
The constructor of the class.
|
|
||||||
|
|
||||||
:param int frequency_print: the frequency in epochs of printing.
|
|
||||||
:param ['any', 'begin', 'none'] header: the header of the output.
|
|
||||||
"""
|
|
||||||
|
|
||||||
self._frequency_print = frequency_print
|
|
||||||
self._header = header
|
|
||||||
|
|
||||||
def header(self, trainer):
|
|
||||||
"""
|
|
||||||
The method for printing the header.
|
|
||||||
"""
|
|
||||||
header = []
|
|
||||||
for condition_name in trainer.problem.conditions:
|
|
||||||
header.append(f"{condition_name}")
|
|
||||||
|
|
||||||
return header
|
|
||||||
|
|
||||||
def write_loss(self, trainer):
|
|
||||||
"""
|
|
||||||
The method for writing the output.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def write_loss_in_loop(self, trainer, loss):
|
|
||||||
"""
|
|
||||||
The method for writing the output within the training loop.
|
|
||||||
|
|
||||||
:param pina.trainer.Trainer trainer: the trainer object.
|
|
||||||
"""
|
|
||||||
|
|
||||||
if trainer.trained_epoch % self._frequency_print == 0:
|
|
||||||
print(f"Epoch {trainer.trained_epoch:05d}: {loss.item():.5e}")
|
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
import torch
|
import torch
|
||||||
import pytest
|
import pytest
|
||||||
from pina import TorchOptimizer
|
from pina.optim import TorchOptimizer
|
||||||
|
|
||||||
opt_list = [
|
opt_list = [
|
||||||
torch.optim.Adam, torch.optim.AdamW, torch.optim.SGD, torch.optim.RMSprop
|
torch.optim.Adam, torch.optim.AdamW, torch.optim.SGD, torch.optim.RMSprop
|
||||||
|
|||||||
@@ -1,75 +0,0 @@
|
|||||||
from pina.domain import CartesianDomain
|
|
||||||
from pina import Condition, Plotter
|
|
||||||
from matplotlib.testing.decorators import image_comparison
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
from pina.problem import SpatialProblem
|
|
||||||
from pina.equation import FixedValue
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
|
|
||||||
TODO : Fix the tests once the Plotter class is updated
|
|
||||||
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
|
|
||||||
|
|
||||||
class FooProblem1D(SpatialProblem):
|
|
||||||
|
|
||||||
# assign output/ spatial and temporal variables
|
|
||||||
output_variables = ['u']
|
|
||||||
spatial_domain = CartesianDomain({'x' : [-1, 1]})
|
|
||||||
|
|
||||||
# problem condition statement
|
|
||||||
conditions = {
|
|
||||||
'D': Condition(location=CartesianDomain({'x': [-1, 1]}), equation=FixedValue(0.)),
|
|
||||||
}
|
|
||||||
|
|
||||||
class FooProblem2D(SpatialProblem):
|
|
||||||
|
|
||||||
# assign output/ spatial and temporal variables
|
|
||||||
output_variables = ['u']
|
|
||||||
spatial_domain = CartesianDomain({'x' : [-1, 1], 'y': [-1, 1]})
|
|
||||||
|
|
||||||
# problem condition statement
|
|
||||||
conditions = {
|
|
||||||
'D': Condition(location=CartesianDomain({'x' : [-1, 1], 'y': [-1, 1]}), equation=FixedValue(0.)),
|
|
||||||
}
|
|
||||||
|
|
||||||
class FooProblem3D(SpatialProblem):
|
|
||||||
|
|
||||||
# assign output/ spatial and temporal variables
|
|
||||||
output_variables = ['u']
|
|
||||||
spatial_domain = CartesianDomain({'x' : [-1, 1], 'y': [-1, 1], 'z':[-1,1]})
|
|
||||||
|
|
||||||
# problem condition statement
|
|
||||||
conditions = {
|
|
||||||
'D': Condition(location=CartesianDomain({'x' : [-1, 1], 'y': [-1, 1], 'z':[-1,1]}), equation=FixedValue(0.)),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def test_constructor():
|
|
||||||
Plotter()
|
|
||||||
|
|
||||||
def test_plot_samples_1d():
|
|
||||||
problem = FooProblem1D()
|
|
||||||
problem.discretise_domain(n=10, mode='grid', variables = 'x', locations=['D'])
|
|
||||||
pl = Plotter()
|
|
||||||
pl.plot_samples(problem=problem, filename='fig.png')
|
|
||||||
import os
|
|
||||||
os.remove('fig.png')
|
|
||||||
|
|
||||||
def test_plot_samples_2d():
|
|
||||||
problem = FooProblem2D()
|
|
||||||
problem.discretise_domain(n=10, mode='grid', variables = ['x', 'y'], locations=['D'])
|
|
||||||
pl = Plotter()
|
|
||||||
pl.plot_samples(problem=problem, filename='fig.png')
|
|
||||||
import os
|
|
||||||
os.remove('fig.png')
|
|
||||||
|
|
||||||
def test_plot_samples_3d():
|
|
||||||
problem = FooProblem3D()
|
|
||||||
problem.discretise_domain(n=10, mode='grid', variables = ['x', 'y', 'z'], locations=['D'])
|
|
||||||
pl = Plotter()
|
|
||||||
pl.plot_samples(problem=problem, filename='fig.png')
|
|
||||||
import os
|
|
||||||
os.remove('fig.png')
|
|
||||||
"""
|
|
||||||
@@ -2,7 +2,7 @@ import torch
|
|||||||
from pina.problem import AbstractProblem
|
from pina.problem import AbstractProblem
|
||||||
from pina.condition import InputOutputPointsCondition
|
from pina.condition import InputOutputPointsCondition
|
||||||
from pina.problem.zoo.supervised_problem import SupervisedProblem
|
from pina.problem.zoo.supervised_problem import SupervisedProblem
|
||||||
from pina import RadiusGraph
|
from pina.graph import RadiusGraph
|
||||||
|
|
||||||
def test_constructor():
|
def test_constructor():
|
||||||
input_ = torch.rand((100,10))
|
input_ = torch.rand((100,10))
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import pytest
|
import pytest
|
||||||
from pina import TorchOptimizer, TorchScheduler
|
from pina.optim import TorchOptimizer, TorchScheduler
|
||||||
|
|
||||||
opt_list = [
|
opt_list = [
|
||||||
torch.optim.Adam,
|
torch.optim.Adam,
|
||||||
|
|||||||
Reference in New Issue
Block a user