Fix rendering and codacy
This commit is contained in:
committed by
Nicola Demo
parent
05105dd517
commit
001d1fc9cf
@@ -118,7 +118,7 @@ class Collector:
|
|||||||
"""
|
"""
|
||||||
Store inside data collections the sampled data of the problem. These
|
Store inside data collections the sampled data of the problem. These
|
||||||
comes from the conditions that require sampling (e.g.
|
comes from the conditions that require sampling (e.g.
|
||||||
:class:`~pina.condition.domain_equation_condition.
|
:class:`~pina.condition.domain_equation_condition.\
|
||||||
DomainEquationCondition`).
|
DomainEquationCondition`).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|||||||
@@ -244,7 +244,7 @@ class PinaSampler:
|
|||||||
|
|
||||||
class PinaDataModule(LightningDataModule):
|
class PinaDataModule(LightningDataModule):
|
||||||
"""
|
"""
|
||||||
This class extends :class:`lightning.pytorch.LightningDataModule`,
|
This class extends :class:`~lightning.pytorch.core.LightningDataModule`,
|
||||||
allowing proper creation and management of different types of datasets
|
allowing proper creation and management of different types of datasets
|
||||||
defined in PINA.
|
defined in PINA.
|
||||||
"""
|
"""
|
||||||
@@ -274,18 +274,18 @@ class PinaDataModule(LightningDataModule):
|
|||||||
:param float val_size: Fraction of elements in the validation split. It
|
:param float val_size: Fraction of elements in the validation split. It
|
||||||
must be in the range [0, 1].
|
must be in the range [0, 1].
|
||||||
:param batch_size: The batch size used for training. If ``None``, the
|
:param batch_size: The batch size used for training. If ``None``, the
|
||||||
entire dataset is returned in a single batch.
|
entire dataset is returned in a single batch. Default is ``None``.
|
||||||
:type batch_size: int | None
|
:type batch_size: int
|
||||||
:param bool shuffle: Whether to shuffle the dataset before splitting.
|
:param bool shuffle: Whether to shuffle the dataset before splitting.
|
||||||
Default True.
|
Default ``Tru``e.
|
||||||
:param bool repeat: Whether to repeat the dataset indefinitely.
|
:param bool repeat: Whether to repeat the dataset indefinitely.
|
||||||
Default False.
|
Default ``False``.
|
||||||
:param automatic_batching: Whether to enable automatic batching.
|
:param automatic_batching: Whether to enable automatic batching.
|
||||||
Default False.
|
Default ``False``.
|
||||||
:param int num_workers: Number of worker threads for data loading.
|
:param int num_workers: Number of worker threads for data loading.
|
||||||
Default 0 (serial loading).
|
Default ``0`` (serial loading).
|
||||||
:param bool pin_memory: Whether to use pinned memory for faster data
|
:param bool pin_memory: Whether to use pinned memory for faster data
|
||||||
transfer to GPU. Default False.
|
transfer to GPU. Default ``False``.
|
||||||
|
|
||||||
:raises ValueError: If at least one of the splits is negative.
|
:raises ValueError: If at least one of the splits is negative.
|
||||||
:raises ValueError: If the sum of the splits is different from 1.
|
:raises ValueError: If the sum of the splits is different from 1.
|
||||||
@@ -643,7 +643,7 @@ class PinaDataModule(LightningDataModule):
|
|||||||
Return all the input points coming from all the datasets.
|
Return all the input points coming from all the datasets.
|
||||||
|
|
||||||
:return: The input points for training.
|
:return: The input points for training.
|
||||||
:rtype dict
|
:rtype: dict
|
||||||
"""
|
"""
|
||||||
|
|
||||||
to_return = {}
|
to_return = {}
|
||||||
|
|||||||
@@ -12,12 +12,13 @@ class PinaDatasetFactory:
|
|||||||
"""
|
"""
|
||||||
Factory class for the PINA dataset.
|
Factory class for the PINA dataset.
|
||||||
|
|
||||||
Depending on the type inside the conditions, it creates a different dataset
|
Depending on the data type inside the conditions, it instanciate an object
|
||||||
object:
|
belonging to the appropriate subclass of
|
||||||
|
:class:`~pina.data.dataset.PinaDataset`. The possible subclasses are:
|
||||||
|
|
||||||
- :class:`~pina.data.dataset.PinaTensorDataset` for handling
|
- :class:`~pina.data.dataset.PinaTensorDataset`, for handling \
|
||||||
:class:`torch.Tensor` and :class:`~pina.label_tensor.LabelTensor` data.
|
:class:`torch.Tensor` and :class:`~pina.label_tensor.LabelTensor` data.
|
||||||
- :class:`~pina.data.dataset.PinaGraphDataset` for handling
|
- :class:`~pina.data.dataset.PinaGraphDataset`, for handling \
|
||||||
:class:`~pina.graph.Graph` and :class:`~torch_geometric.data.Data` data.
|
:class:`~pina.graph.Graph` and :class:`~torch_geometric.data.Data` data.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -33,8 +34,7 @@ class PinaDatasetFactory:
|
|||||||
:param dict conditions_dict: Dictionary containing all the conditions
|
:param dict conditions_dict: Dictionary containing all the conditions
|
||||||
to be included in the dataset instance.
|
to be included in the dataset instance.
|
||||||
:return: A subclass of :class:`~pina.data.dataset.PinaDataset`.
|
:return: A subclass of :class:`~pina.data.dataset.PinaDataset`.
|
||||||
:rtype: pina.data.dataset.PinaTensorDataset |
|
:rtype: PinaTensorDataset | PinaGraphDataset
|
||||||
pina.data.dataset.PinaGraphDataset
|
|
||||||
|
|
||||||
:raises ValueError: If an empty dictionary is provided.
|
:raises ValueError: If an empty dictionary is provided.
|
||||||
"""
|
"""
|
||||||
@@ -74,8 +74,9 @@ class PinaDatasetFactory:
|
|||||||
|
|
||||||
class PinaDataset(Dataset, ABC):
|
class PinaDataset(Dataset, ABC):
|
||||||
"""
|
"""
|
||||||
Abstract class for the PINA dataset. It defines the common interface for
|
Abstract class for the PINA dataset which extends the PyTorch
|
||||||
:class:`~pina.data.dataset.PinaTensorDataset` and
|
:class:`~torch.utils.data.Dataset` class. It defines the common interface
|
||||||
|
for :class:`~pina.data.dataset.PinaTensorDataset` and
|
||||||
:class:`~pina.data.dataset.PinaGraphDataset` classes.
|
:class:`~pina.data.dataset.PinaGraphDataset` classes.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -83,13 +84,15 @@ class PinaDataset(Dataset, ABC):
|
|||||||
self, conditions_dict, max_conditions_lengths, automatic_batching
|
self, conditions_dict, max_conditions_lengths, automatic_batching
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize :class:`~pina.data.dataset.PinaDataset` instance by storing
|
Initialize the instance by storing the conditions dictionary, the
|
||||||
the provided conditions dictionary, and the automatic batching flag.
|
maximum number of items per conditions to consider, and the automatic
|
||||||
|
batching flag.
|
||||||
|
|
||||||
:param dict conditions_dict: Dictionary containing the conditions with
|
:param dict conditions_dict: A dictionary mapping condition names to
|
||||||
data.
|
their respective data. Each key represents a condition name, and the
|
||||||
:param dict max_conditions_lengths: Specifies the maximum number of data
|
corresponding value is a dictionary containing the associated data.
|
||||||
points to include in a single batch for each condition.
|
:param dict max_conditions_lengths: Maximum number of data points that
|
||||||
|
can be included in a single batch per condition.
|
||||||
:param bool automatic_batching: Indicates whether PyTorch automatic
|
:param bool automatic_batching: Indicates whether PyTorch automatic
|
||||||
batching is enabled in
|
batching is enabled in
|
||||||
:class:`~pina.data.data_module.PinaDataModule`.
|
:class:`~pina.data.data_module.PinaDataModule`.
|
||||||
@@ -258,8 +261,8 @@ class PinaGraphDataset(PinaDataset):
|
|||||||
Reshape properly ``data`` tensor to be processed handle by the graph
|
Reshape properly ``data`` tensor to be processed handle by the graph
|
||||||
based models.
|
based models.
|
||||||
|
|
||||||
:param data: torch.Tensor object of shape (N, ...) where N is the
|
:param data: torch.Tensor object of shape ``(N, ...)`` where ``N`` is
|
||||||
number of data points.
|
the number of data objects.
|
||||||
:type data: torch.Tensor | LabelTensor
|
:type data: torch.Tensor | LabelTensor
|
||||||
:return: Reshaped tensor object.
|
:return: Reshaped tensor object.
|
||||||
:rtype: torch.Tensor | LabelTensor
|
:rtype: torch.Tensor | LabelTensor
|
||||||
@@ -275,7 +278,7 @@ class PinaGraphDataset(PinaDataset):
|
|||||||
:param data: List of items to collate in a single batch.
|
:param data: List of items to collate in a single batch.
|
||||||
:type data: list[Data] | list[Graph]
|
:type data: list[Data] | list[Graph]
|
||||||
:return: Batch object.
|
:return: Batch object.
|
||||||
:rtype: Batch | PinaBatch
|
:rtype: Batch | LabelBatch
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if isinstance(data[0], Data):
|
if isinstance(data[0], Data):
|
||||||
|
|||||||
@@ -23,9 +23,8 @@ class Graph(Data):
|
|||||||
Create a new instance of the :class:`~pina.graph.Graph` class by
|
Create a new instance of the :class:`~pina.graph.Graph` class by
|
||||||
checking the consistency of the input data and storing the attributes.
|
checking the consistency of the input data and storing the attributes.
|
||||||
|
|
||||||
:param kwargs: Parameters used to initialize the
|
:param dict kwargs: Parameters used to initialize the
|
||||||
:class:`~pina.graph.Graph` object.
|
:class:`~pina.graph.Graph` object.
|
||||||
:type kwargs: dict
|
|
||||||
:return: A new instance of the :class:`~pina.graph.Graph` class.
|
:return: A new instance of the :class:`~pina.graph.Graph` class.
|
||||||
:rtype: Graph
|
:rtype: Graph
|
||||||
"""
|
"""
|
||||||
@@ -56,8 +55,8 @@ class Graph(Data):
|
|||||||
:param x: Optional tensor of node features ``(N, F)`` where ``F`` is the
|
:param x: Optional tensor of node features ``(N, F)`` where ``F`` is the
|
||||||
number of features per node.
|
number of features per node.
|
||||||
:type x: torch.Tensor, LabelTensor
|
:type x: torch.Tensor, LabelTensor
|
||||||
:param torch.Tensor edge_index: A tensor of shape ``(2, E)`` representing
|
:param torch.Tensor edge_index: A tensor of shape ``(2, E)``
|
||||||
the indices of the graph's edges.
|
representing the indices of the graph's edges.
|
||||||
:param pos: A tensor of shape ``(N, D)`` representing the positions of
|
:param pos: A tensor of shape ``(N, D)`` representing the positions of
|
||||||
``N`` points in ``D``-dimensional space.
|
``N`` points in ``D``-dimensional space.
|
||||||
:type pos: torch.Tensor | LabelTensor
|
:type pos: torch.Tensor | LabelTensor
|
||||||
@@ -80,8 +79,7 @@ class Graph(Data):
|
|||||||
"""
|
"""
|
||||||
Check the consistency of the types of the input data.
|
Check the consistency of the types of the input data.
|
||||||
|
|
||||||
:param kwargs: Attributes to be checked for consistency.
|
:param dict kwargs: Attributes to be checked for consistency.
|
||||||
:type kwargs: dict
|
|
||||||
"""
|
"""
|
||||||
# default types, specified in cls.__new__, by default they are Nont
|
# default types, specified in cls.__new__, by default they are Nont
|
||||||
# if specified in **kwargs they get override
|
# if specified in **kwargs they get override
|
||||||
@@ -134,7 +132,8 @@ class Graph(Data):
|
|||||||
Check if the edge attribute tensor is consistent in type and shape
|
Check if the edge attribute tensor is consistent in type and shape
|
||||||
with the edge index.
|
with the edge index.
|
||||||
|
|
||||||
:param torch.Tensor edge_attr: The edge attribute tensor.
|
:param edge_attr: The edge attribute tensor.
|
||||||
|
:type edge_attr: torch.Tensor | LabelTensor
|
||||||
:param torch.Tensor edge_index: The edge index tensor.
|
:param torch.Tensor edge_index: The edge index tensor.
|
||||||
:raises ValueError: If the edge attribute tensor is not consistent.
|
:raises ValueError: If the edge attribute tensor is not consistent.
|
||||||
"""
|
"""
|
||||||
@@ -156,8 +155,11 @@ class Graph(Data):
|
|||||||
Check if the input tensor x is consistent with the position tensor
|
Check if the input tensor x is consistent with the position tensor
|
||||||
`pos`.
|
`pos`.
|
||||||
|
|
||||||
:param torch.Tensor x: The input tensor.
|
:param x: The input tensor.
|
||||||
:param torch.Tensor pos: The position tensor.
|
:type x: torch.Tensor | LabelTensor
|
||||||
|
:param pos: The position tensor.
|
||||||
|
:type pos: torch.Tensor | LabelTensor
|
||||||
|
:raises ValueError: If the input tensor is not consistent.
|
||||||
"""
|
"""
|
||||||
if x is not None:
|
if x is not None:
|
||||||
check_consistency(x, (torch.Tensor, LabelTensor))
|
check_consistency(x, (torch.Tensor, LabelTensor))
|
||||||
@@ -166,9 +168,6 @@ class Graph(Data):
|
|||||||
if pos is not None:
|
if pos is not None:
|
||||||
if x.size(0) != pos.size(0):
|
if x.size(0) != pos.size(0):
|
||||||
raise ValueError("Inconsistent number of nodes.")
|
raise ValueError("Inconsistent number of nodes.")
|
||||||
if pos is not None:
|
|
||||||
if x.size(0) != pos.size(0):
|
|
||||||
raise ValueError("Inconsistent number of nodes.")
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _preprocess_edge_index(edge_index, undirected):
|
def _preprocess_edge_index(edge_index, undirected):
|
||||||
@@ -292,7 +291,7 @@ class GraphBuilder:
|
|||||||
class RadiusGraph(GraphBuilder):
|
class RadiusGraph(GraphBuilder):
|
||||||
"""
|
"""
|
||||||
Extends the :class:`~pina.graph.GraphBuilder` class to compute
|
Extends the :class:`~pina.graph.GraphBuilder` class to compute
|
||||||
edge_index based on a radius. Each point is connected to all the points
|
``edge_index`` based on a radius. Each point is connected to all the points
|
||||||
within the radius.
|
within the radius.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -305,11 +304,10 @@ class RadiusGraph(GraphBuilder):
|
|||||||
``N`` points in ``D``-dimensional space.
|
``N`` points in ``D``-dimensional space.
|
||||||
:type pos: torch.Tensor | LabelTensor
|
:type pos: torch.Tensor | LabelTensor
|
||||||
:param float radius: The radius within which points are connected.
|
:param float radius: The radius within which points are connected.
|
||||||
:param kwargs: Additional keyword arguments to be passed to the
|
:param dict kwargs: The additional keyword arguments to be passed to
|
||||||
:class:`~pina.graph.GraphBuilder` and :class:`~pina.graph.Graph`
|
:class:`GraphBuilder` and :class:`Graph` classes.
|
||||||
constructors.
|
:return: A :class:`~pina.graph.Graph` instance with the computed
|
||||||
:return: A :class:`~pina.graph.Graph` instance containing the input
|
``edge_index``.
|
||||||
information and the computed ``edge_index``.
|
|
||||||
:rtype: Graph
|
:rtype: Graph
|
||||||
"""
|
"""
|
||||||
edge_index = cls.compute_radius_graph(pos, radius)
|
edge_index = cls.compute_radius_graph(pos, radius)
|
||||||
@@ -318,16 +316,16 @@ class RadiusGraph(GraphBuilder):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def compute_radius_graph(points, radius):
|
def compute_radius_graph(points, radius):
|
||||||
"""
|
"""
|
||||||
Computes ``edge_index`` for a given set of points base on the radius.
|
Computes the ``edge_index`` based on the radius. Each point is connected
|
||||||
Each point is connected to all the points within the radius.
|
to all the points within the radius.
|
||||||
|
|
||||||
:param points: A tensor of shape ``(N, D)`` representing the positions
|
:param points: A tensor of shape ``(N, D)`` representing the positions
|
||||||
of ``N`` points in ``D``-dimensional space.
|
of ``N`` points in ``D``-dimensional space.
|
||||||
:type points: torch.Tensor | LabelTensor
|
:type points: torch.Tensor | LabelTensor
|
||||||
:param float radius: The number of nearest neighbors to find for each
|
:param float radius: The radius within which points are connected.
|
||||||
point.
|
:return: A tensor of shape ``(2, E)``, with ``E`` number of edges,
|
||||||
:rtype torch.Tensor: A tensor of shape ``(2, E)``, where ``E`` is the
|
representing the edge indices of the graph.
|
||||||
number of edges, representing the edge indices of the KNN graph.
|
:rtype: torch.Tensor
|
||||||
"""
|
"""
|
||||||
dist = torch.cdist(points, points, p=2)
|
dist = torch.cdist(points, points, p=2)
|
||||||
return (
|
return (
|
||||||
@@ -340,7 +338,7 @@ class RadiusGraph(GraphBuilder):
|
|||||||
class KNNGraph(GraphBuilder):
|
class KNNGraph(GraphBuilder):
|
||||||
"""
|
"""
|
||||||
Extends the :class:`~pina.graph.GraphBuilder` class to compute
|
Extends the :class:`~pina.graph.GraphBuilder` class to compute
|
||||||
edge_index based on a K-nearest neighbors algorithm.
|
``edge_index`` based on a K-nearest neighbors algorithm.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __new__(cls, pos, neighbours, **kwargs):
|
def __new__(cls, pos, neighbours, **kwargs):
|
||||||
@@ -353,12 +351,11 @@ class KNNGraph(GraphBuilder):
|
|||||||
:type pos: torch.Tensor | LabelTensor
|
:type pos: torch.Tensor | LabelTensor
|
||||||
:param int neighbours: The number of nearest neighbors to consider when
|
:param int neighbours: The number of nearest neighbors to consider when
|
||||||
building the graph.
|
building the graph.
|
||||||
:Keyword Arguments:
|
:param dict kwargs: The additional keyword arguments to be passed to
|
||||||
The additional keyword arguments to be passed to GraphBuilder
|
:class:`GraphBuilder` and :class:`Graph` classes.
|
||||||
and Graph classes
|
|
||||||
|
|
||||||
:return: A :class:`~pina.graph.Graph` instance containg the
|
:return: A :class:`~pina.graph.Graph` instance with the computed
|
||||||
information passed in input and the computed ``edge_index``
|
``edge_index``.
|
||||||
:rtype: Graph
|
:rtype: Graph
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -366,41 +363,45 @@ class KNNGraph(GraphBuilder):
|
|||||||
return super().__new__(cls, pos=pos, edge_index=edge_index, **kwargs)
|
return super().__new__(cls, pos=pos, edge_index=edge_index, **kwargs)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def compute_knn_graph(points, k):
|
def compute_knn_graph(points, neighbours):
|
||||||
"""
|
"""
|
||||||
Computes the edge_index based k-nearest neighbors graph algorithm
|
Computes the ``edge_index`` based on the K-nearest neighbors algorithm.
|
||||||
|
|
||||||
:param points: A tensor of shape ``(N, D)`` representing the positions
|
:param points: A tensor of shape ``(N, D)`` representing the positions
|
||||||
of ``N`` points in ``D``-dimensional space.
|
of ``N`` points in ``D``-dimensional space.
|
||||||
:type points: torch.Tensor | LabelTensor
|
:type points: torch.Tensor | LabelTensor
|
||||||
:param int k: The number of nearest neighbors to find for each point.
|
:param int neighbours: The number of nearest neighbors to consider when
|
||||||
:return: A tensor of shape ``(2, E)``, where ``E`` is the number of
|
building the graph.
|
||||||
edges, representing the edge indices of the KNN graph.
|
:return: A tensor of shape ``(2, E)``, with ``E`` number of edges,
|
||||||
|
representing the edge indices of the graph.
|
||||||
:rtype: torch.Tensor
|
:rtype: torch.Tensor
|
||||||
"""
|
"""
|
||||||
|
|
||||||
dist = torch.cdist(points, points, p=2)
|
dist = torch.cdist(points, points, p=2)
|
||||||
knn_indices = torch.topk(dist, k=k + 1, largest=False).indices[:, 1:]
|
knn_indices = torch.topk(dist, k=neighbours + 1, largest=False).indices[
|
||||||
row = torch.arange(points.size(0)).repeat_interleave(k)
|
:, 1:
|
||||||
|
]
|
||||||
|
row = torch.arange(points.size(0)).repeat_interleave(neighbours)
|
||||||
col = knn_indices.flatten()
|
col = knn_indices.flatten()
|
||||||
return torch.stack([row, col], dim=0).as_subclass(torch.Tensor)
|
return torch.stack([row, col], dim=0).as_subclass(torch.Tensor)
|
||||||
|
|
||||||
|
|
||||||
class LabelBatch(Batch):
|
class LabelBatch(Batch):
|
||||||
"""
|
"""
|
||||||
Add extract function to torch_geometric Batch object
|
Extends the :class:`~torch_geometric.data.Batch` class to include
|
||||||
|
:class:`~pina.label_tensor.LabelTensor` objects.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_data_list(cls, data_list):
|
def from_data_list(cls, data_list):
|
||||||
"""
|
"""
|
||||||
Create a Batch object from a list of :class:`~torch_geometric.data.Data`
|
Create a Batch object from a list of :class:`~torch_geometric.data.Data`
|
||||||
objects.
|
or :class:`~pina.graph.Graph` objects.
|
||||||
|
|
||||||
:param data_list: List of :class:`~torch_geometric.data.Data` or
|
:param data_list: List of :class:`~torch_geometric.data.Data` or
|
||||||
:class:`~pina.graph.Graph` objects.
|
:class:`~pina.graph.Graph` objects.
|
||||||
:type data_list: list[Data] | list[Graph]
|
:type data_list: list[Data] | list[Graph]
|
||||||
:return: A Batch object containing the data in the list
|
:return: A :class:`Batch` object containing the input data.
|
||||||
:rtype: Batch
|
:rtype: Batch
|
||||||
"""
|
"""
|
||||||
# Store the labels of Data/Graph objects (all data have the same labels)
|
# Store the labels of Data/Graph objects (all data have the same labels)
|
||||||
|
|||||||
@@ -216,10 +216,10 @@ class LabelTensor(torch.Tensor):
|
|||||||
def extract(self, labels_to_extract):
|
def extract(self, labels_to_extract):
|
||||||
"""
|
"""
|
||||||
Extract the subset of the original tensor by returning all the positions
|
Extract the subset of the original tensor by returning all the positions
|
||||||
corresponding to the passed ``label_to_extract``. If ``label_to_extract``
|
corresponding to the passed ``label_to_extract``. If
|
||||||
is a dictionary, the keys are the dimension names and the values are the
|
``label_to_extract`` is a dictionary, the keys are the dimension names
|
||||||
labels to extract. If a single label or a list of labels is passed, the
|
and the values are the labels to extract. If a single label or a list
|
||||||
last dimension is considered.
|
of labels is passed, the last dimension is considered.
|
||||||
|
|
||||||
:Example:
|
:Example:
|
||||||
>>> from pina import LabelTensor
|
>>> from pina import LabelTensor
|
||||||
|
|||||||
@@ -109,9 +109,7 @@ class FourierIntegralKernel(torch.nn.Module):
|
|||||||
if all(isinstance(i, list) for i in n_modes) and len(layers) != len(
|
if all(isinstance(i, list) for i in n_modes) and len(layers) != len(
|
||||||
n_modes
|
n_modes
|
||||||
):
|
):
|
||||||
raise RuntimeError(
|
raise RuntimeError("Inconsistent number of layers and modes.")
|
||||||
"Inconsistent number of layers and modes."
|
|
||||||
)
|
|
||||||
if all(isinstance(i, int) for i in n_modes):
|
if all(isinstance(i, int) for i in n_modes):
|
||||||
n_modes = [n_modes] * len(layers)
|
n_modes = [n_modes] * len(layers)
|
||||||
else:
|
else:
|
||||||
@@ -329,7 +327,7 @@ class FNO(KernelNeuralOperator):
|
|||||||
``projection_net`` maps the hidden representation to the output
|
``projection_net`` maps the hidden representation to the output
|
||||||
function.
|
function.
|
||||||
|
|
||||||
: param x: The input tensor for performing the computation. Depending
|
: param x: The input tensor for performing the computation. Depending
|
||||||
on the ``dimensions`` in the initialization, it expects a tensor
|
on the ``dimensions`` in the initialization, it expects a tensor
|
||||||
with the following shapes:
|
with the following shapes:
|
||||||
* 1D tensors: ``[batch, X, channels]``
|
* 1D tensors: ``[batch, X, channels]``
|
||||||
|
|||||||
Reference in New Issue
Block a user