Add Graph class and tests for Graph and Collector + Dataloader refactoring
This commit is contained in:
committed by
Nicola Demo
parent
4fdb5641d4
commit
e63c3d9061
332
pina/graph.py
332
pina/graph.py
@@ -1,118 +1,240 @@
|
|||||||
""" Module for Loss class """
|
from logging import warning
|
||||||
|
|
||||||
import logging
|
|
||||||
from torch_geometric.nn import MessagePassing, InstanceNorm, radius_graph
|
|
||||||
from torch_geometric.data import Data
|
|
||||||
import torch
|
import torch
|
||||||
|
from . import LabelTensor
|
||||||
|
from torch_geometric.data import Data
|
||||||
|
from torch_geometric.utils import to_undirected
|
||||||
|
|
||||||
|
|
||||||
class Graph:
|
class Graph:
|
||||||
"""
|
"""
|
||||||
PINA Graph managing the PyG Data class.
|
Class for the graph construction.
|
||||||
"""
|
"""
|
||||||
def __init__(self, data):
|
|
||||||
self.data = data
|
|
||||||
|
|
||||||
@staticmethod
|
def __init__(self,
|
||||||
def _build_triangulation(**kwargs):
|
x,
|
||||||
logging.debug("Creating graph with triangulation mode.")
|
pos,
|
||||||
|
edge_index,
|
||||||
# check for mandatory arguments
|
edge_attr=None,
|
||||||
if "nodes_coordinates" not in kwargs:
|
build_edge_attr=False,
|
||||||
raise ValueError("Nodes coordinates must be provided in the kwargs.")
|
undirected=False,
|
||||||
if "nodes_data" not in kwargs:
|
additional_params=None):
|
||||||
raise ValueError("Nodes data must be provided in the kwargs.")
|
|
||||||
if "triangles" not in kwargs:
|
|
||||||
raise ValueError("Triangles must be provided in the kwargs.")
|
|
||||||
|
|
||||||
nodes_coordinates = kwargs["nodes_coordinates"]
|
|
||||||
nodes_data = kwargs["nodes_data"]
|
|
||||||
triangles = kwargs["triangles"]
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def less_first(a, b):
|
|
||||||
return [a, b] if a < b else [b, a]
|
|
||||||
|
|
||||||
list_of_edges = []
|
|
||||||
|
|
||||||
for triangle in triangles:
|
|
||||||
for e1, e2 in [[0, 1], [1, 2], [2, 0]]:
|
|
||||||
list_of_edges.append(less_first(triangle[e1],triangle[e2]))
|
|
||||||
|
|
||||||
array_of_edges = torch.unique(torch.Tensor(list_of_edges), dim=0) # remove duplicates
|
|
||||||
array_of_edges = array_of_edges.t().contiguous()
|
|
||||||
print(array_of_edges)
|
|
||||||
|
|
||||||
# list_of_lengths = []
|
|
||||||
|
|
||||||
# for p1,p2 in array_of_edges:
|
|
||||||
# x1, y1 = tri.points[p1]
|
|
||||||
# x2, y2 = tri.points[p2]
|
|
||||||
# list_of_lengths.append((x1-x2)**2 + (y1-y2)**2)
|
|
||||||
|
|
||||||
# array_of_lengths = np.sqrt(np.array(list_of_lengths))
|
|
||||||
|
|
||||||
# return array_of_edges, array_of_lengths
|
|
||||||
|
|
||||||
return Data(
|
|
||||||
x=nodes_data,
|
|
||||||
pos=nodes_coordinates.T,
|
|
||||||
|
|
||||||
edge_index=array_of_edges,
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _build_radius(**kwargs):
|
|
||||||
logging.debug("Creating graph with radius mode.")
|
|
||||||
|
|
||||||
# check for mandatory arguments
|
|
||||||
if "nodes_coordinates" not in kwargs:
|
|
||||||
raise ValueError("Nodes coordinates must be provided in the kwargs.")
|
|
||||||
if "nodes_data" not in kwargs:
|
|
||||||
raise ValueError("Nodes data must be provided in the kwargs.")
|
|
||||||
if "radius" not in kwargs:
|
|
||||||
raise ValueError("Radius must be provided in the kwargs.")
|
|
||||||
|
|
||||||
nodes_coordinates = kwargs["nodes_coordinates"]
|
|
||||||
nodes_data = kwargs["nodes_data"]
|
|
||||||
radius = kwargs["radius"]
|
|
||||||
|
|
||||||
edges_data = kwargs.get("edge_data", None)
|
|
||||||
loop = kwargs.get("loop", False)
|
|
||||||
batch = kwargs.get("batch", None)
|
|
||||||
|
|
||||||
logging.debug(f"radius: {radius}, loop: {loop}, "
|
|
||||||
f"batch: {batch}")
|
|
||||||
|
|
||||||
edge_index = radius_graph(
|
|
||||||
x=nodes_coordinates.tensor,
|
|
||||||
r=radius,
|
|
||||||
loop=loop,
|
|
||||||
batch=batch,
|
|
||||||
)
|
|
||||||
|
|
||||||
logging.debug(f"edge_index computed")
|
|
||||||
return Data(
|
|
||||||
x=nodes_data.tensor,
|
|
||||||
pos=nodes_coordinates.tensor,
|
|
||||||
edge_index=edge_index,
|
|
||||||
edge_attr=edges_data,
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def build(mode, **kwargs):
|
|
||||||
"""
|
"""
|
||||||
Constructor for the `Graph` class.
|
Constructor for the Graph class.
|
||||||
|
:param x: The node features.
|
||||||
|
:type x: torch.Tensor or list[torch.Tensor]
|
||||||
|
:param pos: The node positions.
|
||||||
|
:type pos: torch.Tensor or list[torch.Tensor]
|
||||||
|
:param edge_index: The edge index.
|
||||||
|
:type edge_index: torch.Tensor or list[torch.Tensor]
|
||||||
|
:param edge_attr: The edge attributes.
|
||||||
|
:type edge_attr: torch.Tensor or list[torch.Tensor]
|
||||||
|
:param build_edge_attr: Whether to build the edge attributes.
|
||||||
|
:type build_edge_attr: bool
|
||||||
|
:param undirected: Whether to build an undirected graph.
|
||||||
|
:type undirected: bool
|
||||||
|
:param additional_params: Additional parameters.
|
||||||
|
:type additional_params: dict
|
||||||
"""
|
"""
|
||||||
if mode == "radius":
|
self.data = []
|
||||||
graph = Graph._build_radius(**kwargs)
|
x, pos, edge_index = Graph._check_input_consistency(x, pos, edge_index)
|
||||||
elif mode == "triangulation":
|
|
||||||
graph = Graph._build_triangulation(**kwargs)
|
# Check input dimension consistency and store the number of graphs
|
||||||
|
data_len = self._check_len_consistency(x, pos)
|
||||||
|
|
||||||
|
# Initialize additional_parameters (if present)
|
||||||
|
if additional_params is not None:
|
||||||
|
if not isinstance(additional_params, dict):
|
||||||
|
raise TypeError("additional_params must be a dictionary.")
|
||||||
|
for param, val in additional_params.items():
|
||||||
|
# Check if the values are tensors or lists of tensors
|
||||||
|
if isinstance(val, torch.Tensor):
|
||||||
|
# If the tensor is 3D, we split it into a list of 2D tensors
|
||||||
|
# In this case there must be a additional parameter for each
|
||||||
|
# node
|
||||||
|
if val.ndim == 3:
|
||||||
|
additional_params[param] = [val[i] for i in
|
||||||
|
range(val.shape[0])]
|
||||||
|
# If the tensor is 2D, we replicate it for each node
|
||||||
|
elif val.ndim == 2:
|
||||||
|
additional_params[param] = [val] * data_len
|
||||||
|
# If the tensor is 1D, each graph has a scalar values as
|
||||||
|
# additional parameter
|
||||||
|
if val.ndim == 1:
|
||||||
|
if len(val) == data_len:
|
||||||
|
additional_params[param] = [val[i] for i in
|
||||||
|
range(len(val))]
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Mode {mode} not recognized")
|
additional_params[param] = [val for _ in
|
||||||
|
range(data_len)]
|
||||||
|
elif not isinstance(val, list):
|
||||||
|
raise TypeError("additional_params values must be tensors "
|
||||||
|
"or lists of tensors.")
|
||||||
|
else:
|
||||||
|
additional_params = {}
|
||||||
|
|
||||||
return Graph(graph)
|
# Make the graphs undirected
|
||||||
|
if undirected:
|
||||||
|
if isinstance(edge_index, list):
|
||||||
|
edge_index = [to_undirected(e) for e in edge_index]
|
||||||
|
else:
|
||||||
|
edge_index = to_undirected(edge_index)
|
||||||
|
|
||||||
|
if build_edge_attr:
|
||||||
|
if edge_attr is not None:
|
||||||
|
warning("Edge attributes are provided, build_edge_attr is set "
|
||||||
|
"to True. The provided edge attributes will be ignored.")
|
||||||
|
edge_attr = self._build_edge_attr(pos, edge_index)
|
||||||
|
|
||||||
|
# Prepare internal lists to create a graph list (same positions but
|
||||||
|
# different node features)
|
||||||
|
if isinstance(x, list) and isinstance(pos,
|
||||||
|
(torch.Tensor, LabelTensor)):
|
||||||
|
# Replicate the positions, edge_index and edge_attr
|
||||||
|
pos, edge_index = [pos] * data_len, [edge_index] * data_len
|
||||||
|
if edge_attr is not None:
|
||||||
|
edge_attr = [edge_attr] * data_len
|
||||||
|
# Prepare internal lists to create a list containing a single graph
|
||||||
|
elif isinstance(x, (torch.Tensor, LabelTensor)) and isinstance(pos, (
|
||||||
|
torch.Tensor, LabelTensor)):
|
||||||
|
# Encapsulate the input tensors into lists
|
||||||
|
x, pos, edge_index = [x], [pos], [edge_index]
|
||||||
|
if isinstance(edge_attr, torch.Tensor):
|
||||||
|
edge_attr = [edge_attr]
|
||||||
|
# Prepare internal lists to create a list of graphs (same node features
|
||||||
|
# but different positions)
|
||||||
|
elif (isinstance(x, (torch.Tensor, LabelTensor))
|
||||||
|
and isinstance(pos, list)):
|
||||||
|
# Replicate the node features
|
||||||
|
x = [x] * data_len
|
||||||
|
elif not isinstance(x, list) and not isinstance(pos, list):
|
||||||
|
raise TypeError("x and pos must be lists or tensors.")
|
||||||
|
|
||||||
|
# Perform the graph construction
|
||||||
|
self._build_graph_list(x, pos, edge_index, edge_attr, additional_params)
|
||||||
|
|
||||||
|
def _build_graph_list(self, x, pos, edge_index, edge_attr,
|
||||||
|
additional_params):
|
||||||
|
for i, (x_, pos_, edge_index_) in enumerate(zip(x, pos, edge_index)):
|
||||||
|
if isinstance(x_, LabelTensor):
|
||||||
|
x_ = x_.tensor
|
||||||
|
add_params_local = {k: v[i] for k, v in additional_params.items()}
|
||||||
|
if edge_attr is not None:
|
||||||
|
|
||||||
|
self.data.append(Data(x=x_, pos=pos_, edge_index=edge_index_,
|
||||||
|
edge_attr=edge_attr[i],
|
||||||
|
**add_params_local))
|
||||||
|
else:
|
||||||
|
self.data.append(Data(x=x_, pos=pos_, edge_index=edge_index_,
|
||||||
|
**add_params_local))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _build_edge_attr(pos, edge_index):
|
||||||
|
if isinstance(pos, torch.Tensor):
|
||||||
|
pos = [pos]
|
||||||
|
edge_index = [edge_index]
|
||||||
|
distance = [pos_[edge_index_[0]] - pos_[edge_index_[1]] ** 2 for
|
||||||
|
pos_, edge_index_ in zip(pos, edge_index)]
|
||||||
|
return distance
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _check_len_consistency(x, pos):
|
||||||
|
if isinstance(x, list) and isinstance(pos, list):
|
||||||
|
if len(x) != len(pos):
|
||||||
|
raise ValueError("x and pos must have the same length.")
|
||||||
|
return max(len(x), len(pos))
|
||||||
|
elif isinstance(x, list) and not isinstance(pos, list):
|
||||||
|
return len(x)
|
||||||
|
elif not isinstance(x, list) and isinstance(pos, list):
|
||||||
|
return len(pos)
|
||||||
|
else:
|
||||||
|
return 1
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _check_input_consistency(x, pos, edge_index=None):
|
||||||
|
# If x is a 3D tensor, we split it into a list of 2D tensors
|
||||||
|
if isinstance(x, torch.Tensor) and x.ndim == 3:
|
||||||
|
x = [x[i] for i in range(x.shape[0])]
|
||||||
|
|
||||||
|
# If pos is a 3D tensor, we split it into a list of 2D tensors
|
||||||
|
if isinstance(pos, torch.Tensor) and pos.ndim == 3:
|
||||||
|
pos = [pos[i] for i in range(pos.shape[0])]
|
||||||
|
|
||||||
|
# If edge_index is a 3D tensor, we split it into a list of 2D tensors
|
||||||
|
if isinstance(edge_index, torch.Tensor) and edge_index.ndim == 3:
|
||||||
|
edge_index = [edge_index[i] for i in range(edge_index.shape[0])]
|
||||||
|
return x, pos, edge_index
|
||||||
|
|
||||||
|
|
||||||
def __repr__(self):
|
class RadiusGraph(Graph):
|
||||||
return f"Graph(data={self.data})"
|
def __init__(self,
|
||||||
|
x,
|
||||||
|
pos,
|
||||||
|
r,
|
||||||
|
build_edge_attr=False,
|
||||||
|
undirected=False,
|
||||||
|
additional_params=None, ):
|
||||||
|
x, pos, edge_index = Graph._check_input_consistency(x, pos)
|
||||||
|
|
||||||
|
if isinstance(pos, (torch.Tensor, LabelTensor)):
|
||||||
|
edge_index = RadiusGraph._radius_graph(pos, r)
|
||||||
|
else:
|
||||||
|
edge_index = [RadiusGraph._radius_graph(p, r) for p in pos]
|
||||||
|
|
||||||
|
super().__init__(x=x, pos=pos, edge_index=edge_index,
|
||||||
|
build_edge_attr=build_edge_attr,
|
||||||
|
undirected=undirected,
|
||||||
|
additional_params=additional_params)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _radius_graph(points, r):
|
||||||
|
"""
|
||||||
|
Implementation of the radius graph construction.
|
||||||
|
:param points: The input points.
|
||||||
|
:type points: torch.Tensor
|
||||||
|
:param r: The radius.
|
||||||
|
:type r: float
|
||||||
|
:return: The edge index.
|
||||||
|
:rtype: torch.Tensor
|
||||||
|
"""
|
||||||
|
dist = torch.cdist(points, points, p=2)
|
||||||
|
edge_index = torch.nonzero(dist <= r, as_tuple=False).t()
|
||||||
|
return edge_index
|
||||||
|
|
||||||
|
|
||||||
|
class KNNGraph(Graph):
|
||||||
|
def __init__(self,
|
||||||
|
x,
|
||||||
|
pos,
|
||||||
|
k,
|
||||||
|
build_edge_attr=False,
|
||||||
|
undirected=False,
|
||||||
|
additional_params=None,
|
||||||
|
):
|
||||||
|
x, pos, edge_index = Graph._check_input_consistency(x, pos)
|
||||||
|
if isinstance(pos, (torch.Tensor, LabelTensor)):
|
||||||
|
edge_index = KNNGraph._knn_graph(pos, k)
|
||||||
|
else:
|
||||||
|
edge_index = [KNNGraph._knn_graph(p, k) for p in pos]
|
||||||
|
super().__init__(x=x, pos=pos, edge_index=edge_index,
|
||||||
|
build_edge_attr=build_edge_attr,
|
||||||
|
undirected=undirected,
|
||||||
|
additional_params=additional_params)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _knn_graph(points, k):
|
||||||
|
"""
|
||||||
|
Implementation of the k-nearest neighbors graph construction.
|
||||||
|
:param points: The input points.
|
||||||
|
:type points: torch.Tensor
|
||||||
|
:param k: The number of nearest neighbors.
|
||||||
|
:type k: int
|
||||||
|
:return: The edge index.
|
||||||
|
:rtype: torch.Tensor
|
||||||
|
"""
|
||||||
|
dist = torch.cdist(points, points, p=2)
|
||||||
|
knn_indices = torch.topk(dist, k=k + 1, largest=False).indices[:, 1:]
|
||||||
|
row = torch.arange(points.size(0)).repeat_interleave(k)
|
||||||
|
col = knn_indices.flatten()
|
||||||
|
edge_index = torch.stack([row, col], dim=0)
|
||||||
|
return edge_index
|
||||||
|
|||||||
125
tests/test_collector.py
Normal file
125
tests/test_collector.py
Normal file
@@ -0,0 +1,125 @@
|
|||||||
|
import torch
|
||||||
|
import pytest
|
||||||
|
from pina import Condition, LabelTensor, Graph
|
||||||
|
from pina.condition import InputOutputPointsCondition, DomainEquationCondition
|
||||||
|
from pina.graph import RadiusGraph
|
||||||
|
from pina.problem import AbstractProblem, SpatialProblem
|
||||||
|
from pina.domain import CartesianDomain
|
||||||
|
from pina.equation.equation import Equation
|
||||||
|
from pina.equation.equation_factory import FixedValue
|
||||||
|
from pina.operators import laplacian
|
||||||
|
|
||||||
|
def test_supervised_tensor_collector():
|
||||||
|
class SupervisedProblem(AbstractProblem):
|
||||||
|
output_variables = None
|
||||||
|
conditions = {
|
||||||
|
'data1' : Condition(input_points=torch.rand((10,2)),
|
||||||
|
output_points=torch.rand((10,2))),
|
||||||
|
'data2' : Condition(input_points=torch.rand((20,2)),
|
||||||
|
output_points=torch.rand((20,2))),
|
||||||
|
'data3' : Condition(input_points=torch.rand((30,2)),
|
||||||
|
output_points=torch.rand((30,2))),
|
||||||
|
}
|
||||||
|
problem = SupervisedProblem()
|
||||||
|
collector = problem.collector
|
||||||
|
for v in collector.conditions_name.values():
|
||||||
|
assert v in problem.conditions.keys()
|
||||||
|
assert all(collector._is_conditions_ready.values())
|
||||||
|
|
||||||
|
def test_pinn_collector():
|
||||||
|
def laplace_equation(input_, output_):
|
||||||
|
force_term = (torch.sin(input_.extract(['x']) * torch.pi) *
|
||||||
|
torch.sin(input_.extract(['y']) * torch.pi))
|
||||||
|
delta_u = laplacian(output_.extract(['u']), input_)
|
||||||
|
return delta_u - force_term
|
||||||
|
|
||||||
|
my_laplace = Equation(laplace_equation)
|
||||||
|
in_ = LabelTensor(torch.tensor([[0., 1.]], requires_grad=True), ['x', 'y'])
|
||||||
|
out_ = LabelTensor(torch.tensor([[0.]], requires_grad=True), ['u'])
|
||||||
|
class Poisson(SpatialProblem):
|
||||||
|
output_variables = ['u']
|
||||||
|
spatial_domain = CartesianDomain({'x': [0, 1], 'y': [0, 1]})
|
||||||
|
|
||||||
|
conditions = {
|
||||||
|
'gamma1':
|
||||||
|
Condition(domain=CartesianDomain({
|
||||||
|
'x': [0, 1],
|
||||||
|
'y': 1
|
||||||
|
}),
|
||||||
|
equation=FixedValue(0.0)),
|
||||||
|
'gamma2':
|
||||||
|
Condition(domain=CartesianDomain({
|
||||||
|
'x': [0, 1],
|
||||||
|
'y': 0
|
||||||
|
}),
|
||||||
|
equation=FixedValue(0.0)),
|
||||||
|
'gamma3':
|
||||||
|
Condition(domain=CartesianDomain({
|
||||||
|
'x': 1,
|
||||||
|
'y': [0, 1]
|
||||||
|
}),
|
||||||
|
equation=FixedValue(0.0)),
|
||||||
|
'gamma4':
|
||||||
|
Condition(domain=CartesianDomain({
|
||||||
|
'x': 0,
|
||||||
|
'y': [0, 1]
|
||||||
|
}),
|
||||||
|
equation=FixedValue(0.0)),
|
||||||
|
'D':
|
||||||
|
Condition(domain=CartesianDomain({
|
||||||
|
'x': [0, 1],
|
||||||
|
'y': [0, 1]
|
||||||
|
}),
|
||||||
|
equation=my_laplace),
|
||||||
|
'data':
|
||||||
|
Condition(input_points=in_, output_points=out_)
|
||||||
|
}
|
||||||
|
|
||||||
|
def poisson_sol(self, pts):
|
||||||
|
return -(torch.sin(pts.extract(['x']) * torch.pi) *
|
||||||
|
torch.sin(pts.extract(['y']) * torch.pi)) / (2 * torch.pi**2)
|
||||||
|
|
||||||
|
truth_solution = poisson_sol
|
||||||
|
|
||||||
|
problem = Poisson()
|
||||||
|
collector = problem.collector
|
||||||
|
for k,v in problem.conditions.items():
|
||||||
|
if isinstance(v, InputOutputPointsCondition):
|
||||||
|
assert collector._is_conditions_ready[k] == True
|
||||||
|
assert list(collector.data_collections[k].keys()) == ['input_points', 'output_points']
|
||||||
|
else:
|
||||||
|
assert collector._is_conditions_ready[k] == False
|
||||||
|
assert collector.data_collections[k] == {}
|
||||||
|
|
||||||
|
boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
|
||||||
|
problem.discretise_domain(10, 'grid', locations=boundaries)
|
||||||
|
problem.discretise_domain(10, 'grid', locations='D')
|
||||||
|
assert all(collector._is_conditions_ready.values())
|
||||||
|
for k,v in problem.conditions.items():
|
||||||
|
if isinstance(v, DomainEquationCondition):
|
||||||
|
assert list(collector.data_collections[k].keys()) == ['input_points', 'equation']
|
||||||
|
|
||||||
|
|
||||||
|
def test_supervised_graph_collector():
|
||||||
|
pos = torch.rand((100,3))
|
||||||
|
x = [torch.rand((100,3)) for _ in range(10)]
|
||||||
|
graph_list_1 = RadiusGraph(pos=pos, x=x, build_edge_attr=True, r=.4)
|
||||||
|
out_1 = torch.rand((10,100,3))
|
||||||
|
pos = torch.rand((50,3))
|
||||||
|
x = [torch.rand((50,3)) for _ in range(10)]
|
||||||
|
graph_list_2 = RadiusGraph(pos=pos, x=x, build_edge_attr=True, r=.4)
|
||||||
|
out_2 = torch.rand((10,50,3))
|
||||||
|
class SupervisedProblem(AbstractProblem):
|
||||||
|
output_variables = None
|
||||||
|
conditions = {
|
||||||
|
'data1' : Condition(input_points=graph_list_1,
|
||||||
|
output_points=out_1),
|
||||||
|
'data2' : Condition(input_points=graph_list_2,
|
||||||
|
output_points=out_2),
|
||||||
|
}
|
||||||
|
|
||||||
|
problem = SupervisedProblem()
|
||||||
|
collector = problem.collector
|
||||||
|
assert all(collector._is_conditions_ready.values())
|
||||||
|
for v in collector.conditions_name.values():
|
||||||
|
assert v in problem.conditions.keys()
|
||||||
143
tests/test_graph.py
Normal file
143
tests/test_graph.py
Normal file
@@ -0,0 +1,143 @@
|
|||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
from pina import Graph
|
||||||
|
from pina.graph import RadiusGraph, KNNGraph
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"x, pos",
|
||||||
|
[
|
||||||
|
([torch.rand(10, 2) for _ in range(3)], [torch.rand(10, 3) for _ in range(3)]),
|
||||||
|
([torch.rand(10, 2) for _ in range(3)], [torch.rand(10, 3) for _ in range(3)]),
|
||||||
|
(torch.rand(3,10,2), torch.rand(3,10,3)),
|
||||||
|
(torch.rand(3,10,2), torch.rand(3,10,3)),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def test_build_multiple_graph_multiple_val(x, pos):
|
||||||
|
graph = RadiusGraph(x=x, pos=pos, build_edge_attr=False, r=.3)
|
||||||
|
assert len(graph.data) == 3
|
||||||
|
data = graph.data
|
||||||
|
assert all(torch.isclose(d_.x, x_).all() for (d_, x_) in zip(data, x))
|
||||||
|
assert all(torch.isclose(d_.pos, pos_).all() for d_, pos_ in zip(data, pos))
|
||||||
|
assert all(len(d.edge_index) == 2 for d in data)
|
||||||
|
graph = RadiusGraph(x=x, pos=pos, build_edge_attr=True, r=.3)
|
||||||
|
data = graph.data
|
||||||
|
assert all(torch.isclose(d_.x, x_).all() for (d_, x_) in zip(data, x))
|
||||||
|
assert all(torch.isclose(d_.pos, pos_).all() for d_, pos_ in zip(data, pos))
|
||||||
|
assert all(len(d.edge_index) == 2 for d in data)
|
||||||
|
assert all(d.edge_attr is not None for d in data)
|
||||||
|
assert all([d.edge_index.shape[1] == d.edge_attr.shape[0]] for d in data)
|
||||||
|
|
||||||
|
graph = KNNGraph(x=x, pos=pos, build_edge_attr=True, k = 3)
|
||||||
|
data = graph.data
|
||||||
|
assert all(torch.isclose(d_.x, x_).all() for (d_, x_) in zip(data, x))
|
||||||
|
assert all(torch.isclose(d_.pos, pos_).all() for d_, pos_ in zip(data, pos))
|
||||||
|
assert all(len(d.edge_index) == 2 for d in data)
|
||||||
|
assert all(d.edge_attr is not None for d in data)
|
||||||
|
assert all([d.edge_index.shape[1] == d.edge_attr.shape[0]] for d in data)
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_single_graph_multiple_val():
|
||||||
|
x = torch.rand(10, 2)
|
||||||
|
pos = torch.rand(10, 3)
|
||||||
|
graph = RadiusGraph(x=x, pos=pos, build_edge_attr=False, r=.3)
|
||||||
|
assert len(graph.data) == 1
|
||||||
|
data = graph.data
|
||||||
|
assert all(torch.isclose(d.x, x).all() for d in data)
|
||||||
|
assert all(torch.isclose(d_.pos, pos).all() for d_ in data)
|
||||||
|
assert all(len(d.edge_index) == 2 for d in data)
|
||||||
|
graph = RadiusGraph(x=x, pos=pos, build_edge_attr=True, r=.3)
|
||||||
|
data = graph.data
|
||||||
|
assert len(graph.data) == 1
|
||||||
|
assert all(torch.isclose(d.x, x).all() for d in data)
|
||||||
|
assert all(torch.isclose(d_.pos, pos).all() for d_ in data)
|
||||||
|
assert all(len(d.edge_index) == 2 for d in data)
|
||||||
|
assert all(d.edge_attr is not None for d in data)
|
||||||
|
assert all([d.edge_index.shape[1] == d.edge_attr.shape[0]] for d in data)
|
||||||
|
|
||||||
|
x = torch.rand(10, 2)
|
||||||
|
pos = torch.rand(10, 3)
|
||||||
|
graph = KNNGraph(x=x, pos=pos, build_edge_attr=True, k=3)
|
||||||
|
assert len(graph.data) == 1
|
||||||
|
data = graph.data
|
||||||
|
assert all(torch.isclose(d.x, x).all() for d in data)
|
||||||
|
assert all(torch.isclose(d_.pos, pos).all() for d_ in data)
|
||||||
|
assert all(len(d.edge_index) == 2 for d in data)
|
||||||
|
graph = KNNGraph(x=x, pos=pos, build_edge_attr=True, k=3)
|
||||||
|
data = graph.data
|
||||||
|
assert len(graph.data) == 1
|
||||||
|
assert all(torch.isclose(d.x, x).all() for d in data)
|
||||||
|
assert all(torch.isclose(d_.pos, pos).all() for d_ in data)
|
||||||
|
assert all(len(d.edge_index) == 2 for d in data)
|
||||||
|
assert all(d.edge_attr is not None for d in data)
|
||||||
|
assert all([d.edge_index.shape[1] == d.edge_attr.shape[0]] for d in data)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"pos",
|
||||||
|
[
|
||||||
|
([torch.rand(10, 3) for _ in range(3)]),
|
||||||
|
([torch.rand(10, 3) for _ in range(3)]),
|
||||||
|
(torch.rand(3, 10, 3)),
|
||||||
|
(torch.rand(3, 10, 3))
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def test_build_single_graph_single_val(pos):
|
||||||
|
x = torch.rand(10, 2)
|
||||||
|
graph = RadiusGraph(x=x, pos=pos, build_edge_attr=False, r=.3)
|
||||||
|
assert len(graph.data) == 3
|
||||||
|
data = graph.data
|
||||||
|
assert all(torch.isclose(d.x, x).all() for d in data)
|
||||||
|
assert all(torch.isclose(d_.pos, pos_).all() for d_, pos_ in zip(data, pos))
|
||||||
|
assert all(len(d.edge_index) == 2 for d in data)
|
||||||
|
graph = RadiusGraph(x=x, pos=pos, build_edge_attr=True, r=.3)
|
||||||
|
data = graph.data
|
||||||
|
assert all(torch.isclose(d.x, x).all() for d in data)
|
||||||
|
assert all(torch.isclose(d_.pos, pos_).all() for d_, pos_ in zip(data, pos))
|
||||||
|
assert all(len(d.edge_index) == 2 for d in data)
|
||||||
|
assert all(d.edge_attr is not None for d in data)
|
||||||
|
assert all([d.edge_index.shape[1] == d.edge_attr.shape[0]] for d in data)
|
||||||
|
x = torch.rand(10, 2)
|
||||||
|
graph = KNNGraph(x=x, pos=pos, build_edge_attr=False, k=3)
|
||||||
|
assert len(graph.data) == 3
|
||||||
|
data = graph.data
|
||||||
|
assert all(torch.isclose(d.x, x).all() for d in data)
|
||||||
|
assert all(torch.isclose(d_.pos, pos_).all() for d_, pos_ in zip(data, pos))
|
||||||
|
assert all(len(d.edge_index) == 2 for d in data)
|
||||||
|
graph = KNNGraph(x=x, pos=pos, build_edge_attr=True, k=3)
|
||||||
|
data = graph.data
|
||||||
|
assert all(torch.isclose(d.x, x).all() for d in data)
|
||||||
|
assert all(torch.isclose(d_.pos, pos_).all() for d_, pos_ in zip(data, pos))
|
||||||
|
assert all(len(d.edge_index) == 2 for d in data)
|
||||||
|
assert all(d.edge_attr is not None for d in data)
|
||||||
|
assert all([d.edge_index.shape[1] == d.edge_attr.shape[0]] for d in data)
|
||||||
|
|
||||||
|
def test_additional_parameters_1():
|
||||||
|
x = torch.rand(3, 10, 2)
|
||||||
|
pos = torch.rand(3, 10, 2)
|
||||||
|
additional_parameters = {'y': torch.ones(3)}
|
||||||
|
graph = RadiusGraph(x=x, pos=pos, build_edge_attr=True, r=.3,
|
||||||
|
additional_params=additional_parameters)
|
||||||
|
assert len(graph.data) == 3
|
||||||
|
data = graph.data
|
||||||
|
assert all(torch.isclose(d_.x, x_).all() for (d_, x_) in zip(data, x))
|
||||||
|
assert all(hasattr(d, 'y') for d in data)
|
||||||
|
assert all(d_.y == 1 for d_ in data)
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"additional_parameters",
|
||||||
|
[
|
||||||
|
({'y': torch.rand(3,10,1)}),
|
||||||
|
({'y': [torch.rand(10,1) for _ in range(3)]}),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def test_additional_parameters_2(additional_parameters):
|
||||||
|
x = torch.rand(3, 10, 2)
|
||||||
|
pos = torch.rand(3, 10, 2)
|
||||||
|
graph = RadiusGraph(x=x, pos=pos, build_edge_attr=True, r=.3,
|
||||||
|
additional_params=additional_parameters)
|
||||||
|
assert len(graph.data) == 3
|
||||||
|
data = graph.data
|
||||||
|
assert all(torch.isclose(d_.x, x_).all() for (d_, x_) in zip(data, x))
|
||||||
|
assert all(hasattr(d, 'y') for d in data)
|
||||||
|
assert all(torch.isclose(d_.x, x_).all() for (d_, x_) in zip(data, x))
|
||||||
Reference in New Issue
Block a user