diff --git a/pina/condition/condition_interface.py b/pina/condition/condition_interface.py index d31030d..eaee604 100644 --- a/pina/condition/condition_interface.py +++ b/pina/condition/condition_interface.py @@ -45,7 +45,7 @@ class ConditionInterface(metaclass=ABCMeta): objects is consistent. :param data_list: List of graph type objects. - :type data_list: torch_geometric.data.Data | Graph| + :type data_list: torch_geometric.data.Data | Graph| list[torch_geometric.data.Data] | list[Graph] :raises ValueError: Input data must be either torch_geometric.data.Data diff --git a/pina/condition/data_condition.py b/pina/condition/data_condition.py index 5f158af..84fb540 100644 --- a/pina/condition/data_condition.py +++ b/pina/condition/data_condition.py @@ -26,8 +26,8 @@ class DataCondition(ConditionInterface): types of input data. :param input: Input data for the condition. - :type input: torch.Tensor | LabelTensor | Graph | - torch_geometric.data.Data | list[Graph] | + :type input: torch.Tensor | LabelTensor | Graph | + torch_geometric.data.Data | list[Graph] | list[torch_geometric.data.Data] | tuple[Graph] | tuple[torch_geometric.data.Data] :param conditional_variables: Conditional variables for the condition. @@ -63,8 +63,8 @@ class DataCondition(ConditionInterface): variables (if any). :param input: Input data for the condition. - :type input: torch.Tensor | LabelTensor | Graph | - torch_geometric.data.Data | list[Graph] | + :type input: torch.Tensor | LabelTensor | Graph | + torch_geometric.data.Data | list[Graph] | list[torch_geometric.data.Data] | tuple[Graph] | tuple[torch_geometric.data.Data] :param conditional_variables: Conditional variables for the condition. diff --git a/pina/condition/input_target_condition.py b/pina/condition/input_target_condition.py index 7239051..78446e4 100644 --- a/pina/condition/input_target_condition.py +++ b/pina/condition/input_target_condition.py @@ -85,13 +85,13 @@ class InputTargetCondition(ConditionInterface): Initialize the InputTargetCondition, storing the input and target data. :param input: Input data for the condition. - :type input: torch.Tensor | LabelTensor | Graph | - torch_geometric.data.Data | list[Graph] | + :type input: torch.Tensor | LabelTensor | Graph | + torch_geometric.data.Data | list[Graph] | list[torch_geometric.data.Data] | tuple[Graph] | tuple[torch_geometric.data.Data] :param target: Target data for the condition. - :type target: torch.Tensor | LabelTensor | Graph | - torch_geometric.data.Data | list[Graph] | + :type target: torch.Tensor | LabelTensor | Graph | + torch_geometric.data.Data | list[Graph] | list[torch_geometric.data.Data] | tuple[Graph] | tuple[torch_geometric.data.Data] diff --git a/pina/label_tensor.py b/pina/label_tensor.py index 6a91a6c..9109083 100644 --- a/pina/label_tensor.py +++ b/pina/label_tensor.py @@ -668,7 +668,7 @@ class LabelTensor(torch.Tensor): Computes the summation of a list of :class:`LabelTensor` instances. - :param list[LabelTensor] tensors: A list of tensors to sum. All + :param list[LabelTensor] tensors: A list of tensors to sum. All tensors must have the same shape and labels. :return: A new `LabelTensor` containing the element-wise sum of the input tensors.