Other fixes
This commit is contained in:
@@ -18,7 +18,8 @@ class Collector:
|
|||||||
and the problem and initializing the data collections (dictionary where
|
and the problem and initializing the data collections (dictionary where
|
||||||
data will be stored).
|
data will be stored).
|
||||||
|
|
||||||
:param AbstractProblem problem: The problem to collect data from.
|
:param pina.problem.abstract_problem.AbstractProblem problem: The
|
||||||
|
problem to collect data from.
|
||||||
"""
|
"""
|
||||||
# creating a hook between collector and problem
|
# creating a hook between collector and problem
|
||||||
self.problem = problem
|
self.problem = problem
|
||||||
@@ -71,16 +72,17 @@ class Collector:
|
|||||||
Problem connected to the collector.
|
Problem connected to the collector.
|
||||||
|
|
||||||
:return: The problem from which the data is collected.
|
:return: The problem from which the data is collected.
|
||||||
:rtype: AbstractProblem
|
:rtype: pina.problem.abstract_problem.AbstractProblem
|
||||||
"""
|
"""
|
||||||
return self._problem
|
return self._problem
|
||||||
|
|
||||||
@problem.setter
|
@problem.setter
|
||||||
def problem(self, value):
|
def problem(self, value):
|
||||||
"""
|
"""
|
||||||
Ser the problem connected to the collector.
|
Set the problem connected to the collector.
|
||||||
|
|
||||||
:param AbstractProblem value: The problem to connect to the collector.
|
:param pina.problem.abstract_problem.AbstractProblem value: The problem
|
||||||
|
to connect to the collector.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
self._problem = value
|
self._problem = value
|
||||||
|
|||||||
@@ -89,15 +89,15 @@ class Condition:
|
|||||||
Create a new condition object based on the keyword arguments passed.
|
Create a new condition object based on the keyword arguments passed.
|
||||||
|
|
||||||
- `input` and `target`:
|
- `input` and `target`:
|
||||||
:class:`pina.condition.input_target_condition.InputTargetCondition`
|
:class:`~pina.condition.input_target_condition.InputTargetCondition`
|
||||||
- `domain` and `equation`:
|
- `domain` and `equation`:
|
||||||
:class:`pina.condition.domain_equation_condition.
|
:class:`~pina.condition.domain_equation_condition.
|
||||||
DomainEquationCondition`
|
DomainEquationCondition`
|
||||||
- `input` and `equation`: :class:`pina.condition.
|
- `input` and `equation`: :class:`~pina.condition.
|
||||||
input_equation_condition.InputEquationCondition`
|
input_equation_condition.InputEquationCondition`
|
||||||
- `input`: :class:`pina.condition.data_condition.DataCondition`
|
- `input`: :class:`~pina.condition.data_condition.DataCondition`
|
||||||
- `input` and `conditional_variables`:
|
- `input` and `conditional_variables`:
|
||||||
:class:`pina.condition.data_condition.DataCondition`
|
:class:`~pina.condition.data_condition.DataCondition`
|
||||||
|
|
||||||
:raises ValueError: No valid condition has been found.
|
:raises ValueError: No valid condition has been found.
|
||||||
:return: A new condition instance belonging to the proper class.
|
:return: A new condition instance belonging to the proper class.
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ class DataCondition(ConditionInterface):
|
|||||||
pina.condition.data_condition.GraphDataCondition
|
pina.condition.data_condition.GraphDataCondition
|
||||||
|
|
||||||
:raises ValueError: If input is not of type :class:`torch.Tensor`,
|
:raises ValueError: If input is not of type :class:`torch.Tensor`,
|
||||||
:class:`pina.label_tensor.LabelTensor`, :class:`pina.graph.Graph`,
|
:class:`~pina.label_tensor.LabelTensor`, :class:`~pina.graph.Graph`,
|
||||||
or :class:`~torch_geometric.data.Data`.
|
or :class:`~torch_geometric.data.Data`.
|
||||||
|
|
||||||
|
|
||||||
@@ -68,7 +68,7 @@ class DataCondition(ConditionInterface):
|
|||||||
:type conditional_variables: torch.Tensor or LabelTensor
|
:type conditional_variables: torch.Tensor or LabelTensor
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
If either `input` is composed by a list of :class:`pina.graph.Graph`
|
If either `input` is composed by a list of :class:`~pina.graph.Graph`
|
||||||
or :class:`~torch_geometric.data.Data` objects, all elements must
|
or :class:`~torch_geometric.data.Data` objects, all elements must
|
||||||
have the same structure (keys and data types)
|
have the same structure (keys and data types)
|
||||||
"""
|
"""
|
||||||
@@ -80,12 +80,12 @@ class DataCondition(ConditionInterface):
|
|||||||
class TensorDataCondition(DataCondition):
|
class TensorDataCondition(DataCondition):
|
||||||
"""
|
"""
|
||||||
DataCondition for :class:`torch.Tensor` or
|
DataCondition for :class:`torch.Tensor` or
|
||||||
:class:`pina.label_tensor.LabelTensor` input data
|
:class:`~pina.label_tensor.LabelTensor` input data
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
class GraphDataCondition(DataCondition):
|
class GraphDataCondition(DataCondition):
|
||||||
"""
|
"""
|
||||||
DataCondition for :class:`pina.graph.Graph` or
|
DataCondition for :class:`~pina.graph.Graph` or
|
||||||
:class:`~torch_geometric.data.Data` input data
|
:class:`~torch_geometric.data.Data` input data
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ class InputEquationCondition(ConditionInterface):
|
|||||||
pina.condition.input_equation_condition.InputGraphEquationCondition
|
pina.condition.input_equation_condition.InputGraphEquationCondition
|
||||||
|
|
||||||
:raises ValueError: If input is not of type
|
:raises ValueError: If input is not of type
|
||||||
:class:`pina.label_tensor.LabelTensor`, :class:`pina.graph.Graph`.
|
:class:`~pina.label_tensor.LabelTensor`, :class:`~pina.graph.Graph`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# If the class is already a subclass, return the instance
|
# If the class is already a subclass, return the instance
|
||||||
@@ -62,15 +62,16 @@ class InputEquationCondition(ConditionInterface):
|
|||||||
Initialize the InputEquationCondition by storing the input and equation.
|
Initialize the InputEquationCondition by storing the input and equation.
|
||||||
|
|
||||||
:param input: Input data for the condition.
|
:param input: Input data for the condition.
|
||||||
:type input: LabelTensor | Graph | list[Graph] | tuple[Graph]
|
:type input: pina.label_tensor.LabelTensor | pina.graph.Graph |
|
||||||
|
list[pina.graph.Graph] | tuple[pina.graph.Graph]
|
||||||
:param EquationInterface equation: Equation object containing the
|
:param EquationInterface equation: Equation object containing the
|
||||||
equation function.
|
equation function.
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
If `input` is composed by a list of :class:`pina.graph.Graph`
|
If `input` is composed by a list of :class:`~pina.graph.Graph`
|
||||||
objects, all elements must have the same structure (keys and data
|
objects, all elements must have the same structure (keys and data
|
||||||
types). Moreover, at least one attribute must be a
|
types). Moreover, at least one attribute must be a
|
||||||
:class:`pina.label_tensor.LabelTensor`.
|
:class:`~pina.label_tensor.LabelTensor`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -90,21 +91,21 @@ class InputEquationCondition(ConditionInterface):
|
|||||||
|
|
||||||
class InputTensorEquationCondition(InputEquationCondition):
|
class InputTensorEquationCondition(InputEquationCondition):
|
||||||
"""
|
"""
|
||||||
InputEquationCondition subclass for :class:`pina.label_tensor.LabelTensor`
|
InputEquationCondition subclass for :class:`~pina.label_tensor.LabelTensor`
|
||||||
input data.
|
input data.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
class InputGraphEquationCondition(InputEquationCondition):
|
class InputGraphEquationCondition(InputEquationCondition):
|
||||||
"""
|
"""
|
||||||
InputEquationCondition subclass for :class:`pina.graph.Graph` input data.
|
InputEquationCondition subclass for :class:`~pina.graph.Graph` input data.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _check_label_tensor(input):
|
def _check_label_tensor(input):
|
||||||
"""
|
"""
|
||||||
Check if at least one :class:`pina.label_tensor.LabelTensor` is present
|
Check if at least one :class:`~pina.label_tensor.LabelTensor` is present
|
||||||
in the :class:`pina.graph.Graph` object.
|
in the :class:`~pina.graph.Graph` object.
|
||||||
|
|
||||||
:param input: Input data.
|
:param input: Input data.
|
||||||
:type input: torch.Tensor | Graph | Data
|
:type input: torch.Tensor | Graph | Data
|
||||||
|
|||||||
@@ -44,8 +44,8 @@ class InputTargetCondition(ConditionInterface):
|
|||||||
pina.condition.input_target_condition.GraphInputGraphTargetCondition
|
pina.condition.input_target_condition.GraphInputGraphTargetCondition
|
||||||
|
|
||||||
:raises ValueError: If input and or target are not of type
|
:raises ValueError: If input and or target are not of type
|
||||||
:class:`torch.Tensor`, :class:`pina.label_tensor.LabelTensor`,
|
:class:`torch.Tensor`, :class:`~pina.label_tensor.LabelTensor`,
|
||||||
:class:`pina.graph.Graph`, or :class:`~torch_geometric.data.Data`.
|
:class:`~pina.graph.Graph`, or :class:`~torch_geometric.data.Data`.
|
||||||
"""
|
"""
|
||||||
if cls != InputTargetCondition:
|
if cls != InputTargetCondition:
|
||||||
return super().__new__(cls)
|
return super().__new__(cls)
|
||||||
@@ -97,7 +97,7 @@ class InputTargetCondition(ConditionInterface):
|
|||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
If either `input` or `target` are composed by a list of
|
If either `input` or `target` are composed by a list of
|
||||||
:class:`pina.graph.Graph` or :class:`~torch_geometric.data.Data`
|
:class:`~pina.graph.Graph` or :class:`~torch_geometric.data.Data`
|
||||||
objects, all elements must have the same structure (keys and data
|
objects, all elements must have the same structure (keys and data
|
||||||
types)
|
types)
|
||||||
"""
|
"""
|
||||||
@@ -122,29 +122,29 @@ class InputTargetCondition(ConditionInterface):
|
|||||||
class TensorInputTensorTargetCondition(InputTargetCondition):
|
class TensorInputTensorTargetCondition(InputTargetCondition):
|
||||||
"""
|
"""
|
||||||
InputTargetCondition subclass for :class:`torch.Tensor` or
|
InputTargetCondition subclass for :class:`torch.Tensor` or
|
||||||
:class:`pina.label_tensor.LabelTensor` input and target data.
|
:class:`~pina.label_tensor.LabelTensor` input and target data.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
class TensorInputGraphTargetCondition(InputTargetCondition):
|
class TensorInputGraphTargetCondition(InputTargetCondition):
|
||||||
"""
|
"""
|
||||||
InputTargetCondition subclass for :class:`torch.Tensor` or
|
InputTargetCondition subclass for :class:`torch.Tensor` or
|
||||||
:class:`pina.label_tensor.LabelTensor` input and
|
:class:`~pina.label_tensor.LabelTensor` input and
|
||||||
:class:`pina.graph.Graph` or :class:`~torch_geometric.data.Data` target
|
:class:`~pina.graph.Graph` or :class:`~torch_geometric.data.Data` target
|
||||||
data.
|
data.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
class GraphInputTensorTargetCondition(InputTargetCondition):
|
class GraphInputTensorTargetCondition(InputTargetCondition):
|
||||||
"""
|
"""
|
||||||
InputTargetCondition subclass for :class:`pina.graph.Graph` o
|
InputTargetCondition subclass for :class:`~pina.graph.Graph` o
|
||||||
:class:`~torch_geometric.data.Data` input and :class:`torch.Tensor` or
|
:class:`~torch_geometric.data.Data` input and :class:`torch.Tensor` or
|
||||||
:class:`pina.label_tensor.LabelTensor` target data.
|
:class:`~pina.label_tensor.LabelTensor` target data.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
class GraphInputGraphTargetCondition(InputTargetCondition):
|
class GraphInputGraphTargetCondition(InputTargetCondition):
|
||||||
"""
|
"""
|
||||||
InputTargetCondition subclass for :class:`pina.graph.Graph`/
|
InputTargetCondition subclass for :class:`~pina.graph.Graph`/
|
||||||
:class:`~torch_geometric.data.Data` input and target data.
|
:class:`~torch_geometric.data.Data` input and target data.
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -157,7 +157,7 @@ class Collator:
|
|||||||
def _collate_tensor_dataset(data_list):
|
def _collate_tensor_dataset(data_list):
|
||||||
"""
|
"""
|
||||||
Function used to collate the data when the dataset is a
|
Function used to collate the data when the dataset is a
|
||||||
:class:`pina.data.dataset.PinaTensorDataset`.
|
:class:`~pina.data.dataset.PinaTensorDataset`.
|
||||||
|
|
||||||
:param data_list: Elements to be collated.
|
:param data_list: Elements to be collated.
|
||||||
:type data_list: list[torch.Tensor] | list[LabelTensor]
|
:type data_list: list[torch.Tensor] | list[LabelTensor]
|
||||||
@@ -165,7 +165,7 @@ class Collator:
|
|||||||
:rtype: dict
|
:rtype: dict
|
||||||
|
|
||||||
:raises RuntimeError: If the data is not a :class:`torch.Tensor` or a
|
:raises RuntimeError: If the data is not a :class:`torch.Tensor` or a
|
||||||
:class:`pina.label_tensor.LabelTensor`.
|
:class:`~pina.label_tensor.LabelTensor`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if isinstance(data_list[0], LabelTensor):
|
if isinstance(data_list[0], LabelTensor):
|
||||||
@@ -177,7 +177,7 @@ class Collator:
|
|||||||
def _collate_graph_dataset(self, data_list):
|
def _collate_graph_dataset(self, data_list):
|
||||||
"""
|
"""
|
||||||
Function used to collate the data when the dataset is a
|
Function used to collate the data when the dataset is a
|
||||||
:class:`pina.data.dataset.PinaGraphDataset`.
|
:class:`~pina.data.dataset.PinaGraphDataset`.
|
||||||
|
|
||||||
:param data_list: Elememts to be collated.
|
:param data_list: Elememts to be collated.
|
||||||
:type data_list: list[Data] | list[Graph]
|
:type data_list: list[Data] | list[Graph]
|
||||||
@@ -185,7 +185,7 @@ class Collator:
|
|||||||
:rtype: dict
|
:rtype: dict
|
||||||
|
|
||||||
:raises RuntimeError: If the data is not a
|
:raises RuntimeError: If the data is not a
|
||||||
:class:`~torch_geometric.data.Data` or a :class:`pina.graph.Graph`.
|
:class:`~torch_geometric.data.Data` or a :class:`~pina.graph.Graph`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if isinstance(data_list[0], LabelTensor):
|
if isinstance(data_list[0], LabelTensor):
|
||||||
|
|||||||
@@ -15,26 +15,26 @@ class PinaDatasetFactory:
|
|||||||
Depending on the type inside the conditions, it creates a different dataset
|
Depending on the type inside the conditions, it creates a different dataset
|
||||||
object:
|
object:
|
||||||
|
|
||||||
- :class:`pina.data.dataset.PinaTensorDataset` for handling
|
- :class:`~pina.data.dataset.PinaTensorDataset` for handling
|
||||||
:class:`torch.Tensor` and :class:`pina.label_tensor.LabelTensor` data.
|
:class:`torch.Tensor` and :class:`~pina.label_tensor.LabelTensor` data.
|
||||||
- :class:`pina.data.dataset.PinaGraphDataset` for handling
|
- :class:`~pina.data.dataset.PinaGraphDataset` for handling
|
||||||
:class:`pina.graph.Graph` and :class:`~torch_geometric.data.Data` data.
|
:class:`~pina.graph.Graph` and :class:`~torch_geometric.data.Data` data.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __new__(cls, conditions_dict, **kwargs):
|
def __new__(cls, conditions_dict, **kwargs):
|
||||||
"""
|
"""
|
||||||
Instantiate the appropriate subclass of
|
Instantiate the appropriate subclass of
|
||||||
:class:`pina.data.dataset.PinaDataset`.
|
:class:`~pina.data.dataset.PinaDataset`.
|
||||||
|
|
||||||
If a graph is present in the conditions, returns a
|
If a graph is present in the conditions, returns a
|
||||||
:class:`pina.data.dataset.PinaGraphDataset`, otherwise returns a
|
:class:`~pina.data.dataset.PinaGraphDataset`, otherwise returns a
|
||||||
:class:`pina.data.dataset.PinaTensorDataset`.
|
:class:`~pina.data.dataset.PinaTensorDataset`.
|
||||||
|
|
||||||
:param dict conditions_dict: Dictionary containing all the conditions
|
:param dict conditions_dict: Dictionary containing all the conditions
|
||||||
to be included in the dataset instance.
|
to be included in the dataset instance.
|
||||||
:return: A subclass of :class:`pina.data.dataset.PinaDataset`.
|
:return: A subclass of :class:`~pina.data.dataset.PinaDataset`.
|
||||||
:rtype: :class:`pina.data.dataset.PinaTensorDataset` |
|
:rtype: :class:`~pina.data.dataset.PinaTensorDataset` |
|
||||||
:class:`pina.data.dataset.PinaGraphDataset`
|
:class:`~pina.data.dataset.PinaGraphDataset`
|
||||||
|
|
||||||
:raises ValueError: If an empty dictionary is provided.
|
:raises ValueError: If an empty dictionary is provided.
|
||||||
"""
|
"""
|
||||||
@@ -75,15 +75,15 @@ class PinaDatasetFactory:
|
|||||||
class PinaDataset(Dataset, ABC):
|
class PinaDataset(Dataset, ABC):
|
||||||
"""
|
"""
|
||||||
Abstract class for the PINA dataset. It defines the common interface for
|
Abstract class for the PINA dataset. It defines the common interface for
|
||||||
the :class:`pina.data.dataset.PinaTensorDataset` and
|
the :class:`~pina.data.dataset.PinaTensorDataset` and
|
||||||
:class:`pina.data.dataset.PinaGraphDataset` classes.
|
:class:`~pina.data.dataset.PinaGraphDataset` classes.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, conditions_dict, max_conditions_lengths, automatic_batching
|
self, conditions_dict, max_conditions_lengths, automatic_batching
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize a :class:`pina.data.dataset.PinaDataset` instance by storing
|
Initialize a :class:`~pina.data.dataset.PinaDataset` instance by storing
|
||||||
the providedconditions dictionary, the maximum number of conditions to
|
the providedconditions dictionary, the maximum number of conditions to
|
||||||
consider, and the automatic batching flag.
|
consider, and the automatic batching flag.
|
||||||
|
|
||||||
@@ -93,7 +93,7 @@ class PinaDataset(Dataset, ABC):
|
|||||||
points to include in a single batch for each condition.
|
points to include in a single batch for each condition.
|
||||||
:param bool automatic_batching: Indicates whether PyTorch automatic
|
:param bool automatic_batching: Indicates whether PyTorch automatic
|
||||||
batching is enabled in
|
batching is enabled in
|
||||||
:class:`pina.data.data_module.PinaDataModule`.
|
:class:`~pina.data.data_module.PinaDataModule`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Store the conditions dictionary
|
# Store the conditions dictionary
|
||||||
@@ -205,7 +205,7 @@ class PinaDataset(Dataset, ABC):
|
|||||||
class PinaTensorDataset(PinaDataset):
|
class PinaTensorDataset(PinaDataset):
|
||||||
"""
|
"""
|
||||||
Dataset class for the PINA dataset with :class:`torch.Tensor` and
|
Dataset class for the PINA dataset with :class:`torch.Tensor` and
|
||||||
:class:`pina.label_tensor.LabelTensor` data.
|
:class:`~pina.label_tensor.LabelTensor` data.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Override _retrive_data method for torch.Tensor data
|
# Override _retrive_data method for torch.Tensor data
|
||||||
@@ -215,7 +215,7 @@ class PinaTensorDataset(PinaDataset):
|
|||||||
|
|
||||||
:param dict data: Dictionary containing the data
|
:param dict data: Dictionary containing the data
|
||||||
(only :class:`torch.Tensor` or
|
(only :class:`torch.Tensor` or
|
||||||
:class:`pina.label_tensor.LabelTensor`).
|
:class:`~pina.label_tensor.LabelTensor`).
|
||||||
:param list[int] idx_list: indices to retrieve.
|
:param list[int] idx_list: indices to retrieve.
|
||||||
:return: Dictionary containing the data at the given indices.
|
:return: Dictionary containing the data at the given indices.
|
||||||
:rtype: dict
|
:rtype: dict
|
||||||
@@ -237,7 +237,7 @@ class PinaTensorDataset(PinaDataset):
|
|||||||
class PinaGraphDataset(PinaDataset):
|
class PinaGraphDataset(PinaDataset):
|
||||||
"""
|
"""
|
||||||
Dataset class for the PINA dataset with :class:`~torch_geometric.data.Data`
|
Dataset class for the PINA dataset with :class:`~torch_geometric.data.Data`
|
||||||
and :class:`pina.graph.Graph` data.
|
and :class:`~pina.graph.Graph` data.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _create_graph_batch(self, data):
|
def _create_graph_batch(self, data):
|
||||||
|
|||||||
@@ -20,13 +20,13 @@ class Graph(Data):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Create a new instance of the :class:`pina.graph.Graph` class by checking
|
Create a new instance of the :class:`~pina.graph.Graph` class by checking
|
||||||
the consistency of the input data and storing the attributes.
|
the consistency of the input data and storing the attributes.
|
||||||
|
|
||||||
:param kwargs: Parameters used to initialize the
|
:param kwargs: Parameters used to initialize the
|
||||||
:class:`pina.graph.Graph` object.
|
:class:`~pina.graph.Graph` object.
|
||||||
:type kwargs: dict
|
:type kwargs: dict
|
||||||
:return: A new instance of the :class:`pina.graph.Graph` class.
|
:return: A new instance of the :class:`~pina.graph.Graph` class.
|
||||||
:rtype: Graph
|
:rtype: Graph
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -230,7 +230,7 @@ class GraphBuilder:
|
|||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Compute the edge attributes and create a new instance of the
|
Compute the edge attributes and create a new instance of the
|
||||||
:class:`pina.graph.Graph` class.
|
:class:`~pina.graph.Graph` class.
|
||||||
|
|
||||||
:param pos: A tensor of shape `(N, D)` representing the positions of `N`
|
:param pos: A tensor of shape `(N, D)` representing the positions of `N`
|
||||||
points in `D`-dimensional space.
|
points in `D`-dimensional space.
|
||||||
@@ -248,8 +248,8 @@ class GraphBuilder:
|
|||||||
If provided, overrides `edge_attr`.
|
If provided, overrides `edge_attr`.
|
||||||
:type custom_edge_func: callable, optional
|
:type custom_edge_func: callable, optional
|
||||||
:param kwargs: Additional keyword arguments passed to the
|
:param kwargs: Additional keyword arguments passed to the
|
||||||
:class:`pina.graph.Graph` class constructor.
|
:class:`~pina.graph.Graph` class constructor.
|
||||||
:return: A :class:`pina.graph.Graph` instance constructed using the
|
:return: A :class:`~pina.graph.Graph` instance constructed using the
|
||||||
provided information.
|
provided information.
|
||||||
:rtype: Graph
|
:rtype: Graph
|
||||||
"""
|
"""
|
||||||
@@ -289,7 +289,7 @@ class RadiusGraph(GraphBuilder):
|
|||||||
|
|
||||||
def __new__(cls, pos, radius, **kwargs):
|
def __new__(cls, pos, radius, **kwargs):
|
||||||
"""
|
"""
|
||||||
Extends the :class:`pina.graph.GraphBuilder` class to compute
|
Extends the :class:`~pina.graph.GraphBuilder` class to compute
|
||||||
edge_index based on a radius. Each point is connected to all the points
|
edge_index based on a radius. Each point is connected to all the points
|
||||||
within the radius.
|
within the radius.
|
||||||
|
|
||||||
@@ -299,9 +299,9 @@ class RadiusGraph(GraphBuilder):
|
|||||||
:param radius: The radius within which points are connected.
|
:param radius: The radius within which points are connected.
|
||||||
:type radius: float
|
:type radius: float
|
||||||
:param kwargs: Additional keyword arguments to be passed to the
|
:param kwargs: Additional keyword arguments to be passed to the
|
||||||
:class:`pina.graph.GraphBuilder` and :class:`pina.graph.Graph`
|
:class:`~pina.graph.GraphBuilder` and :class:`~pina.graph.Graph`
|
||||||
constructors.
|
constructors.
|
||||||
:return: A :class:`pina.graph.Graph` instance containing the input
|
:return: A :class:`~pina.graph.Graph` instance containing the input
|
||||||
information and the computed edge_index.
|
information and the computed edge_index.
|
||||||
:rtype: Graph
|
:rtype: Graph
|
||||||
"""
|
"""
|
||||||
@@ -339,7 +339,7 @@ class KNNGraph(GraphBuilder):
|
|||||||
|
|
||||||
def __new__(cls, pos, neighbours, **kwargs):
|
def __new__(cls, pos, neighbours, **kwargs):
|
||||||
"""
|
"""
|
||||||
Extends the :class:`pina.graph.GraphBuilder` class to compute edge_index
|
Extends the :class:`~pina.graph.GraphBuilder` class to compute edge_index
|
||||||
based on a K-nearest neighbors algorithm.
|
based on a K-nearest neighbors algorithm.
|
||||||
|
|
||||||
:param pos: A tensor of shape (N, D) representing the positions of N
|
:param pos: A tensor of shape (N, D) representing the positions of N
|
||||||
@@ -351,7 +351,7 @@ class KNNGraph(GraphBuilder):
|
|||||||
The additional keyword arguments to be passed to GraphBuilder
|
The additional keyword arguments to be passed to GraphBuilder
|
||||||
and Graph classes
|
and Graph classes
|
||||||
|
|
||||||
:return: A :class:`pina.graph.Graph` instance containg the
|
:return: A :class:`~pina.graph.Graph` instance containg the
|
||||||
information passed in input and the computed edge_index
|
information passed in input and the computed edge_index
|
||||||
:rtype: Graph
|
:rtype: Graph
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -14,14 +14,14 @@ class LabelTensor(torch.Tensor):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def __new__(cls, x, labels, *args, **kwargs):
|
def __new__(cls, x, labels, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
Create a new instance of the :class:`pina.label_tensor.LabelTensor`
|
Create a new instance of the :class:`~pina.label_tensor.LabelTensor`
|
||||||
class.
|
class.
|
||||||
|
|
||||||
:param torch.Tensor x: :class:`torch.tensor` instance to be casted as a
|
:param torch.Tensor x: :class:`torch.tensor` instance to be casted as a
|
||||||
:class:`pina.label_tensor.LabelTensor`.
|
:class:`~pina.label_tensor.LabelTensor`.
|
||||||
:param labels: Labels to assign to the tensor.
|
:param labels: Labels to assign to the tensor.
|
||||||
:type labels: str | list[str] | dict
|
:type labels: str | list[str] | dict
|
||||||
:return: The instance of the :class:`pina.label_tensor.LabelTensor`
|
:return: The instance of the :class:`~pina.label_tensor.LabelTensor`
|
||||||
class.
|
class.
|
||||||
:rtype: LabelTensor
|
:rtype: LabelTensor
|
||||||
"""
|
"""
|
||||||
@@ -33,10 +33,10 @@ class LabelTensor(torch.Tensor):
|
|||||||
@property
|
@property
|
||||||
def tensor(self):
|
def tensor(self):
|
||||||
"""
|
"""
|
||||||
Give the tensor part of the :class:`pina.label_tensor.LabelTensor`
|
Give the tensor part of the :class:`~pina.label_tensor.LabelTensor`
|
||||||
object.
|
object.
|
||||||
|
|
||||||
:return: tensor part of the :class:`pina.label_tensor.LabelTensor`.
|
:return: tensor part of the :class:`~pina.label_tensor.LabelTensor`.
|
||||||
:rtype: torch.Tensor
|
:rtype: torch.Tensor
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -44,7 +44,7 @@ class LabelTensor(torch.Tensor):
|
|||||||
|
|
||||||
def __init__(self, x, labels):
|
def __init__(self, x, labels):
|
||||||
"""
|
"""
|
||||||
Construct a :class:`pina.label_tensor.LabelTensor` by passing a dict of
|
Construct a :class:`~pina.label_tensor.LabelTensor` by passing a dict of
|
||||||
the labels and a :class:`torch.Tensor`. Internally, the initialization
|
the labels and a :class:`torch.Tensor`. Internally, the initialization
|
||||||
method will store check the compatibility of the labels with the tensor
|
method will store check the compatibility of the labels with the tensor
|
||||||
shape.
|
shape.
|
||||||
@@ -275,10 +275,10 @@ class LabelTensor(torch.Tensor):
|
|||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
"""
|
"""
|
||||||
The string representation of the :class:`pina.label_tensor.LabelTensor`.
|
The string representation of the :class:`~pina.label_tensor.LabelTensor`.
|
||||||
|
|
||||||
:return: String representation of the
|
:return: String representation of the
|
||||||
:class:`pina.label_tensor.LabelTensor` instance.
|
:class:`~pina.label_tensor.LabelTensor` instance.
|
||||||
:rtype: str
|
:rtype: str
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -295,7 +295,7 @@ class LabelTensor(torch.Tensor):
|
|||||||
Concatenate a list of tensors along a specified dimension. For more
|
Concatenate a list of tensors along a specified dimension. For more
|
||||||
details, see :meth:`torch.cat`.
|
details, see :meth:`torch.cat`.
|
||||||
|
|
||||||
:param list[LabelTensor] tensors: :class:`pina.label_tensor.LabelTensor`
|
:param list[LabelTensor] tensors: :class:`~pina.label_tensor.LabelTensor`
|
||||||
instances to concatenate
|
instances to concatenate
|
||||||
:param int dim: dimensions on which you want to perform the operation
|
:param int dim: dimensions on which you want to perform the operation
|
||||||
(default is 0)
|
(default is 0)
|
||||||
@@ -351,7 +351,7 @@ class LabelTensor(torch.Tensor):
|
|||||||
|
|
||||||
:param list[LabelTensor] tensors: A list of tensors to stack.
|
:param list[LabelTensor] tensors: A list of tensors to stack.
|
||||||
All tensors must have the same shape.
|
All tensors must have the same shape.
|
||||||
:return: A new :class:`pina.label_tensor.LabelTensor` instance obtained
|
:return: A new :class:`~pina.label_tensor.LabelTensor` instance obtained
|
||||||
by stacking the input tensors, with the updated labels.
|
by stacking the input tensors, with the updated labels.
|
||||||
:rtype: LabelTensor
|
:rtype: LabelTensor
|
||||||
"""
|
"""
|
||||||
@@ -373,7 +373,7 @@ class LabelTensor(torch.Tensor):
|
|||||||
:param bool mode: A boolean value indicating whether the tensor should
|
:param bool mode: A boolean value indicating whether the tensor should
|
||||||
track gradients.If `True`, the tensor will track gradients;
|
track gradients.If `True`, the tensor will track gradients;
|
||||||
if `False`, it will not.
|
if `False`, it will not.
|
||||||
:return: The :class:`pina.label_tensor.LabelTensor` itself with the
|
:return: The :class:`~pina.label_tensor.LabelTensor` itself with the
|
||||||
updated `requires_grad` state and retained labels.
|
updated `requires_grad` state and retained labels.
|
||||||
:rtype: LabelTensor
|
:rtype: LabelTensor
|
||||||
"""
|
"""
|
||||||
@@ -399,7 +399,7 @@ class LabelTensor(torch.Tensor):
|
|||||||
Performs Tensor dtype and/or device conversion. For more details, see
|
Performs Tensor dtype and/or device conversion. For more details, see
|
||||||
:meth:`torch.Tensor.to`.
|
:meth:`torch.Tensor.to`.
|
||||||
|
|
||||||
:return: A new :class:`pina.label_tensor.LabelTensor` instance with the
|
:return: A new :class:`~pina.label_tensor.LabelTensor` instance with the
|
||||||
updated dtype and/or device and retained labels.
|
updated dtype and/or device and retained labels.
|
||||||
:rtype: LabelTensor
|
:rtype: LabelTensor
|
||||||
"""
|
"""
|
||||||
@@ -410,10 +410,10 @@ class LabelTensor(torch.Tensor):
|
|||||||
|
|
||||||
def clone(self, *args, **kwargs):
|
def clone(self, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
Clone the :class:`pina.label_tensor.LabelTensor`. For more details, see
|
Clone the :class:`~pina.label_tensor.LabelTensor`. For more details, see
|
||||||
:meth:`torch.Tensor.clone`.
|
:meth:`torch.Tensor.clone`.
|
||||||
|
|
||||||
:return: A new :class:`pina.label_tensor.LabelTensor` instance with the
|
:return: A new :class:`~pina.label_tensor.LabelTensor` instance with the
|
||||||
same data and labels but allocated in a different memory location.
|
same data and labels but allocated in a different memory location.
|
||||||
:rtype: LabelTensor
|
:rtype: LabelTensor
|
||||||
"""
|
"""
|
||||||
@@ -472,9 +472,9 @@ class LabelTensor(torch.Tensor):
|
|||||||
Stack tensors vertically. For more details, see :meth:`torch.vstack`.
|
Stack tensors vertically. For more details, see :meth:`torch.vstack`.
|
||||||
|
|
||||||
:param list of LabelTensor label_tensors: The
|
:param list of LabelTensor label_tensors: The
|
||||||
:class:`pina.label_tensor.LabelTensor` instances to stack. They need
|
:class:`~pina.label_tensor.LabelTensor` instances to stack. They need
|
||||||
to have equal labels.
|
to have equal labels.
|
||||||
:return: A new :class:`pina.label_tensor.LabelTensor` instance obtained
|
:return: A new :class:`~pina.label_tensor.LabelTensor` instance obtained
|
||||||
by stacking the input tensors vertically.
|
by stacking the input tensors vertically.
|
||||||
:rtype: LabelTensor
|
:rtype: LabelTensor
|
||||||
"""
|
"""
|
||||||
@@ -530,13 +530,13 @@ class LabelTensor(torch.Tensor):
|
|||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
""" "
|
""" "
|
||||||
Override the __getitem__ method to handle the labels of the
|
Override the __getitem__ method to handle the labels of the
|
||||||
:class:`pina.label_tensor.LabelTensor` instance. It first performs
|
:class:`~pina.label_tensor.LabelTensor` instance. It first performs
|
||||||
__getitem__ operation on the :class:`torch.Tensor` part of the instance,
|
__getitem__ operation on the :class:`torch.Tensor` part of the instance,
|
||||||
then updates the labels based on the index.
|
then updates the labels based on the index.
|
||||||
|
|
||||||
:param index: The index used to access the item
|
:param index: The index used to access the item
|
||||||
:type index: int | str | tuple of int | list ot int | torch.Tensor
|
:type index: int | str | tuple of int | list ot int | torch.Tensor
|
||||||
:return: A new :class:`pina.label_tensor.LabelTensor` instance obtained
|
:return: A new :class:`~pina.label_tensor.LabelTensor` instance obtained
|
||||||
`__getitem__` operation on :class:`torch.Tensor` part of the
|
`__getitem__` operation on :class:`torch.Tensor` part of the
|
||||||
instance, with the updated labels.
|
instance, with the updated labels.
|
||||||
:rtype: LabelTensor
|
:rtype: LabelTensor
|
||||||
@@ -672,7 +672,7 @@ class LabelTensor(torch.Tensor):
|
|||||||
def summation(tensors):
|
def summation(tensors):
|
||||||
"""
|
"""
|
||||||
Computes the summation of a list of
|
Computes the summation of a list of
|
||||||
:class:`pina.label_tensor.LabelTensor` instances.
|
:class:`~pina.label_tensor.LabelTensor` instances.
|
||||||
|
|
||||||
|
|
||||||
:param list[LabelTensor] tensors: A list of tensors to sum. All
|
:param list[LabelTensor] tensors: A list of tensors to sum. All
|
||||||
@@ -719,7 +719,7 @@ class LabelTensor(torch.Tensor):
|
|||||||
For more details, see :meth:`torch.Tensor.reshape`.
|
For more details, see :meth:`torch.Tensor.reshape`.
|
||||||
|
|
||||||
:param tuple of int shape: The new shape of the tensor.
|
:param tuple of int shape: The new shape of the tensor.
|
||||||
:return: A new :class:`pina.label_tensor.LabelTensor` instance with the
|
:return: A new :class:`~pina.label_tensor.LabelTensor` instance with the
|
||||||
updated shape and labels.
|
updated shape and labels.
|
||||||
:rtype: LabelTensor
|
:rtype: LabelTensor
|
||||||
"""
|
"""
|
||||||
|
|||||||
Reference in New Issue
Block a user