Formatting
* Adding black as dev dependency * Formatting pina code * Formatting tests
This commit is contained in:
committed by
Nicola Demo
parent
4c4482b155
commit
42ab1a666b
181
pina/graph.py
181
pina/graph.py
@@ -14,15 +14,15 @@ class Graph:
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
x,
|
||||
pos,
|
||||
edge_index,
|
||||
edge_attr=None,
|
||||
build_edge_attr=False,
|
||||
undirected=False,
|
||||
custom_build_edge_attr=None,
|
||||
additional_params=None
|
||||
self,
|
||||
x,
|
||||
pos,
|
||||
edge_index,
|
||||
edge_attr=None,
|
||||
build_edge_attr=False,
|
||||
undirected=False,
|
||||
custom_build_edge_attr=None,
|
||||
additional_params=None,
|
||||
):
|
||||
"""
|
||||
Constructor for the Graph class. This object creates a list of PyTorch Geometric Data objects.
|
||||
@@ -72,8 +72,9 @@ class Graph:
|
||||
self._build_edge_attr = custom_build_edge_attr
|
||||
|
||||
# Check consistency and initialize additional_parameters (if present)
|
||||
additional_params = self._check_additional_params(additional_params,
|
||||
data_len)
|
||||
additional_params = self._check_additional_params(
|
||||
additional_params, data_len
|
||||
)
|
||||
|
||||
# Make the graphs undirected
|
||||
if undirected:
|
||||
@@ -84,49 +85,63 @@ class Graph:
|
||||
|
||||
# Prepare internal lists to create a graph list (same positions but
|
||||
# different node features)
|
||||
if isinstance(x, list) and isinstance(pos,
|
||||
(torch.Tensor, LabelTensor)):
|
||||
if isinstance(x, list) and isinstance(pos, (torch.Tensor, LabelTensor)):
|
||||
# Replicate the positions, edge_index and edge_attr
|
||||
pos, edge_index = [pos] * data_len, [edge_index] * data_len
|
||||
# Prepare internal lists to create a list containing a single graph
|
||||
elif isinstance(x, (torch.Tensor, LabelTensor)) and isinstance(pos, (
|
||||
torch.Tensor, LabelTensor)):
|
||||
elif isinstance(x, (torch.Tensor, LabelTensor)) and isinstance(
|
||||
pos, (torch.Tensor, LabelTensor)
|
||||
):
|
||||
# Encapsulate the input tensors into lists
|
||||
x, pos, edge_index = [x], [pos], [edge_index]
|
||||
# Prepare internal lists to create a list of graphs (same node features
|
||||
# but different positions)
|
||||
elif (isinstance(x, (torch.Tensor, LabelTensor))
|
||||
and isinstance(pos, list)):
|
||||
elif isinstance(x, (torch.Tensor, LabelTensor)) and isinstance(
|
||||
pos, list
|
||||
):
|
||||
# Replicate the node features
|
||||
x = [x] * data_len
|
||||
elif not isinstance(x, list) and not isinstance(pos, list):
|
||||
raise TypeError("x and pos must be lists or tensors.")
|
||||
|
||||
# Build the edge attributes
|
||||
edge_attr = self._check_and_build_edge_attr(edge_attr, build_edge_attr,
|
||||
data_len, edge_index, pos,
|
||||
x)
|
||||
edge_attr = self._check_and_build_edge_attr(
|
||||
edge_attr, build_edge_attr, data_len, edge_index, pos, x
|
||||
)
|
||||
|
||||
# Perform the graph construction
|
||||
self._build_graph_list(
|
||||
x, pos, edge_index, edge_attr, additional_params)
|
||||
self._build_graph_list(x, pos, edge_index, edge_attr, additional_params)
|
||||
|
||||
def _build_graph_list(self, x, pos, edge_index, edge_attr,
|
||||
additional_params):
|
||||
def _build_graph_list(
|
||||
self, x, pos, edge_index, edge_attr, additional_params
|
||||
):
|
||||
for i, (x_, pos_, edge_index_) in enumerate(zip(x, pos, edge_index)):
|
||||
add_params_local = {k: v[i] for k, v in additional_params.items()}
|
||||
if edge_attr is not None:
|
||||
self.data.append(Data(x=x_, pos=pos_, edge_index=edge_index_,
|
||||
edge_attr=edge_attr[i],
|
||||
**add_params_local))
|
||||
self.data.append(
|
||||
Data(
|
||||
x=x_,
|
||||
pos=pos_,
|
||||
edge_index=edge_index_,
|
||||
edge_attr=edge_attr[i],
|
||||
**add_params_local,
|
||||
)
|
||||
)
|
||||
else:
|
||||
self.data.append(Data(x=x_, pos=pos_, edge_index=edge_index_,
|
||||
**add_params_local))
|
||||
self.data.append(
|
||||
Data(
|
||||
x=x_,
|
||||
pos=pos_,
|
||||
edge_index=edge_index_,
|
||||
**add_params_local,
|
||||
)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _build_edge_attr(x, pos, edge_index):
|
||||
distance = torch.abs(pos[edge_index[0]] -
|
||||
pos[edge_index[1]]).as_subclass(torch.Tensor)
|
||||
distance = torch.abs(
|
||||
pos[edge_index[0]] - pos[edge_index[1]]
|
||||
).as_subclass(torch.Tensor)
|
||||
return distance
|
||||
|
||||
@staticmethod
|
||||
@@ -147,32 +162,39 @@ class Graph:
|
||||
# If x is a 3D tensor, we split it into a list of 2D tensors
|
||||
if isinstance(x, torch.Tensor) and x.ndim == 3:
|
||||
x = [x[i] for i in range(x.shape[0])]
|
||||
elif (not (isinstance(x, list) and all(t.ndim == 2 for t in x)) and
|
||||
not (isinstance(x, torch.Tensor) and x.ndim == 2)):
|
||||
raise TypeError("x must be either a list of 2D tensors or a 2D "
|
||||
"tensor or a 3D tensor")
|
||||
elif not (isinstance(x, list) and all(t.ndim == 2 for t in x)) and not (
|
||||
isinstance(x, torch.Tensor) and x.ndim == 2
|
||||
):
|
||||
raise TypeError(
|
||||
"x must be either a list of 2D tensors or a 2D "
|
||||
"tensor or a 3D tensor"
|
||||
)
|
||||
|
||||
# If pos is a 3D tensor, we split it into a list of 2D tensors
|
||||
if isinstance(pos, torch.Tensor) and pos.ndim == 3:
|
||||
pos = [pos[i] for i in range(pos.shape[0])]
|
||||
elif not (isinstance(pos, list) and all(
|
||||
t.ndim == 2 for t in pos)) and not (
|
||||
isinstance(pos, torch.Tensor) and pos.ndim == 2):
|
||||
raise TypeError("pos must be either a list of 2D tensors or a 2D "
|
||||
"tensor or a 3D tensor")
|
||||
elif not (
|
||||
isinstance(pos, list) and all(t.ndim == 2 for t in pos)
|
||||
) and not (isinstance(pos, torch.Tensor) and pos.ndim == 2):
|
||||
raise TypeError(
|
||||
"pos must be either a list of 2D tensors or a 2D "
|
||||
"tensor or a 3D tensor"
|
||||
)
|
||||
|
||||
# If edge_index is a 3D tensor, we split it into a list of 2D tensors
|
||||
if edge_index is not None:
|
||||
if isinstance(edge_index, torch.Tensor) and edge_index.ndim == 3:
|
||||
edge_index = [edge_index[i]
|
||||
for i in range(edge_index.shape[0])]
|
||||
elif not (isinstance(edge_index, list) and all(
|
||||
t.ndim == 2 for t in edge_index)) and not (
|
||||
isinstance(edge_index,
|
||||
torch.Tensor) and edge_index.ndim == 2):
|
||||
edge_index = [edge_index[i] for i in range(edge_index.shape[0])]
|
||||
elif not (
|
||||
isinstance(edge_index, list)
|
||||
and all(t.ndim == 2 for t in edge_index)
|
||||
) and not (
|
||||
isinstance(edge_index, torch.Tensor) and edge_index.ndim == 2
|
||||
):
|
||||
raise TypeError(
|
||||
"edge_index must be either a list of 2D tensors or a 2D "
|
||||
"tensor or a 3D tensor")
|
||||
"tensor or a 3D tensor"
|
||||
)
|
||||
|
||||
return x, pos, edge_index
|
||||
|
||||
@@ -188,8 +210,9 @@ class Graph:
|
||||
# In this case there must be a additional parameter for each
|
||||
# node
|
||||
if val.ndim == 3:
|
||||
additional_params[param] = [val[i] for i in
|
||||
range(val.shape[0])]
|
||||
additional_params[param] = [
|
||||
val[i] for i in range(val.shape[0])
|
||||
]
|
||||
# If the tensor is 2D, we replicate it for each node
|
||||
elif val.ndim == 2:
|
||||
additional_params[param] = [val] * data_len
|
||||
@@ -197,44 +220,48 @@ class Graph:
|
||||
# additional parameter
|
||||
if val.ndim == 1:
|
||||
if len(val) == data_len:
|
||||
additional_params[param] = [val[i] for i in
|
||||
range(len(val))]
|
||||
additional_params[param] = [
|
||||
val[i] for i in range(len(val))
|
||||
]
|
||||
else:
|
||||
additional_params[param] = [val for _ in
|
||||
range(data_len)]
|
||||
additional_params[param] = [
|
||||
val for _ in range(data_len)
|
||||
]
|
||||
elif not isinstance(val, list):
|
||||
raise TypeError("additional_params values must be tensors "
|
||||
"or lists of tensors.")
|
||||
raise TypeError(
|
||||
"additional_params values must be tensors "
|
||||
"or lists of tensors."
|
||||
)
|
||||
else:
|
||||
additional_params = {}
|
||||
return additional_params
|
||||
|
||||
def _check_and_build_edge_attr(self, edge_attr, build_edge_attr, data_len,
|
||||
edge_index, pos, x):
|
||||
def _check_and_build_edge_attr(
|
||||
self, edge_attr, build_edge_attr, data_len, edge_index, pos, x
|
||||
):
|
||||
# Check if edge_attr is consistent with x and pos
|
||||
if edge_attr is not None:
|
||||
if build_edge_attr is True:
|
||||
warning("edge_attr is not None. build_edge_attr will not be "
|
||||
"considered.")
|
||||
warning(
|
||||
"edge_attr is not None. build_edge_attr will not be "
|
||||
"considered."
|
||||
)
|
||||
if isinstance(edge_attr, list):
|
||||
if len(edge_attr) != data_len:
|
||||
raise TypeError("edge_attr must have the same length as x "
|
||||
"and pos.")
|
||||
raise TypeError(
|
||||
"edge_attr must have the same length as x " "and pos."
|
||||
)
|
||||
return [edge_attr] * data_len
|
||||
|
||||
if build_edge_attr:
|
||||
return [self._build_edge_attr(x_, pos_, edge_index_) for
|
||||
x_, pos_, edge_index_ in zip(x, pos, edge_index)]
|
||||
return [
|
||||
self._build_edge_attr(x_, pos_, edge_index_)
|
||||
for x_, pos_, edge_index_ in zip(x, pos, edge_index)
|
||||
]
|
||||
|
||||
|
||||
class RadiusGraph(Graph):
|
||||
def __init__(
|
||||
self,
|
||||
x,
|
||||
pos,
|
||||
r,
|
||||
**kwargs
|
||||
):
|
||||
def __init__(self, x, pos, r, **kwargs):
|
||||
x, pos, edge_index = Graph._check_input_consistency(x, pos)
|
||||
|
||||
if isinstance(pos, (torch.Tensor, LabelTensor)):
|
||||
@@ -242,8 +269,7 @@ class RadiusGraph(Graph):
|
||||
else:
|
||||
edge_index = [RadiusGraph._radius_graph(p, r) for p in pos]
|
||||
|
||||
super().__init__(x=x, pos=pos, edge_index=edge_index,
|
||||
**kwargs)
|
||||
super().__init__(x=x, pos=pos, edge_index=edge_index, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def _radius_graph(points, r):
|
||||
@@ -264,20 +290,13 @@ class RadiusGraph(Graph):
|
||||
|
||||
|
||||
class KNNGraph(Graph):
|
||||
def __init__(
|
||||
self,
|
||||
x,
|
||||
pos,
|
||||
k,
|
||||
**kwargs
|
||||
):
|
||||
def __init__(self, x, pos, k, **kwargs):
|
||||
x, pos, edge_index = Graph._check_input_consistency(x, pos)
|
||||
if isinstance(pos, (torch.Tensor, LabelTensor)):
|
||||
edge_index = KNNGraph._knn_graph(pos, k)
|
||||
else:
|
||||
edge_index = [KNNGraph._knn_graph(p, k) for p in pos]
|
||||
super().__init__(x=x, pos=pos, edge_index=edge_index,
|
||||
**kwargs)
|
||||
super().__init__(x=x, pos=pos, edge_index=edge_index, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def _knn_graph(points, k):
|
||||
|
||||
Reference in New Issue
Block a user