From ab6ca78d8579b3a724390b9fd2762963ad94a40c Mon Sep 17 00:00:00 2001 From: Filippo Olivo Date: Mon, 3 Mar 2025 09:30:44 +0100 Subject: [PATCH] Simplify Graph class (#459) * Simplifying Graph class and adjust tests --------- Co-authored-by: Dario Coscia --- pina/graph.py | 548 +++++++++--------- tests/test_collector.py | 120 ++-- tests/test_data/test_data_module.py | 186 +++--- tests/test_data/test_graph_dataset.py | 137 +++-- tests/test_graph.py | 465 ++++++++++----- .../test_model/test_graph_neural_operator.py | 141 ++--- .../test_supervised_problem.py | 31 +- 7 files changed, 909 insertions(+), 719 deletions(-) diff --git a/pina/graph.py b/pina/graph.py index 3bfb370..77e426e 100644 --- a/pina/graph.py +++ b/pina/graph.py @@ -1,319 +1,319 @@ -from logging import warning +""" +This module provides an interface to build torch_geometric.data.Data objects. +""" import torch - -from . import LabelTensor from torch_geometric.data import Data from torch_geometric.utils import to_undirected -import inspect +from . import LabelTensor +from .utils import check_consistency, is_function -class Graph: +class Graph(Data): """ - Class for the graph construction. + A class to build torch_geometric.data.Data objects. """ + def __new__( + cls, + **kwargs, + ): + """ + :param kwargs: Parameters to construct the Graph object. + :return: A new instance of the Graph class. + :rtype: Graph + """ + # create class instance + instance = Data.__new__(cls) + + # check the consistency of types defined in __init__, the others are not + # checked (as in pyg Data object) + instance._check_type_consistency(**kwargs) + + return instance + def __init__( self, - x, - pos, - edge_index, + x=None, + edge_index=None, + pos=None, edge_attr=None, - build_edge_attr=False, undirected=False, - custom_build_edge_attr=None, - additional_params=None, + **kwargs, ): """ - Constructor for the Graph class. This object creates a list of PyTorch Geometric Data objects. - Based on the input of x and pos there could be the following cases: - 1. 1 pos, 1 x: a single graph will be created - 2. N pos, 1 x: N graphs will be created with the same node features - 3. 1 pos, N x: N graphs will be created with the same nodes but different node features - 4. N pos, N x: N graphs will be created + Initialize the Graph object. - :param x: Node features. Can be a single 2D tensor of shape [num_nodes, num_node_features], - or a 3D tensor of shape [n_graphs, num_nodes, num_node_features] - or a list of such 2D tensors of shape [num_nodes, num_node_features]. - :type x: torch.Tensor or list[torch.Tensor] - :param pos: Node coordinates. Can be a single 2D tensor of shape [num_nodes, num_coordinates], - or a 3D tensor of shape [n_graphs, num_nodes, num_coordinates] - or a list of such 2D tensors of shape [num_nodes, num_coordinates]. - :type pos: torch.Tensor or list[torch.Tensor] - :param edge_index: The edge index defining connections between nodes. - It should be a 2D tensor of shape [2, num_edges] - or a 3D tensor of shape [n_graphs, 2, num_edges] - or a list of such 2D tensors of shape [2, num_edges]. - :type edge_index: torch.Tensor or list[torch.Tensor] - :param edge_attr: Edge features. If provided, should have the shape [num_edges, num_edge_features] - or be a list of such tensors for multiple graphs. - :type edge_attr: torch.Tensor or list[torch.Tensor], optional - :param build_edge_attr: Whether to compute edge attributes during initialization. - :type build_edge_attr: bool, default=False - :param undirected: If True, converts the graph(s) into an undirected graph by adding reciprocal edges. - :type undirected: bool, default=False - :param custom_build_edge_attr: A user-defined function to generate edge attributes dynamically. - The function should take (x, pos, edge_index) as input and return a tensor - of shape [num_edges, num_edge_features]. - :type custom_build_edge_attr: function or callable, optional - :param additional_params: Dictionary containing extra attributes to be added to each Data object. - Keys represent attribute names, and values should be tensors or lists of tensors. - :type additional_params: dict, optional - - Note: if x, pos, and edge_index are both lists or 3D tensors, then len(x) == len(pos) == len(edge_index). + :param x: Optional tensor of node features (N, F) where F is the number + of features per node. + :type x: torch.Tensor, LabelTensor + :param torch.Tensor edge_index: A tensor of shape (2, E) representing + the indices of the graph's edges. + :param pos: A tensor of shape (N, D) representing the positions of N + points in D-dimensional space. + :type pos: torch.Tensor | LabelTensor + :param edge_attr: Optional tensor of edge_featured (E, F') where F' is + the number of edge features + :param bool undirected: Whether to make the graph undirected + :param kwargs: Additional keyword arguments passed to the + `torch_geometric.data.Data` class constructor. If the argument + is a `torch.Tensor` or `LabelTensor`, it is included in the Data + object as a graph parameter. """ + # preprocessing + self._preprocess_edge_index(edge_index, undirected) - self.data = [] - x, pos, edge_index = self._check_input_consistency(x, pos, edge_index) - - # Check input dimension consistency and store the number of graphs - data_len = self._check_len_consistency(x, pos) - if inspect.isfunction(custom_build_edge_attr): - self._build_edge_attr = custom_build_edge_attr - - # Check consistency and initialize additional_parameters (if present) - additional_params = self._check_additional_params( - additional_params, data_len + # calling init + super().__init__( + x=x, edge_index=edge_index, edge_attr=edge_attr, pos=pos, **kwargs ) - # Make the graphs undirected - if undirected: - if isinstance(edge_index, list): - edge_index = [to_undirected(e) for e in edge_index] - else: - edge_index = to_undirected(edge_index) - - # Prepare internal lists to create a graph list (same positions but - # different node features) - if isinstance(x, list) and isinstance(pos, (torch.Tensor, LabelTensor)): - # Replicate the positions, edge_index and edge_attr - pos, edge_index = [pos] * data_len, [edge_index] * data_len - # Prepare internal lists to create a list containing a single graph - elif isinstance(x, (torch.Tensor, LabelTensor)) and isinstance( - pos, (torch.Tensor, LabelTensor) - ): - # Encapsulate the input tensors into lists - x, pos, edge_index = [x], [pos], [edge_index] - # Prepare internal lists to create a list of graphs (same node features - # but different positions) - elif isinstance(x, (torch.Tensor, LabelTensor)) and isinstance( - pos, list - ): - # Replicate the node features - x = [x] * data_len - elif not isinstance(x, list) and not isinstance(pos, list): - raise TypeError("x and pos must be lists or tensors.") - - # Build the edge attributes - edge_attr = self._check_and_build_edge_attr( - edge_attr, build_edge_attr, data_len, edge_index, pos, x - ) - - # Perform the graph construction - self._build_graph_list(x, pos, edge_index, edge_attr, additional_params) - - def _build_graph_list( - self, x, pos, edge_index, edge_attr, additional_params - ): - for i, (x_, pos_, edge_index_) in enumerate(zip(x, pos, edge_index)): - add_params_local = {k: v[i] for k, v in additional_params.items()} - if edge_attr is not None: - self.data.append( - Data( - x=x_, - pos=pos_, - edge_index=edge_index_, - edge_attr=edge_attr[i], - **add_params_local, - ) - ) - else: - self.data.append( - Data( - x=x_, - pos=pos_, - edge_index=edge_index_, - **add_params_local, - ) - ) + def _check_type_consistency(self, **kwargs): + # default types, specified in cls.__new__, by default they are Nont + # if specified in **kwargs they get override + x, pos, edge_index, edge_attr = None, None, None, None + if "pos" in kwargs: + pos = kwargs["pos"] + self._check_pos_consistency(pos) + if "edge_index" in kwargs: + edge_index = kwargs["edge_index"] + self._check_edge_index_consistency(edge_index) + if "x" in kwargs: + x = kwargs["x"] + self._check_x_consistency(x, pos) + if "edge_attr" in kwargs: + edge_attr = kwargs["edge_attr"] + self._check_edge_attr_consistency(edge_attr, edge_index) + if "undirected" in kwargs: + undirected = kwargs["undirected"] + check_consistency(undirected, bool) @staticmethod - def _build_edge_attr(x, pos, edge_index): - distance = torch.abs( - pos[edge_index[0]] - pos[edge_index[1]] - ).as_subclass(torch.Tensor) - return distance + def _check_pos_consistency(pos): + """ + Check if the position tensor is consistent. + :param torch.Tensor pos: The position tensor. + """ + if pos is not None: + check_consistency(pos, (torch.Tensor, LabelTensor)) + if pos.ndim != 2: + raise ValueError("pos must be a 2D tensor.") @staticmethod - def _check_len_consistency(x, pos): - if isinstance(x, list) and isinstance(pos, list): - if len(x) != len(pos): - raise ValueError("x and pos must have the same length.") - return max(len(x), len(pos)) - elif isinstance(x, list) and not isinstance(pos, list): - return len(x) - elif not isinstance(x, list) and isinstance(pos, list): - return len(pos) - else: - return 1 + def _check_edge_index_consistency(edge_index): + """ + Check if the edge index is consistent. + :param torch.Tensor edge_index: The edge index tensor. + """ + check_consistency(edge_index, (torch.Tensor, LabelTensor)) + if edge_index.ndim != 2: + raise ValueError("edge_index must be a 2D tensor.") + if edge_index.size(0) != 2: + raise ValueError("edge_index must have shape [2, num_edges].") @staticmethod - def _check_input_consistency(x, pos, edge_index=None): - # If x is a 3D tensor, we split it into a list of 2D tensors - if isinstance(x, torch.Tensor) and x.ndim == 3: - x = [x[i] for i in range(x.shape[0])] - elif not (isinstance(x, list) and all(t.ndim == 2 for t in x)) and not ( - isinstance(x, torch.Tensor) and x.ndim == 2 - ): - raise TypeError( - "x must be either a list of 2D tensors or a 2D " - "tensor or a 3D tensor" - ) + def _check_edge_attr_consistency(edge_attr, edge_index): + """ + Check if the edge attr is consistent. + :param torch.Tensor edge_attr: The edge attribute tensor. - # If pos is a 3D tensor, we split it into a list of 2D tensors - if isinstance(pos, torch.Tensor) and pos.ndim == 3: - pos = [pos[i] for i in range(pos.shape[0])] - elif not ( - isinstance(pos, list) and all(t.ndim == 2 for t in pos) - ) and not (isinstance(pos, torch.Tensor) and pos.ndim == 2): - raise TypeError( - "pos must be either a list of 2D tensors or a 2D " - "tensor or a 3D tensor" - ) - - # If edge_index is a 3D tensor, we split it into a list of 2D tensors - if edge_index is not None: - if isinstance(edge_index, torch.Tensor) and edge_index.ndim == 3: - edge_index = [edge_index[i] for i in range(edge_index.shape[0])] - elif not ( - isinstance(edge_index, list) - and all(t.ndim == 2 for t in edge_index) - ) and not ( - isinstance(edge_index, torch.Tensor) and edge_index.ndim == 2 - ): - raise TypeError( - "edge_index must be either a list of 2D tensors or a 2D " - "tensor or a 3D tensor" - ) - - return x, pos, edge_index - - @staticmethod - def _check_additional_params(additional_params, data_len): - if additional_params is not None: - if not isinstance(additional_params, dict): - raise TypeError("additional_params must be a dictionary.") - for param, val in additional_params.items(): - # Check if the values are tensors or lists of tensors - if isinstance(val, torch.Tensor): - # If the tensor is 3D, we split it into a list of 2D tensors - # In this case there must be a additional parameter for each - # node - if val.ndim == 3: - additional_params[param] = [ - val[i] for i in range(val.shape[0]) - ] - # If the tensor is 2D, we replicate it for each node - elif val.ndim == 2: - additional_params[param] = [val] * data_len - # If the tensor is 1D, each graph has a scalar values as - # additional parameter - if val.ndim == 1: - if len(val) == data_len: - additional_params[param] = [ - val[i] for i in range(len(val)) - ] - else: - additional_params[param] = [ - val for _ in range(data_len) - ] - elif not isinstance(val, list): - raise TypeError( - "additional_params values must be tensors " - "or lists of tensors." - ) - else: - additional_params = {} - return additional_params - - def _check_and_build_edge_attr( - self, edge_attr, build_edge_attr, data_len, edge_index, pos, x - ): - # Check if edge_attr is consistent with x and pos + :param torch.Tensor edge_index: The edge index tensor. + """ if edge_attr is not None: - if build_edge_attr is True: - warning( - "edge_attr is not None. build_edge_attr will not be " - "considered." + check_consistency(edge_attr, (torch.Tensor, LabelTensor)) + if edge_attr.ndim != 2: + raise ValueError("edge_attr must be a 2D tensor.") + if edge_attr.size(0) != edge_index.size(1): + raise ValueError( + "edge_attr must have shape " + "[num_edges, num_edge_features], expected " + f"num_edges {edge_index.size(1)} " + f"got {edge_attr.size(0)}." ) - if isinstance(edge_attr, list): - if len(edge_attr) != data_len: - raise TypeError( - "edge_attr must have the same length as x " "and pos." - ) - return [edge_attr] * data_len - - if build_edge_attr: - return [ - self._build_edge_attr(x_, pos_, edge_index_) - for x_, pos_, edge_index_ in zip(x, pos, edge_index) - ] - - -class RadiusGraph(Graph): - def __init__(self, x, pos, r, **kwargs): - x, pos, edge_index = Graph._check_input_consistency(x, pos) - - if isinstance(pos, (torch.Tensor, LabelTensor)): - edge_index = RadiusGraph._radius_graph(pos, r) - else: - edge_index = [RadiusGraph._radius_graph(p, r) for p in pos] - - super().__init__(x=x, pos=pos, edge_index=edge_index, **kwargs) @staticmethod - def _radius_graph(points, r): + def _check_x_consistency(x, pos=None): """ - Implementation of the radius graph construction. - :param points: The input points. - :type points: torch.Tensor - :param r: The radius. - :type r: float - :return: The edge index. + Check if the input tensor x is consistent with the position tensor pos. + :param torch.Tensor x: The input tensor. + :param torch.Tensor pos: The position tensor. + """ + if x is not None: + check_consistency(x, (torch.Tensor, LabelTensor)) + if x.ndim != 2: + raise ValueError("x must be a 2D tensor.") + if pos is not None: + if x.size(0) != pos.size(0): + raise ValueError("Inconsistent number of nodes.") + if pos is not None: + if x.size(0) != pos.size(0): + raise ValueError("Inconsistent number of nodes.") + + @staticmethod + def _preprocess_edge_index(edge_index, undirected): + """ + Preprocess the edge index. + :param torch.Tensor edge_index: The edge index. + :param bool undirected: Whether the graph is undirected. + :return: The preprocessed edge index. :rtype: torch.Tensor """ - dist = torch.cdist(points, points, p=2) - edge_index = torch.nonzero(dist <= r, as_tuple=False).t() - if isinstance(edge_index, LabelTensor): - edge_index = edge_index.tensor + if undirected: + edge_index = to_undirected(edge_index) return edge_index -class KNNGraph(Graph): - def __init__(self, x, pos, k, **kwargs): - x, pos, edge_index = Graph._check_input_consistency(x, pos) - if isinstance(pos, (torch.Tensor, LabelTensor)): - edge_index = KNNGraph._knn_graph(pos, k) - else: - edge_index = [KNNGraph._knn_graph(p, k) for p in pos] - super().__init__(x=x, pos=pos, edge_index=edge_index, **kwargs) +class GraphBuilder: + """ + A class that allows the simple definition of Graph instances. + """ + + def __new__( + cls, + pos, + edge_index, + x=None, + edge_attr=False, + custom_edge_func=None, + **kwargs, + ): + """ + Creates a new instance of the Graph class. + + :param pos: A tensor of shape (N, D) representing the positions of N + points in D-dimensional space. + :type pos: torch.Tensor | LabelTensor + :param edge_index: A tensor of shape (2, E) representing the indices of + the graph's edges. + :type edge_index: torch.Tensor + :param x: Optional tensor of node features (N, F) where F is the number + of features per node. + :type x: torch.Tensor, LabelTensor + :param bool edge_attr: Optional edge attributes (E, F) where F is the + number of features per edge. + :param callable custom_edge_func: A custom function to compute edge + attributes. + :param kwargs: Additional keyword arguments passed to the Graph class + constructor. + :return: A Graph instance constructed using the provided information. + :rtype: Graph + """ + edge_attr = cls._create_edge_attr( + pos, edge_index, edge_attr, custom_edge_func or cls._build_edge_attr + ) + return Graph( + x=x, + edge_index=edge_index, + edge_attr=edge_attr, + pos=pos, + **kwargs, + ) @staticmethod - def _knn_graph(points, k): + def _create_edge_attr(pos, edge_index, edge_attr, func): + check_consistency(edge_attr, bool) + if edge_attr: + if is_function(func): + return func(pos, edge_index) + raise ValueError("custom_edge_func must be a function.") + return None + + @staticmethod + def _build_edge_attr(pos, edge_index): + return ( + (pos[edge_index[0]] - pos[edge_index[1]]) + .abs() + .as_subclass(torch.Tensor) + ) + + +class RadiusGraph(GraphBuilder): + """ + A class to build a radius graph. + """ + + def __new__(cls, pos, radius, **kwargs): """ - Implementation of the k-nearest neighbors graph construction. - :param points: The input points. - :type points: torch.Tensor - :param k: The number of nearest neighbors. - :type k: int - :return: The edge index. - :rtype: torch.Tensor + Creates a new instance of the Graph class using a radius-based graph + construction. + + :param pos: A tensor of shape (N, D) representing the positions of N + points in D-dimensional space. + :type pos: torch.Tensor | LabelTensor + :param float radius: The radius within which points are connected. + :Keyword Arguments: + The additional keyword arguments to be passed to GraphBuilder + and Graph classes + :return: Graph instance containg the information passed in input and + the computed edge_index + :rtype: Graph """ + edge_index = cls.compute_radius_graph(pos, radius) + return super().__new__(cls, pos=pos, edge_index=edge_index, **kwargs) + + @staticmethod + def compute_radius_graph(points, radius): + """ + Computes a radius-based graph for a given set of points. + + :param points: A tensor of shape (N, D) representing the positions of + N points in D-dimensional space. + :type points: torch.Tensor | LabelTensor + :param float radius: The number of nearest neighbors to find for each + point. + :rtype torch.Tensor: A tensor of shape (2, E), where E is the number of + edges, representing the edge indices of the KNN graph. + """ + dist = torch.cdist(points, points, p=2) + return ( + torch.nonzero(dist <= radius, as_tuple=False) + .t() + .as_subclass(torch.Tensor) + ) + + +class KNNGraph(GraphBuilder): + """ + A class to build a KNN graph. + """ + + def __new__(cls, pos, neighbours, **kwargs): + """ + Creates a new instance of the Graph class using k-nearest neighbors + to compute edge_index. + + :param pos: A tensor of shape (N, D) representing the positions of N + points in D-dimensional space. + :type pos: torch.Tensor | LabelTensor + :param int neighbours: The number of nearest neighbors to consider when + building the graph. + :Keyword Arguments: + The additional keyword arguments to be passed to GraphBuilder + and Graph classes + + :return: Graph instance containg the information passed in input and + the computed edge_index + :rtype: Graph + """ + + edge_index = cls.compute_knn_graph(pos, neighbours) + return super().__new__(cls, pos=pos, edge_index=edge_index, **kwargs) + + @staticmethod + def compute_knn_graph(points, k): + """ + Computes the edge_index based k-nearest neighbors graph algorithm + + :param points: A tensor of shape (N, D) representing the positions of + N points in D-dimensional space. + :type points: torch.Tensor | LabelTensor + :param int k: The number of nearest neighbors to find for each point. + :rtype torch.Tensor: A tensor of shape (2, E), where E is the number of + edges, representing the edge indices of the KNN graph. + """ + dist = torch.cdist(points, points, p=2) knn_indices = torch.topk(dist, k=k + 1, largest=False).indices[:, 1:] row = torch.arange(points.size(0)).repeat_interleave(k) col = knn_indices.flatten() - edge_index = torch.stack([row, col], dim=0) - if isinstance(edge_index, LabelTensor): - edge_index = edge_index.tensor - return edge_index + return torch.stack([row, col], dim=0).as_subclass(torch.Tensor) diff --git a/tests/test_collector.py b/tests/test_collector.py index 284d643..565fed3 100644 --- a/tests/test_collector.py +++ b/tests/test_collector.py @@ -15,12 +15,18 @@ def test_supervised_tensor_collector(): class SupervisedProblem(AbstractProblem): output_variables = None conditions = { - 'data1': Condition(input_points=torch.rand((10, 2)), - output_points=torch.rand((10, 2))), - 'data2': Condition(input_points=torch.rand((20, 2)), - output_points=torch.rand((20, 2))), - 'data3': Condition(input_points=torch.rand((30, 2)), - output_points=torch.rand((30, 2))), + "data1": Condition( + input_points=torch.rand((10, 2)), + output_points=torch.rand((10, 2)), + ), + "data2": Condition( + input_points=torch.rand((20, 2)), + output_points=torch.rand((20, 2)), + ), + "data3": Condition( + input_points=torch.rand((30, 2)), + output_points=torch.rand((30, 2)), + ), } problem = SupervisedProblem() @@ -31,65 +37,58 @@ def test_supervised_tensor_collector(): def test_pinn_collector(): def laplace_equation(input_, output_): - force_term = (torch.sin(input_.extract(['x']) * torch.pi) * - torch.sin(input_.extract(['y']) * torch.pi)) - delta_u = laplacian(output_.extract(['u']), input_) + force_term = torch.sin(input_.extract(["x"]) * torch.pi) * torch.sin( + input_.extract(["y"]) * torch.pi + ) + delta_u = laplacian(output_.extract(["u"]), input_) return delta_u - force_term my_laplace = Equation(laplace_equation) - in_ = LabelTensor(torch.tensor([[0., 1.]], requires_grad=True), ['x', 'y']) - out_ = LabelTensor(torch.tensor([[0.]], requires_grad=True), ['u']) + in_ = LabelTensor( + torch.tensor([[0.0, 1.0]], requires_grad=True), ["x", "y"] + ) + out_ = LabelTensor(torch.tensor([[0.0]], requires_grad=True), ["u"]) class Poisson(SpatialProblem): - output_variables = ['u'] - spatial_domain = CartesianDomain({'x': [0, 1], 'y': [0, 1]}) + output_variables = ["u"] + spatial_domain = CartesianDomain({"x": [0, 1], "y": [0, 1]}) conditions = { - 'gamma1': - Condition(domain=CartesianDomain({ - 'x': [0, 1], - 'y': 1 - }), - equation=FixedValue(0.0)), - 'gamma2': - Condition(domain=CartesianDomain({ - 'x': [0, 1], - 'y': 0 - }), - equation=FixedValue(0.0)), - 'gamma3': - Condition(domain=CartesianDomain({ - 'x': 1, - 'y': [0, 1] - }), - equation=FixedValue(0.0)), - 'gamma4': - Condition(domain=CartesianDomain({ - 'x': 0, - 'y': [0, 1] - }), - equation=FixedValue(0.0)), - 'D': - Condition(domain=CartesianDomain({ - 'x': [0, 1], - 'y': [0, 1] - }), - equation=my_laplace), - 'data': - Condition(input_points=in_, output_points=out_) + "gamma1": Condition( + domain=CartesianDomain({"x": [0, 1], "y": 1}), + equation=FixedValue(0.0), + ), + "gamma2": Condition( + domain=CartesianDomain({"x": [0, 1], "y": 0}), + equation=FixedValue(0.0), + ), + "gamma3": Condition( + domain=CartesianDomain({"x": 1, "y": [0, 1]}), + equation=FixedValue(0.0), + ), + "gamma4": Condition( + domain=CartesianDomain({"x": 0, "y": [0, 1]}), + equation=FixedValue(0.0), + ), + "D": Condition( + domain=CartesianDomain({"x": [0, 1], "y": [0, 1]}), + equation=my_laplace, + ), + "data": Condition(input_points=in_, output_points=out_), } def poisson_sol(self, pts): - return -(torch.sin(pts.extract(['x']) * torch.pi) * - torch.sin(pts.extract(['y']) * torch.pi)) / ( - 2 * torch.pi ** 2) + return -( + torch.sin(pts.extract(["x"]) * torch.pi) + * torch.sin(pts.extract(["y"]) * torch.pi) + ) / (2 * torch.pi**2) truth_solution = poisson_sol problem = Poisson() - boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4'] - problem.discretise_domain(10, 'grid', domains=boundaries) - problem.discretise_domain(10, 'grid', domains='D') + boundaries = ["gamma1", "gamma2", "gamma3", "gamma4"] + problem.discretise_domain(10, "grid", domains=boundaries) + problem.discretise_domain(10, "grid", domains="D") collector = Collector(problem) collector.store_fixed_data() @@ -98,31 +97,34 @@ def test_pinn_collector(): for k, v in problem.conditions.items(): if isinstance(v, InputOutputPointsCondition): assert list(collector.data_collections[k].keys()) == [ - 'input_points', 'output_points'] + "input_points", + "output_points", + ] for k, v in problem.conditions.items(): if isinstance(v, DomainEquationCondition): assert list(collector.data_collections[k].keys()) == [ - 'input_points', 'equation'] + "input_points", + "equation", + ] def test_supervised_graph_collector(): pos = torch.rand((100, 3)) x = [torch.rand((100, 3)) for _ in range(10)] - graph_list_1 = RadiusGraph(pos=pos, x=x, build_edge_attr=True, r=.4) + graph_list_1 = [RadiusGraph(pos=pos, radius=0.4, x=x_) for x_ in x] out_1 = torch.rand((10, 100, 3)) + pos = torch.rand((50, 3)) x = [torch.rand((50, 3)) for _ in range(10)] - graph_list_2 = RadiusGraph(pos=pos, x=x, build_edge_attr=True, r=.4) + graph_list_2 = [RadiusGraph(pos=pos, radius=0.4, x=x_) for x_ in x] out_2 = torch.rand((10, 50, 3)) class SupervisedProblem(AbstractProblem): output_variables = None conditions = { - 'data1': Condition(input_points=graph_list_1, - output_points=out_1), - 'data2': Condition(input_points=graph_list_2, - output_points=out_2), + "data1": Condition(input_points=graph_list_1, output_points=out_1), + "data2": Condition(input_points=graph_list_2, output_points=out_2), } problem = SupervisedProblem() diff --git a/tests/test_data/test_data_module.py b/tests/test_data/test_data_module.py index 19b0ec3..2d7de9d 100644 --- a/tests/test_data/test_data_module.py +++ b/tests/test_data/test_data_module.py @@ -15,16 +15,15 @@ output_tensor = torch.rand((100, 2)) x = torch.rand((100, 50, 10)) pos = torch.rand((100, 50, 2)) -input_graph = RadiusGraph(x, pos, r=.1, build_edge_attr=True) +input_graph = [ + RadiusGraph(x=x_, pos=pos_, radius=0.2) for x_, pos_, in zip(x, pos) +] output_graph = torch.rand((100, 50, 10)) @pytest.mark.parametrize( "input_, output_", - [ - (input_tensor, output_tensor), - (input_graph, output_graph) - ] + [(input_tensor, output_tensor), (input_graph, output_graph)], ) def test_constructor(input_, output_): problem = SupervisedProblem(input_=input_, output_=output_) @@ -33,22 +32,16 @@ def test_constructor(input_, output_): @pytest.mark.parametrize( "input_, output_", - [ - (input_tensor, output_tensor), - (input_graph, output_graph) - ] + [(input_tensor, output_tensor), (input_graph, output_graph)], ) @pytest.mark.parametrize( - "train_size, val_size, test_size", - [ - (.7, .2, .1), - (.7, .3, 0) - ] + "train_size, val_size, test_size", [(0.7, 0.2, 0.1), (0.7, 0.3, 0)] ) def test_setup_train(input_, output_, train_size, val_size, test_size): problem = SupervisedProblem(input_=input_, output_=output_) - dm = PinaDataModule(problem, train_size=train_size, - val_size=val_size, test_size=test_size) + dm = PinaDataModule( + problem, train_size=train_size, val_size=val_size, test_size=test_size + ) dm.setup() assert hasattr(dm, "train_dataset") if isinstance(input_, torch.Tensor): @@ -71,23 +64,17 @@ def test_setup_train(input_, output_, train_size, val_size, test_size): @pytest.mark.parametrize( "input_, output_", - [ - (input_tensor, output_tensor), - (input_graph, output_graph) - ] + [(input_tensor, output_tensor), (input_graph, output_graph)], ) @pytest.mark.parametrize( - "train_size, val_size, test_size", - [ - (.7, .2, .1), - (0., 0., 1.) - ] + "train_size, val_size, test_size", [(0.7, 0.2, 0.1), (0.0, 0.0, 1.0)] ) def test_setup_test(input_, output_, train_size, val_size, test_size): problem = SupervisedProblem(input_=input_, output_=output_) - dm = PinaDataModule(problem, train_size=train_size, - val_size=val_size, test_size=test_size) - dm.setup(stage='test') + dm = PinaDataModule( + problem, train_size=train_size, val_size=val_size, test_size=test_size + ) + dm.setup(stage="test") if train_size > 0: assert hasattr(dm, "train_dataset") assert dm.train_dataset is None @@ -109,16 +96,14 @@ def test_setup_test(input_, output_, train_size, val_size, test_size): @pytest.mark.parametrize( "input_, output_", - [ - (input_tensor, output_tensor), - (input_graph, output_graph) - ] + [(input_tensor, output_tensor), (input_graph, output_graph)], ) def test_dummy_dataloader(input_, output_): problem = SupervisedProblem(input_=input_, output_=output_) solver = SupervisedSolver(problem=problem, model=torch.nn.Linear(10, 10)) - trainer = Trainer(solver, batch_size=None, train_size=.7, - val_size=.3, test_size=0.) + trainer = Trainer( + solver, batch_size=None, train_size=0.7, val_size=0.3, test_size=0.0 + ) dm = trainer.data_module dm.setup() dm.trainer = trainer @@ -128,11 +113,11 @@ def test_dummy_dataloader(input_, output_): data = next(dataloader) assert isinstance(data, list) assert isinstance(data[0], tuple) - if isinstance(input_, RadiusGraph): - assert isinstance(data[0][1]['input_points'], Batch) + if isinstance(input_, list): + assert isinstance(data[0][1]["input_points"], Batch) else: - assert isinstance(data[0][1]['input_points'], torch.Tensor) - assert isinstance(data[0][1]['output_points'], torch.Tensor) + assert isinstance(data[0][1]["input_points"], torch.Tensor) + assert isinstance(data[0][1]["output_points"], torch.Tensor) dataloader = dm.val_dataloader() assert isinstance(dataloader, DummyDataloader) @@ -140,31 +125,29 @@ def test_dummy_dataloader(input_, output_): data = next(dataloader) assert isinstance(data, list) assert isinstance(data[0], tuple) - if isinstance(input_, RadiusGraph): - assert isinstance(data[0][1]['input_points'], Batch) + if isinstance(input_, list): + assert isinstance(data[0][1]["input_points"], Batch) else: - assert isinstance(data[0][1]['input_points'], torch.Tensor) - assert isinstance(data[0][1]['output_points'], torch.Tensor) + assert isinstance(data[0][1]["input_points"], torch.Tensor) + assert isinstance(data[0][1]["output_points"], torch.Tensor) @pytest.mark.parametrize( "input_, output_", - [ - (input_tensor, output_tensor), - (input_graph, output_graph) - ] -) -@pytest.mark.parametrize( - "automatic_batching", - [ - True, False - ] + [(input_tensor, output_tensor), (input_graph, output_graph)], ) +@pytest.mark.parametrize("automatic_batching", [True, False]) def test_dataloader(input_, output_, automatic_batching): problem = SupervisedProblem(input_=input_, output_=output_) solver = SupervisedSolver(problem=problem, model=torch.nn.Linear(10, 10)) - trainer = Trainer(solver, batch_size=10, train_size=.7, val_size=.3, - test_size=0., automatic_batching=automatic_batching) + trainer = Trainer( + solver, + batch_size=10, + train_size=0.7, + val_size=0.3, + test_size=0.0, + automatic_batching=automatic_batching, + ) dm = trainer.data_module dm.setup() dm.trainer = trainer @@ -173,51 +156,53 @@ def test_dataloader(input_, output_, automatic_batching): assert len(dataloader) == 7 data = next(iter(dataloader)) assert isinstance(data, dict) - if isinstance(input_, RadiusGraph): - assert isinstance(data['data']['input_points'], Batch) + if isinstance(input_, list): + assert isinstance(data["data"]["input_points"], Batch) else: - assert isinstance(data['data']['input_points'], torch.Tensor) - assert isinstance(data['data']['output_points'], torch.Tensor) + assert isinstance(data["data"]["input_points"], torch.Tensor) + assert isinstance(data["data"]["output_points"], torch.Tensor) dataloader = dm.val_dataloader() assert isinstance(dataloader, DataLoader) assert len(dataloader) == 3 data = next(iter(dataloader)) assert isinstance(data, dict) - if isinstance(input_, RadiusGraph): - assert isinstance(data['data']['input_points'], Batch) + if isinstance(input_, list): + assert isinstance(data["data"]["input_points"], Batch) else: - assert isinstance(data['data']['input_points'], torch.Tensor) - assert isinstance(data['data']['output_points'], torch.Tensor) + assert isinstance(data["data"]["input_points"], torch.Tensor) + assert isinstance(data["data"]["output_points"], torch.Tensor) + from pina import LabelTensor -input_tensor = LabelTensor(torch.rand((100, 3)), ['u', 'v', 'w']) -output_tensor = LabelTensor(torch.rand((100, 3)), ['u', 'v', 'w']) +input_tensor = LabelTensor(torch.rand((100, 3)), ["u", "v", "w"]) +output_tensor = LabelTensor(torch.rand((100, 3)), ["u", "v", "w"]) + +x = LabelTensor(torch.rand((100, 50, 3)), ["u", "v", "w"]) +pos = LabelTensor(torch.rand((100, 50, 2)), ["x", "y"]) +input_graph = [ + RadiusGraph(x=x[i], pos=pos[i], radius=0.1) for i in range(len(x)) +] +output_graph = LabelTensor(torch.rand((100, 50, 3)), ["u", "v", "w"]) -x = LabelTensor(torch.rand((100, 50, 3)), ['u', 'v', 'w']) -pos = LabelTensor(torch.rand((100, 50, 2)), ['x', 'y']) -input_graph = RadiusGraph(x, pos, r=.1, build_edge_attr=True) -output_graph = LabelTensor(torch.rand((100, 50, 3)), ['u', 'v', 'w']) @pytest.mark.parametrize( "input_, output_", - [ - (input_tensor, output_tensor), - (input_graph, output_graph) - ] -) -@pytest.mark.parametrize( - "automatic_batching", - [ - True, False - ] + [(input_tensor, output_tensor), (input_graph, output_graph)], ) +@pytest.mark.parametrize("automatic_batching", [True, False]) def test_dataloader_labels(input_, output_, automatic_batching): problem = SupervisedProblem(input_=input_, output_=output_) solver = SupervisedSolver(problem=problem, model=torch.nn.Linear(10, 10)) - trainer = Trainer(solver, batch_size=10, train_size=.7, val_size=.3, - test_size=0., automatic_batching=automatic_batching) + trainer = Trainer( + solver, + batch_size=10, + train_size=0.7, + val_size=0.3, + test_size=0.0, + automatic_batching=automatic_batching, + ) dm = trainer.data_module dm.setup() dm.trainer = trainer @@ -226,31 +211,30 @@ def test_dataloader_labels(input_, output_, automatic_batching): assert len(dataloader) == 7 data = next(iter(dataloader)) assert isinstance(data, dict) - if isinstance(input_, RadiusGraph): - assert isinstance(data['data']['input_points'], Batch) - assert isinstance(data['data']['input_points'].x, LabelTensor) - assert data['data']['input_points'].x.labels == ['u', 'v', 'w'] - assert data['data']['input_points'].pos.labels == ['x', 'y'] - else: - assert isinstance(data['data']['input_points'], LabelTensor) - assert data['data']['input_points'].labels == ['u', 'v', 'w'] - assert isinstance(data['data']['output_points'], LabelTensor) - assert data['data']['output_points'].labels == ['u', 'v', 'w'] + if isinstance(input_, list): + assert isinstance(data["data"]["input_points"], Batch) + assert isinstance(data["data"]["input_points"].x, LabelTensor) + assert data["data"]["input_points"].x.labels == ["u", "v", "w"] + assert data["data"]["input_points"].pos.labels == ["x", "y"] + else: + assert isinstance(data["data"]["input_points"], LabelTensor) + assert data["data"]["input_points"].labels == ["u", "v", "w"] + assert isinstance(data["data"]["output_points"], LabelTensor) + assert data["data"]["output_points"].labels == ["u", "v", "w"] dataloader = dm.val_dataloader() assert isinstance(dataloader, DataLoader) assert len(dataloader) == 3 data = next(iter(dataloader)) assert isinstance(data, dict) - if isinstance(input_, RadiusGraph): - assert isinstance(data['data']['input_points'], Batch) - assert isinstance(data['data']['input_points'].x, LabelTensor) - assert data['data']['input_points'].x.labels == ['u', 'v', 'w'] - assert data['data']['input_points'].pos.labels == ['x', 'y'] + if isinstance(input_, list): + assert isinstance(data["data"]["input_points"], Batch) + assert isinstance(data["data"]["input_points"].x, LabelTensor) + assert data["data"]["input_points"].x.labels == ["u", "v", "w"] + assert data["data"]["input_points"].pos.labels == ["x", "y"] else: - assert isinstance(data['data']['input_points'], torch.Tensor) - assert isinstance(data['data']['input_points'], LabelTensor) - assert data['data']['input_points'].labels == ['u', 'v', 'w'] - assert isinstance(data['data']['output_points'], torch.Tensor) - assert data['data']['output_points'].labels == ['u', 'v', 'w'] -test_dataloader_labels(input_graph, output_graph, True) \ No newline at end of file + assert isinstance(data["data"]["input_points"], torch.Tensor) + assert isinstance(data["data"]["input_points"], LabelTensor) + assert data["data"]["input_points"].labels == ["u", "v", "w"] + assert isinstance(data["data"]["output_points"], torch.Tensor) + assert data["data"]["output_points"].labels == ["u", "v", "w"] diff --git a/tests/test_data/test_graph_dataset.py b/tests/test_data/test_graph_dataset.py index 15da4bf..4acb19e 100644 --- a/tests/test_data/test_graph_dataset.py +++ b/tests/test_data/test_graph_dataset.py @@ -6,55 +6,58 @@ from torch_geometric.data import Data x = torch.rand((100, 20, 10)) pos = torch.rand((100, 20, 2)) -input_ = KNNGraph(x=x, pos=pos, k=3, build_edge_attr=True) +input_ = [ + KNNGraph(x=x_, pos=pos_, neighbours=3, edge_attr=True) + for x_, pos_ in zip(x, pos) +] output_ = torch.rand((100, 20, 10)) x_2 = torch.rand((50, 20, 10)) pos_2 = torch.rand((50, 20, 2)) -input_2_ = KNNGraph(x=x_2, pos=pos_2, k=3, build_edge_attr=True) +input_2_ = [ + KNNGraph(x=x_, pos=pos_, neighbours=3, edge_attr=True) + for x_, pos_ in zip(x_2, pos_2) +] output_2_ = torch.rand((50, 20, 10)) # Problem with a single condition conditions_dict_single = { - 'data': { - 'input_points': input_.data, - 'output_points': output_, + "data": { + "input_points": input_, + "output_points": output_, } } -max_conditions_lengths_single = { - 'data': 100 -} +max_conditions_lengths_single = {"data": 100} # Problem with multiple conditions conditions_dict_single_multi = { - 'data_1': { - 'input_points': input_.data, - 'output_points': output_, + "data_1": { + "input_points": input_, + "output_points": output_, + }, + "data_2": { + "input_points": input_2_, + "output_points": output_2_, }, - 'data_2': { - 'input_points': input_2_.data, - 'output_points': output_2_, - } } -max_conditions_lengths_multi = { - 'data_1': 100, - 'data_2': 50 -} +max_conditions_lengths_multi = {"data_1": 100, "data_2": 50} @pytest.mark.parametrize( "conditions_dict, max_conditions_lengths", [ (conditions_dict_single, max_conditions_lengths_single), - (conditions_dict_single_multi, max_conditions_lengths_multi) - ] + (conditions_dict_single_multi, max_conditions_lengths_multi), + ], ) def test_constructor(conditions_dict, max_conditions_lengths): - dataset = PinaDatasetFactory(conditions_dict, - max_conditions_lengths=max_conditions_lengths, - automatic_batching=True) + dataset = PinaDatasetFactory( + conditions_dict, + max_conditions_lengths=max_conditions_lengths, + automatic_batching=True, + ) assert isinstance(dataset, PinaGraphDataset) assert len(dataset) == 100 @@ -63,39 +66,67 @@ def test_constructor(conditions_dict, max_conditions_lengths): "conditions_dict, max_conditions_lengths", [ (conditions_dict_single, max_conditions_lengths_single), - (conditions_dict_single_multi, max_conditions_lengths_multi) - ] + (conditions_dict_single_multi, max_conditions_lengths_multi), + ], ) def test_getitem(conditions_dict, max_conditions_lengths): - dataset = PinaDatasetFactory(conditions_dict, - max_conditions_lengths=max_conditions_lengths, - automatic_batching=True) + dataset = PinaDatasetFactory( + conditions_dict, + max_conditions_lengths=max_conditions_lengths, + automatic_batching=True, + ) data = dataset[50] assert isinstance(data, dict) - assert all([isinstance(d['input_points'], Data) - for d in data.values()]) - assert all([isinstance(d['output_points'], torch.Tensor) - for d in data.values()]) - assert all([d['input_points'].x.shape == torch.Size((20, 10)) - for d in data.values()]) - assert all([d['output_points'].shape == torch.Size((20, 10)) - for d in data.values()]) - assert all([d['input_points'].edge_index.shape == - torch.Size((2, 60)) for d in data.values()]) - assert all([d['input_points'].edge_attr.shape[0] - == 60 for d in data.values()]) + assert all([isinstance(d["input_points"], Data) for d in data.values()]) + assert all( + [isinstance(d["output_points"], torch.Tensor) for d in data.values()] + ) + assert all( + [ + d["input_points"].x.shape == torch.Size((20, 10)) + for d in data.values() + ] + ) + assert all( + [ + d["output_points"].shape == torch.Size((20, 10)) + for d in data.values() + ] + ) + assert all( + [ + d["input_points"].edge_index.shape == torch.Size((2, 60)) + for d in data.values() + ] + ) + assert all( + [d["input_points"].edge_attr.shape[0] == 60 for d in data.values()] + ) data = dataset.fetch_from_idx_list([i for i in range(20)]) assert isinstance(data, dict) - assert all([isinstance(d['input_points'], Data) - for d in data.values()]) - assert all([isinstance(d['output_points'], torch.Tensor) - for d in data.values()]) - assert all([d['input_points'].x.shape == torch.Size((400, 10)) - for d in data.values()]) - assert all([d['output_points'].shape == torch.Size((400, 10)) - for d in data.values()]) - assert all([d['input_points'].edge_index.shape == - torch.Size((2, 1200)) for d in data.values()]) - assert all([d['input_points'].edge_attr.shape[0] - == 1200 for d in data.values()]) + assert all([isinstance(d["input_points"], Data) for d in data.values()]) + assert all( + [isinstance(d["output_points"], torch.Tensor) for d in data.values()] + ) + assert all( + [ + d["input_points"].x.shape == torch.Size((400, 10)) + for d in data.values() + ] + ) + assert all( + [ + d["output_points"].shape == torch.Size((400, 10)) + for d in data.values() + ] + ) + assert all( + [ + d["input_points"].edge_index.shape == torch.Size((2, 1200)) + for d in data.values() + ] + ) + assert all( + [d["input_points"].edge_attr.shape[0] == 1200 for d in data.values()] + ) diff --git a/tests/test_graph.py b/tests/test_graph.py index 660ec34..bf053a8 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -1,163 +1,346 @@ import pytest import torch -from pina.graph import RadiusGraph, KNNGraph +from pina import LabelTensor +from pina.graph import RadiusGraph, KNNGraph, Graph +from torch_geometric.data import Data + + +def build_edge_attr(pos, edge_index): + return torch.cat([pos[edge_index[0]], pos[edge_index[1]]], dim=-1) @pytest.mark.parametrize( "x, pos", [ - ([torch.rand(10, 2) for _ in range(3)], - [torch.rand(10, 3) for _ in range(3)]), - ([torch.rand(10, 2) for _ in range(3)], - [torch.rand(10, 3) for _ in range(3)]), - (torch.rand(3, 10, 2), torch.rand(3, 10, 3)), - (torch.rand(3, 10, 2), torch.rand(3, 10, 3)), - ] + (torch.rand(10, 2), torch.rand(10, 3)), + ( + LabelTensor(torch.rand(10, 2), ["u", "v"]), + LabelTensor(torch.rand(10, 3), ["x", "y", "z"]), + ), + ], ) -def test_build_multiple_graph_multiple_val(x, pos): - graph = RadiusGraph(x=x, pos=pos, build_edge_attr=False, r=.3) - assert len(graph.data) == 3 - data = graph.data - assert all(torch.isclose(d_.x, x_).all() for (d_, x_) in zip(data, x)) - assert all(torch.isclose(d_.pos, pos_).all() for d_, pos_ in zip(data, pos)) - assert all(len(d.edge_index) == 2 for d in data) - graph = RadiusGraph(x=x, pos=pos, build_edge_attr=True, r=.3) - data = graph.data - assert all(torch.isclose(d_.x, x_).all() for (d_, x_) in zip(data, x)) - assert all(torch.isclose(d_.pos, pos_).all() for d_, pos_ in zip(data, pos)) - assert all(len(d.edge_index) == 2 for d in data) - assert all(d.edge_attr is not None for d in data) - assert all([d.edge_index.shape[1] == d.edge_attr.shape[0]] for d in data) +def test_build_graph(x, pos): + edge_index = torch.tensor( + [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9], [1, 2, 3, 4, 5, 6, 7, 8, 9, 0]], + dtype=torch.int64, + ) + graph = Graph(x=x, pos=pos, edge_index=edge_index) + assert hasattr(graph, "x") + assert hasattr(graph, "pos") + assert hasattr(graph, "edge_index") + assert torch.isclose(graph.x, x).all() + if isinstance(x, LabelTensor): + assert isinstance(graph.x, LabelTensor) + assert graph.x.labels == x.labels + else: + assert isinstance(graph.pos, torch.Tensor) + assert torch.isclose(graph.pos, pos).all() + if isinstance(pos, LabelTensor): + assert isinstance(graph.pos, LabelTensor) + assert graph.pos.labels == pos.labels + else: + assert isinstance(graph.pos, torch.Tensor) - graph = KNNGraph(x=x, pos=pos, build_edge_attr=True, k=3) - data = graph.data - assert all(torch.isclose(d_.x, x_).all() for (d_, x_) in zip(data, x)) - assert all(torch.isclose(d_.pos, pos_).all() for d_, pos_ in zip(data, pos)) - assert all(len(d.edge_index) == 2 for d in data) - assert all(d.edge_attr is not None for d in data) - assert all([d.edge_index.shape[1] == d.edge_attr.shape[0]] for d in data) - - -def test_build_single_graph_multiple_val(): - x = torch.rand(10, 2) - pos = torch.rand(10, 3) - graph = RadiusGraph(x=x, pos=pos, build_edge_attr=False, r=.3) - assert len(graph.data) == 1 - data = graph.data - assert all(torch.isclose(d.x, x).all() for d in data) - assert all(torch.isclose(d_.pos, pos).all() for d_ in data) - assert all(len(d.edge_index) == 2 for d in data) - graph = RadiusGraph(x=x, pos=pos, build_edge_attr=True, r=.3) - data = graph.data - assert len(graph.data) == 1 - assert all(torch.isclose(d.x, x).all() for d in data) - assert all(torch.isclose(d_.pos, pos).all() for d_ in data) - assert all(len(d.edge_index) == 2 for d in data) - assert all(d.edge_attr is not None for d in data) - assert all([d.edge_index.shape[1] == d.edge_attr.shape[0]] for d in data) - - x = torch.rand(10, 2) - pos = torch.rand(10, 3) - graph = KNNGraph(x=x, pos=pos, build_edge_attr=True, k=3) - assert len(graph.data) == 1 - data = graph.data - assert all(torch.isclose(d.x, x).all() for d in data) - assert all(torch.isclose(d_.pos, pos).all() for d_ in data) - assert all(len(d.edge_index) == 2 for d in data) - graph = KNNGraph(x=x, pos=pos, build_edge_attr=True, k=3) - data = graph.data - assert len(graph.data) == 1 - assert all(torch.isclose(d.x, x).all() for d in data) - assert all(torch.isclose(d_.pos, pos).all() for d_ in data) - assert all(len(d.edge_index) == 2 for d in data) - assert all(d.edge_attr is not None for d in data) - assert all([d.edge_index.shape[1] == d.edge_attr.shape[0]] for d in data) + edge_index = torch.tensor( + [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9], [1, 2, 3, 4, 5, 6, 7, 8, 9, 0]], + dtype=torch.int64, + ) + graph = Graph(x=x, edge_index=edge_index) + assert hasattr(graph, "x") + assert hasattr(graph, "pos") + assert hasattr(graph, "edge_index") + assert torch.isclose(graph.x, x).all() + if isinstance(x, LabelTensor): + assert isinstance(graph.x, LabelTensor) + assert graph.x.labels == x.labels + else: + assert isinstance(graph.x, torch.Tensor) @pytest.mark.parametrize( - "pos", + "x, pos", [ - ([torch.rand(10, 3) for _ in range(3)]), - ([torch.rand(10, 3) for _ in range(3)]), - (torch.rand(3, 10, 3)), - (torch.rand(3, 10, 3)) - ] + (torch.rand(10, 2), torch.rand(10, 3)), + ( + LabelTensor(torch.rand(10, 2), ["u", "v"]), + LabelTensor(torch.rand(10, 3), ["x", "y", "z"]), + ), + ], ) -def test_build_single_graph_single_val(pos): - x = torch.rand(10, 2) - graph = RadiusGraph(x=x, pos=pos, build_edge_attr=False, r=.3) - assert len(graph.data) == 3 - data = graph.data - assert all(torch.isclose(d.x, x).all() for d in data) - assert all(torch.isclose(d_.pos, pos_).all() for d_, pos_ in zip(data, pos)) - assert all(len(d.edge_index) == 2 for d in data) - graph = RadiusGraph(x=x, pos=pos, build_edge_attr=True, r=.3) - data = graph.data - assert all(torch.isclose(d.x, x).all() for d in data) - assert all(torch.isclose(d_.pos, pos_).all() for d_, pos_ in zip(data, pos)) - assert all(len(d.edge_index) == 2 for d in data) - assert all(d.edge_attr is not None for d in data) - assert all([d.edge_index.shape[1] == d.edge_attr.shape[0]] for d in data) - x = torch.rand(10, 2) - graph = KNNGraph(x=x, pos=pos, build_edge_attr=False, k=3) - assert len(graph.data) == 3 - data = graph.data - assert all(torch.isclose(d.x, x).all() for d in data) - assert all(torch.isclose(d_.pos, pos_).all() for d_, pos_ in zip(data, pos)) - assert all(len(d.edge_index) == 2 for d in data) - graph = KNNGraph(x=x, pos=pos, build_edge_attr=True, k=3) - data = graph.data - assert all(torch.isclose(d.x, x).all() for d in data) - assert all(torch.isclose(d_.pos, pos_).all() for d_, pos_ in zip(data, pos)) - assert all(len(d.edge_index) == 2 for d in data) - assert all(d.edge_attr is not None for d in data) - assert all([d.edge_index.shape[1] == d.edge_attr.shape[0]] for d in data) - - -def test_additional_parameters_1(): - x = torch.rand(3, 10, 2) - pos = torch.rand(3, 10, 2) - additional_parameters = {'y': torch.ones(3)} - graph = RadiusGraph(x=x, pos=pos, build_edge_attr=True, r=.3, - additional_params=additional_parameters) - assert len(graph.data) == 3 - data = graph.data - assert all(torch.isclose(d_.x, x_).all() for (d_, x_) in zip(data, x)) - assert all(hasattr(d, 'y') for d in data) - assert all(d_.y == 1 for d_ in data) +def test_build_radius_graph(x, pos): + graph = RadiusGraph(x=x, pos=pos, radius=0.5) + assert hasattr(graph, "x") + assert hasattr(graph, "pos") + assert hasattr(graph, "edge_index") + assert torch.isclose(graph.x, x).all() + if isinstance(x, LabelTensor): + assert isinstance(graph.x, LabelTensor) + assert graph.x.labels == x.labels + else: + assert isinstance(graph.pos, torch.Tensor) + assert torch.isclose(graph.pos, pos).all() + if isinstance(pos, LabelTensor): + assert isinstance(graph.pos, LabelTensor) + assert graph.pos.labels == pos.labels + else: + assert isinstance(graph.pos, torch.Tensor) @pytest.mark.parametrize( - "additional_parameters", + "x, pos", [ - ({'y': torch.rand(3, 10, 1)}), - ({'y': [torch.rand(10, 1) for _ in range(3)]}), - ] + (torch.rand(10, 2), torch.rand(10, 3)), + ( + LabelTensor(torch.rand(10, 2), ["u", "v"]), + LabelTensor(torch.rand(10, 3), ["x", "y", "z"]), + ), + ], ) -def test_additional_parameters_2(additional_parameters): - x = torch.rand(3, 10, 2) - pos = torch.rand(3, 10, 2) - graph = RadiusGraph(x=x, pos=pos, build_edge_attr=True, r=.3, - additional_params=additional_parameters) - assert len(graph.data) == 3 - data = graph.data - assert all(torch.isclose(d_.x, x_).all() for (d_, x_) in zip(data, x)) - assert all(hasattr(d, 'y') for d in data) - assert all(torch.isclose(d_.x, x_).all() for (d_, x_) in zip(data, x)) +def test_build_radius_graph_edge_attr(x, pos): + graph = RadiusGraph(x=x, pos=pos, radius=0.5, edge_attr=True) + assert hasattr(graph, "x") + assert hasattr(graph, "pos") + assert hasattr(graph, "edge_index") + assert torch.isclose(graph.x, x).all() + if isinstance(x, LabelTensor): + assert isinstance(graph.x, LabelTensor) + assert graph.x.labels == x.labels + else: + assert isinstance(graph.pos, torch.Tensor) + assert torch.isclose(graph.pos, pos).all() + if isinstance(pos, LabelTensor): + assert isinstance(graph.pos, LabelTensor) + assert graph.pos.labels == pos.labels + else: + assert isinstance(graph.pos, torch.Tensor) + assert hasattr(graph, "edge_attr") + assert isinstance(graph.edge_attr, torch.Tensor) + assert graph.edge_attr.shape[-1] == 3 + assert graph.edge_attr.shape[0] == graph.edge_index.shape[1] -def test_custom_build_edge_attr_func(): - x = torch.rand(3, 10, 2) - pos = torch.rand(3, 10, 2) - def build_edge_attr(x, pos, edge_index): - return torch.cat([pos[edge_index[0]], pos[edge_index[1]]], dim=-1) +@pytest.mark.parametrize( + "x, pos", + [ + (torch.rand(10, 2), torch.rand(10, 3)), + ( + LabelTensor(torch.rand(10, 2), ["u", "v"]), + LabelTensor(torch.rand(10, 3), ["x", "y", "z"]), + ), + ], +) +def test_build_radius_graph_custom_edge_attr(x, pos): + graph = RadiusGraph( + x=x, + pos=pos, + radius=0.5, + edge_attr=True, + custom_edge_func=build_edge_attr, + ) + assert hasattr(graph, "x") + assert hasattr(graph, "pos") + assert hasattr(graph, "edge_index") + assert torch.isclose(graph.x, x).all() + if isinstance(x, LabelTensor): + assert isinstance(graph.x, LabelTensor) + assert graph.x.labels == x.labels + else: + assert isinstance(graph.pos, torch.Tensor) + assert torch.isclose(graph.pos, pos).all() + if isinstance(pos, LabelTensor): + assert isinstance(graph.pos, LabelTensor) + assert graph.pos.labels == pos.labels + else: + assert isinstance(graph.pos, torch.Tensor) + assert hasattr(graph, "edge_attr") + assert isinstance(graph.edge_attr, torch.Tensor) + assert graph.edge_attr.shape[-1] == 6 + assert graph.edge_attr.shape[0] == graph.edge_index.shape[1] - graph = RadiusGraph(x=x, pos=pos, build_edge_attr=True, r=.3, - custom_build_edge_attr=build_edge_attr) - assert len(graph.data) == 3 - data = graph.data - assert all(hasattr(d, 'edge_attr') for d in data) - assert all(d.edge_attr.shape[1] == 4 for d in data) - assert all(torch.isclose(d.edge_attr, - build_edge_attr(d.x, d.pos, d.edge_index)).all() - for d in data) + +@pytest.mark.parametrize( + "x, pos", + [ + (torch.rand(10, 2), torch.rand(10, 3)), + ( + LabelTensor(torch.rand(10, 2), ["u", "v"]), + LabelTensor(torch.rand(10, 3), ["x", "y", "z"]), + ), + ], +) +def test_build_knn_graph(x, pos): + graph = KNNGraph(x=x, pos=pos, neighbours=2) + assert hasattr(graph, "x") + assert hasattr(graph, "pos") + assert hasattr(graph, "edge_index") + assert torch.isclose(graph.x, x).all() + if isinstance(x, LabelTensor): + assert isinstance(graph.x, LabelTensor) + assert graph.x.labels == x.labels + else: + assert isinstance(graph.pos, torch.Tensor) + assert torch.isclose(graph.pos, pos).all() + if isinstance(pos, LabelTensor): + assert isinstance(graph.pos, LabelTensor) + assert graph.pos.labels == pos.labels + else: + assert isinstance(graph.pos, torch.Tensor) + assert graph.edge_attr is None + + +@pytest.mark.parametrize( + "x, pos", + [ + (torch.rand(10, 2), torch.rand(10, 3)), + ( + LabelTensor(torch.rand(10, 2), ["u", "v"]), + LabelTensor(torch.rand(10, 3), ["x", "y", "z"]), + ), + ], +) +def test_build_knn_graph_edge_attr(x, pos): + graph = KNNGraph(x=x, pos=pos, neighbours=2, edge_attr=True) + assert hasattr(graph, "x") + assert hasattr(graph, "pos") + assert hasattr(graph, "edge_index") + assert torch.isclose(graph.x, x).all() + if isinstance(x, LabelTensor): + assert isinstance(graph.x, LabelTensor) + assert graph.x.labels == x.labels + else: + assert isinstance(graph.pos, torch.Tensor) + assert torch.isclose(graph.pos, pos).all() + if isinstance(pos, LabelTensor): + assert isinstance(graph.pos, LabelTensor) + assert graph.pos.labels == pos.labels + else: + assert isinstance(graph.pos, torch.Tensor) + assert isinstance(graph.edge_attr, torch.Tensor) + assert graph.edge_attr.shape[-1] == 3 + assert graph.edge_attr.shape[0] == graph.edge_index.shape[1] + + +@pytest.mark.parametrize( + "x, pos", + [ + (torch.rand(10, 2), torch.rand(10, 3)), + ( + LabelTensor(torch.rand(10, 2), ["u", "v"]), + LabelTensor(torch.rand(10, 3), ["x", "y", "z"]), + ), + ], +) +def test_build_knn_graph_custom_edge_attr(x, pos): + graph = KNNGraph( + x=x, + pos=pos, + neighbours=2, + edge_attr=True, + custom_edge_func=build_edge_attr, + ) + assert hasattr(graph, "x") + assert hasattr(graph, "pos") + assert hasattr(graph, "edge_index") + assert torch.isclose(graph.x, x).all() + if isinstance(x, LabelTensor): + assert isinstance(graph.x, LabelTensor) + assert graph.x.labels == x.labels + else: + assert isinstance(graph.pos, torch.Tensor) + assert torch.isclose(graph.pos, pos).all() + if isinstance(pos, LabelTensor): + assert isinstance(graph.pos, LabelTensor) + assert graph.pos.labels == pos.labels + else: + assert isinstance(graph.pos, torch.Tensor) + assert isinstance(graph.edge_attr, torch.Tensor) + assert graph.edge_attr.shape[-1] == 6 + assert graph.edge_attr.shape[0] == graph.edge_index.shape[1] + + +@pytest.mark.parametrize( + "x, pos, y", + [ + (torch.rand(10, 2), torch.rand(10, 3), torch.rand(10, 4)), + ( + LabelTensor(torch.rand(10, 2), ["u", "v"]), + LabelTensor(torch.rand(10, 3), ["x", "y", "z"]), + LabelTensor(torch.rand(10, 4), ["a", "b", "c", "d"]), + ), + ], +) +def test_additional_params(x, pos, y): + edge_index = torch.tensor( + [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9], [1, 2, 3, 4, 5, 6, 7, 8, 9, 0]], + dtype=torch.int64, + ) + graph = Graph(x=x, pos=pos, edge_index=edge_index, y=y) + assert hasattr(graph, "y") + assert torch.isclose(graph.y, y).all() + if isinstance(y, LabelTensor): + assert isinstance(graph.y, LabelTensor) + assert graph.y.labels == y.labels + else: + assert isinstance(graph.y, torch.Tensor) + assert torch.isclose(graph.y, y).all() + if isinstance(y, LabelTensor): + assert isinstance(graph.y, LabelTensor) + assert graph.y.labels == y.labels + else: + assert isinstance(graph.y, torch.Tensor) + + +@pytest.mark.parametrize( + "x, pos, y", + [ + (torch.rand(10, 2), torch.rand(10, 3), torch.rand(10, 4)), + ( + LabelTensor(torch.rand(10, 2), ["u", "v"]), + LabelTensor(torch.rand(10, 3), ["x", "y", "z"]), + LabelTensor(torch.rand(10, 4), ["a", "b", "c", "d"]), + ), + ], +) +def test_additional_params_radius_graph(x, pos, y): + graph = RadiusGraph(x=x, pos=pos, radius=0.5, y=y) + assert hasattr(graph, "y") + assert torch.isclose(graph.y, y).all() + if isinstance(y, LabelTensor): + assert isinstance(graph.y, LabelTensor) + assert graph.y.labels == y.labels + else: + assert isinstance(graph.y, torch.Tensor) + assert torch.isclose(graph.y, y).all() + if isinstance(y, LabelTensor): + assert isinstance(graph.y, LabelTensor) + assert graph.y.labels == y.labels + else: + assert isinstance(graph.y, torch.Tensor) + + +@pytest.mark.parametrize( + "x, pos, y", + [ + (torch.rand(10, 2), torch.rand(10, 3), torch.rand(10, 4)), + ( + LabelTensor(torch.rand(10, 2), ["u", "v"]), + LabelTensor(torch.rand(10, 3), ["x", "y", "z"]), + LabelTensor(torch.rand(10, 4), ["a", "b", "c", "d"]), + ), + ], +) +def test_additional_params_knn_graph(x, pos, y): + graph = KNNGraph(x=x, pos=pos, neighbours=3, y=y) + assert hasattr(graph, "y") + assert torch.isclose(graph.y, y).all() + if isinstance(y, LabelTensor): + assert isinstance(graph.y, LabelTensor) + assert graph.y.labels == y.labels + else: + assert isinstance(graph.y, torch.Tensor) + assert torch.isclose(graph.y, y).all() + if isinstance(y, LabelTensor): + assert isinstance(graph.y, LabelTensor) + assert graph.y.labels == y.labels + else: + assert isinstance(graph.y, torch.Tensor) diff --git a/tests/test_model/test_graph_neural_operator.py b/tests/test_model/test_graph_neural_operator.py index 8fb10d8..e2ea3ad 100644 --- a/tests/test_model/test_graph_neural_operator.py +++ b/tests/test_model/test_graph_neural_operator.py @@ -6,99 +6,90 @@ from torch_geometric.data import Batch x = [torch.rand(100, 6) for _ in range(10)] pos = [torch.rand(100, 3) for _ in range(10)] -graph = KNNGraph(x=x, pos=pos, build_edge_attr=True, k=6) -input_ = Batch.from_data_list(graph.data) +graph = [ + KNNGraph(x=x_, pos=pos_, neighbours=6, edge_attr=True) + for x_, pos_ in zip(x, pos) +] +input_ = Batch.from_data_list(graph) -@pytest.mark.parametrize( - "shared_weights", - [ - True, - False - ] -) +@pytest.mark.parametrize("shared_weights", [True, False]) def test_constructor(shared_weights): lifting_operator = torch.nn.Linear(6, 16) projection_operator = torch.nn.Linear(16, 3) - GraphNeuralOperator(lifting_operator=lifting_operator, - projection_operator=projection_operator, - edge_features=3, - internal_layers=[16, 16], - shared_weights=shared_weights) + GraphNeuralOperator( + lifting_operator=lifting_operator, + projection_operator=projection_operator, + edge_features=3, + internal_layers=[16, 16], + shared_weights=shared_weights, + ) - GraphNeuralOperator(lifting_operator=lifting_operator, - projection_operator=projection_operator, - edge_features=3, - inner_size=16, - internal_n_layers=10, - shared_weights=shared_weights) + GraphNeuralOperator( + lifting_operator=lifting_operator, + projection_operator=projection_operator, + edge_features=3, + inner_size=16, + internal_n_layers=10, + shared_weights=shared_weights, + ) int_func = torch.nn.Softplus ext_func = torch.nn.ReLU - GraphNeuralOperator(lifting_operator=lifting_operator, - projection_operator=projection_operator, - edge_features=3, - internal_n_layers=10, - shared_weights=shared_weights, - internal_func=int_func, - external_func=ext_func) + GraphNeuralOperator( + lifting_operator=lifting_operator, + projection_operator=projection_operator, + edge_features=3, + internal_n_layers=10, + shared_weights=shared_weights, + internal_func=int_func, + external_func=ext_func, + ) -@pytest.mark.parametrize( - "shared_weights", - [ - True, - False - ] -) +@pytest.mark.parametrize("shared_weights", [True, False]) def test_forward_1(shared_weights): lifting_operator = torch.nn.Linear(6, 16) projection_operator = torch.nn.Linear(16, 3) - model = GraphNeuralOperator(lifting_operator=lifting_operator, - projection_operator=projection_operator, - edge_features=3, - internal_layers=[16, 16], - shared_weights=shared_weights) + model = GraphNeuralOperator( + lifting_operator=lifting_operator, + projection_operator=projection_operator, + edge_features=3, + internal_layers=[16, 16], + shared_weights=shared_weights, + ) output_ = model(input_) assert output_.shape == torch.Size([1000, 3]) -@pytest.mark.parametrize( - "shared_weights", - [ - True, - False - ] -) +@pytest.mark.parametrize("shared_weights", [True, False]) def test_forward_2(shared_weights): lifting_operator = torch.nn.Linear(6, 16) projection_operator = torch.nn.Linear(16, 3) - model = GraphNeuralOperator(lifting_operator=lifting_operator, - projection_operator=projection_operator, - edge_features=3, - inner_size=32, - internal_n_layers=2, - shared_weights=shared_weights) + model = GraphNeuralOperator( + lifting_operator=lifting_operator, + projection_operator=projection_operator, + edge_features=3, + inner_size=32, + internal_n_layers=2, + shared_weights=shared_weights, + ) output_ = model(input_) assert output_.shape == torch.Size([1000, 3]) -@pytest.mark.parametrize( - "shared_weights", - [ - True, - False - ] -) +@pytest.mark.parametrize("shared_weights", [True, False]) def test_backward(shared_weights): lifting_operator = torch.nn.Linear(6, 16) projection_operator = torch.nn.Linear(16, 3) - model = GraphNeuralOperator(lifting_operator=lifting_operator, - projection_operator=projection_operator, - edge_features=3, - internal_layers=[16, 16], - shared_weights=shared_weights) + model = GraphNeuralOperator( + lifting_operator=lifting_operator, + projection_operator=projection_operator, + edge_features=3, + internal_layers=[16, 16], + shared_weights=shared_weights, + ) input_.x.requires_grad = True output_ = model(input_) l = torch.mean(output_) @@ -106,22 +97,18 @@ def test_backward(shared_weights): assert input_.x.grad.shape == torch.Size([1000, 6]) -@pytest.mark.parametrize( - "shared_weights", - [ - True, - False - ] -) +@pytest.mark.parametrize("shared_weights", [True, False]) def test_backward_2(shared_weights): lifting_operator = torch.nn.Linear(6, 16) projection_operator = torch.nn.Linear(16, 3) - model = GraphNeuralOperator(lifting_operator=lifting_operator, - projection_operator=projection_operator, - edge_features=3, - inner_size=32, - internal_n_layers=2, - shared_weights=shared_weights) + model = GraphNeuralOperator( + lifting_operator=lifting_operator, + projection_operator=projection_operator, + edge_features=3, + inner_size=32, + internal_n_layers=2, + shared_weights=shared_weights, + ) input_.x.requires_grad = True output_ = model(input_) l = torch.mean(output_) diff --git a/tests/test_problem_zoo/test_supervised_problem.py b/tests/test_problem_zoo/test_supervised_problem.py index f3ac567..06241fa 100644 --- a/tests/test_problem_zoo/test_supervised_problem.py +++ b/tests/test_problem_zoo/test_supervised_problem.py @@ -4,28 +4,31 @@ from pina.condition import InputOutputPointsCondition from pina.problem.zoo.supervised_problem import SupervisedProblem from pina.graph import RadiusGraph + def test_constructor(): - input_ = torch.rand((100,10)) - output_ = torch.rand((100,10)) + input_ = torch.rand((100, 10)) + output_ = torch.rand((100, 10)) problem = SupervisedProblem(input_=input_, output_=output_) assert isinstance(problem, AbstractProblem) assert hasattr(problem, "conditions") assert isinstance(problem.conditions, dict) - assert list(problem.conditions.keys()) == ['data'] - assert isinstance(problem.conditions['data'], InputOutputPointsCondition) + assert list(problem.conditions.keys()) == ["data"] + assert isinstance(problem.conditions["data"], InputOutputPointsCondition) + def test_constructor_graph(): - x = torch.rand((20,100,10)) - pos = torch.rand((20,100,2)) - input_ = RadiusGraph( - x=x, pos=pos, r=.2, build_edge_attr=True - ) - output_ = torch.rand((100,10)) + x = torch.rand((20, 100, 10)) + pos = torch.rand((20, 100, 2)) + input_ = [ + RadiusGraph(x=x_, pos=pos_, radius=0.2, edge_attr=True) + for x_, pos_ in zip(x, pos) + ] + output_ = torch.rand((100, 10)) problem = SupervisedProblem(input_=input_, output_=output_) assert isinstance(problem, AbstractProblem) assert hasattr(problem, "conditions") assert isinstance(problem.conditions, dict) - assert list(problem.conditions.keys()) == ['data'] - assert isinstance(problem.conditions['data'], InputOutputPointsCondition) - assert isinstance(problem.conditions['data'].input_points, list) - assert isinstance(problem.conditions['data'].output_points, torch.Tensor) + assert list(problem.conditions.keys()) == ["data"] + assert isinstance(problem.conditions["data"], InputOutputPointsCondition) + assert isinstance(problem.conditions["data"].input_points, list) + assert isinstance(problem.conditions["data"].output_points, torch.Tensor)