Add Graph class and tests for Graph and Collector + Dataloader refactoring

This commit is contained in:
FilippoOlivo
2025-01-16 17:10:38 +01:00
committed by Nicola Demo
parent 4fdb5641d4
commit e63c3d9061
3 changed files with 496 additions and 106 deletions

View File

@@ -1,118 +1,240 @@
""" Module for Loss class """
from logging import warning
import logging
from torch_geometric.nn import MessagePassing, InstanceNorm, radius_graph
from torch_geometric.data import Data
import torch
from . import LabelTensor
from torch_geometric.data import Data
from torch_geometric.utils import to_undirected
class Graph:
"""
PINA Graph managing the PyG Data class.
Class for the graph construction.
"""
def __init__(self, data):
self.data = data
@staticmethod
def _build_triangulation(**kwargs):
logging.debug("Creating graph with triangulation mode.")
# check for mandatory arguments
if "nodes_coordinates" not in kwargs:
raise ValueError("Nodes coordinates must be provided in the kwargs.")
if "nodes_data" not in kwargs:
raise ValueError("Nodes data must be provided in the kwargs.")
if "triangles" not in kwargs:
raise ValueError("Triangles must be provided in the kwargs.")
nodes_coordinates = kwargs["nodes_coordinates"]
nodes_data = kwargs["nodes_data"]
triangles = kwargs["triangles"]
def less_first(a, b):
return [a, b] if a < b else [b, a]
list_of_edges = []
for triangle in triangles:
for e1, e2 in [[0, 1], [1, 2], [2, 0]]:
list_of_edges.append(less_first(triangle[e1],triangle[e2]))
array_of_edges = torch.unique(torch.Tensor(list_of_edges), dim=0) # remove duplicates
array_of_edges = array_of_edges.t().contiguous()
print(array_of_edges)
# list_of_lengths = []
# for p1,p2 in array_of_edges:
# x1, y1 = tri.points[p1]
# x2, y2 = tri.points[p2]
# list_of_lengths.append((x1-x2)**2 + (y1-y2)**2)
# array_of_lengths = np.sqrt(np.array(list_of_lengths))
# return array_of_edges, array_of_lengths
return Data(
x=nodes_data,
pos=nodes_coordinates.T,
edge_index=array_of_edges,
)
@staticmethod
def _build_radius(**kwargs):
logging.debug("Creating graph with radius mode.")
# check for mandatory arguments
if "nodes_coordinates" not in kwargs:
raise ValueError("Nodes coordinates must be provided in the kwargs.")
if "nodes_data" not in kwargs:
raise ValueError("Nodes data must be provided in the kwargs.")
if "radius" not in kwargs:
raise ValueError("Radius must be provided in the kwargs.")
nodes_coordinates = kwargs["nodes_coordinates"]
nodes_data = kwargs["nodes_data"]
radius = kwargs["radius"]
edges_data = kwargs.get("edge_data", None)
loop = kwargs.get("loop", False)
batch = kwargs.get("batch", None)
logging.debug(f"radius: {radius}, loop: {loop}, "
f"batch: {batch}")
edge_index = radius_graph(
x=nodes_coordinates.tensor,
r=radius,
loop=loop,
batch=batch,
)
logging.debug(f"edge_index computed")
return Data(
x=nodes_data.tensor,
pos=nodes_coordinates.tensor,
edge_index=edge_index,
edge_attr=edges_data,
)
@staticmethod
def build(mode, **kwargs):
def __init__(self,
x,
pos,
edge_index,
edge_attr=None,
build_edge_attr=False,
undirected=False,
additional_params=None):
"""
Constructor for the `Graph` class.
Constructor for the Graph class.
:param x: The node features.
:type x: torch.Tensor or list[torch.Tensor]
:param pos: The node positions.
:type pos: torch.Tensor or list[torch.Tensor]
:param edge_index: The edge index.
:type edge_index: torch.Tensor or list[torch.Tensor]
:param edge_attr: The edge attributes.
:type edge_attr: torch.Tensor or list[torch.Tensor]
:param build_edge_attr: Whether to build the edge attributes.
:type build_edge_attr: bool
:param undirected: Whether to build an undirected graph.
:type undirected: bool
:param additional_params: Additional parameters.
:type additional_params: dict
"""
if mode == "radius":
graph = Graph._build_radius(**kwargs)
elif mode == "triangulation":
graph = Graph._build_triangulation(**kwargs)
self.data = []
x, pos, edge_index = Graph._check_input_consistency(x, pos, edge_index)
# Check input dimension consistency and store the number of graphs
data_len = self._check_len_consistency(x, pos)
# Initialize additional_parameters (if present)
if additional_params is not None:
if not isinstance(additional_params, dict):
raise TypeError("additional_params must be a dictionary.")
for param, val in additional_params.items():
# Check if the values are tensors or lists of tensors
if isinstance(val, torch.Tensor):
# If the tensor is 3D, we split it into a list of 2D tensors
# In this case there must be a additional parameter for each
# node
if val.ndim == 3:
additional_params[param] = [val[i] for i in
range(val.shape[0])]
# If the tensor is 2D, we replicate it for each node
elif val.ndim == 2:
additional_params[param] = [val] * data_len
# If the tensor is 1D, each graph has a scalar values as
# additional parameter
if val.ndim == 1:
if len(val) == data_len:
additional_params[param] = [val[i] for i in
range(len(val))]
else:
raise ValueError(f"Mode {mode} not recognized")
additional_params[param] = [val for _ in
range(data_len)]
elif not isinstance(val, list):
raise TypeError("additional_params values must be tensors "
"or lists of tensors.")
else:
additional_params = {}
return Graph(graph)
# Make the graphs undirected
if undirected:
if isinstance(edge_index, list):
edge_index = [to_undirected(e) for e in edge_index]
else:
edge_index = to_undirected(edge_index)
if build_edge_attr:
if edge_attr is not None:
warning("Edge attributes are provided, build_edge_attr is set "
"to True. The provided edge attributes will be ignored.")
edge_attr = self._build_edge_attr(pos, edge_index)
# Prepare internal lists to create a graph list (same positions but
# different node features)
if isinstance(x, list) and isinstance(pos,
(torch.Tensor, LabelTensor)):
# Replicate the positions, edge_index and edge_attr
pos, edge_index = [pos] * data_len, [edge_index] * data_len
if edge_attr is not None:
edge_attr = [edge_attr] * data_len
# Prepare internal lists to create a list containing a single graph
elif isinstance(x, (torch.Tensor, LabelTensor)) and isinstance(pos, (
torch.Tensor, LabelTensor)):
# Encapsulate the input tensors into lists
x, pos, edge_index = [x], [pos], [edge_index]
if isinstance(edge_attr, torch.Tensor):
edge_attr = [edge_attr]
# Prepare internal lists to create a list of graphs (same node features
# but different positions)
elif (isinstance(x, (torch.Tensor, LabelTensor))
and isinstance(pos, list)):
# Replicate the node features
x = [x] * data_len
elif not isinstance(x, list) and not isinstance(pos, list):
raise TypeError("x and pos must be lists or tensors.")
# Perform the graph construction
self._build_graph_list(x, pos, edge_index, edge_attr, additional_params)
def _build_graph_list(self, x, pos, edge_index, edge_attr,
additional_params):
for i, (x_, pos_, edge_index_) in enumerate(zip(x, pos, edge_index)):
if isinstance(x_, LabelTensor):
x_ = x_.tensor
add_params_local = {k: v[i] for k, v in additional_params.items()}
if edge_attr is not None:
self.data.append(Data(x=x_, pos=pos_, edge_index=edge_index_,
edge_attr=edge_attr[i],
**add_params_local))
else:
self.data.append(Data(x=x_, pos=pos_, edge_index=edge_index_,
**add_params_local))
@staticmethod
def _build_edge_attr(pos, edge_index):
if isinstance(pos, torch.Tensor):
pos = [pos]
edge_index = [edge_index]
distance = [pos_[edge_index_[0]] - pos_[edge_index_[1]] ** 2 for
pos_, edge_index_ in zip(pos, edge_index)]
return distance
@staticmethod
def _check_len_consistency(x, pos):
if isinstance(x, list) and isinstance(pos, list):
if len(x) != len(pos):
raise ValueError("x and pos must have the same length.")
return max(len(x), len(pos))
elif isinstance(x, list) and not isinstance(pos, list):
return len(x)
elif not isinstance(x, list) and isinstance(pos, list):
return len(pos)
else:
return 1
@staticmethod
def _check_input_consistency(x, pos, edge_index=None):
# If x is a 3D tensor, we split it into a list of 2D tensors
if isinstance(x, torch.Tensor) and x.ndim == 3:
x = [x[i] for i in range(x.shape[0])]
# If pos is a 3D tensor, we split it into a list of 2D tensors
if isinstance(pos, torch.Tensor) and pos.ndim == 3:
pos = [pos[i] for i in range(pos.shape[0])]
# If edge_index is a 3D tensor, we split it into a list of 2D tensors
if isinstance(edge_index, torch.Tensor) and edge_index.ndim == 3:
edge_index = [edge_index[i] for i in range(edge_index.shape[0])]
return x, pos, edge_index
def __repr__(self):
return f"Graph(data={self.data})"
class RadiusGraph(Graph):
def __init__(self,
x,
pos,
r,
build_edge_attr=False,
undirected=False,
additional_params=None, ):
x, pos, edge_index = Graph._check_input_consistency(x, pos)
if isinstance(pos, (torch.Tensor, LabelTensor)):
edge_index = RadiusGraph._radius_graph(pos, r)
else:
edge_index = [RadiusGraph._radius_graph(p, r) for p in pos]
super().__init__(x=x, pos=pos, edge_index=edge_index,
build_edge_attr=build_edge_attr,
undirected=undirected,
additional_params=additional_params)
@staticmethod
def _radius_graph(points, r):
"""
Implementation of the radius graph construction.
:param points: The input points.
:type points: torch.Tensor
:param r: The radius.
:type r: float
:return: The edge index.
:rtype: torch.Tensor
"""
dist = torch.cdist(points, points, p=2)
edge_index = torch.nonzero(dist <= r, as_tuple=False).t()
return edge_index
class KNNGraph(Graph):
def __init__(self,
x,
pos,
k,
build_edge_attr=False,
undirected=False,
additional_params=None,
):
x, pos, edge_index = Graph._check_input_consistency(x, pos)
if isinstance(pos, (torch.Tensor, LabelTensor)):
edge_index = KNNGraph._knn_graph(pos, k)
else:
edge_index = [KNNGraph._knn_graph(p, k) for p in pos]
super().__init__(x=x, pos=pos, edge_index=edge_index,
build_edge_attr=build_edge_attr,
undirected=undirected,
additional_params=additional_params)
@staticmethod
def _knn_graph(points, k):
"""
Implementation of the k-nearest neighbors graph construction.
:param points: The input points.
:type points: torch.Tensor
:param k: The number of nearest neighbors.
:type k: int
:return: The edge index.
:rtype: torch.Tensor
"""
dist = torch.cdist(points, points, p=2)
knn_indices = torch.topk(dist, k=k + 1, largest=False).indices[:, 1:]
row = torch.arange(points.size(0)).repeat_interleave(k)
col = knn_indices.flatten()
edge_index = torch.stack([row, col], dim=0)
return edge_index

125
tests/test_collector.py Normal file
View File

@@ -0,0 +1,125 @@
import torch
import pytest
from pina import Condition, LabelTensor, Graph
from pina.condition import InputOutputPointsCondition, DomainEquationCondition
from pina.graph import RadiusGraph
from pina.problem import AbstractProblem, SpatialProblem
from pina.domain import CartesianDomain
from pina.equation.equation import Equation
from pina.equation.equation_factory import FixedValue
from pina.operators import laplacian
def test_supervised_tensor_collector():
class SupervisedProblem(AbstractProblem):
output_variables = None
conditions = {
'data1' : Condition(input_points=torch.rand((10,2)),
output_points=torch.rand((10,2))),
'data2' : Condition(input_points=torch.rand((20,2)),
output_points=torch.rand((20,2))),
'data3' : Condition(input_points=torch.rand((30,2)),
output_points=torch.rand((30,2))),
}
problem = SupervisedProblem()
collector = problem.collector
for v in collector.conditions_name.values():
assert v in problem.conditions.keys()
assert all(collector._is_conditions_ready.values())
def test_pinn_collector():
def laplace_equation(input_, output_):
force_term = (torch.sin(input_.extract(['x']) * torch.pi) *
torch.sin(input_.extract(['y']) * torch.pi))
delta_u = laplacian(output_.extract(['u']), input_)
return delta_u - force_term
my_laplace = Equation(laplace_equation)
in_ = LabelTensor(torch.tensor([[0., 1.]], requires_grad=True), ['x', 'y'])
out_ = LabelTensor(torch.tensor([[0.]], requires_grad=True), ['u'])
class Poisson(SpatialProblem):
output_variables = ['u']
spatial_domain = CartesianDomain({'x': [0, 1], 'y': [0, 1]})
conditions = {
'gamma1':
Condition(domain=CartesianDomain({
'x': [0, 1],
'y': 1
}),
equation=FixedValue(0.0)),
'gamma2':
Condition(domain=CartesianDomain({
'x': [0, 1],
'y': 0
}),
equation=FixedValue(0.0)),
'gamma3':
Condition(domain=CartesianDomain({
'x': 1,
'y': [0, 1]
}),
equation=FixedValue(0.0)),
'gamma4':
Condition(domain=CartesianDomain({
'x': 0,
'y': [0, 1]
}),
equation=FixedValue(0.0)),
'D':
Condition(domain=CartesianDomain({
'x': [0, 1],
'y': [0, 1]
}),
equation=my_laplace),
'data':
Condition(input_points=in_, output_points=out_)
}
def poisson_sol(self, pts):
return -(torch.sin(pts.extract(['x']) * torch.pi) *
torch.sin(pts.extract(['y']) * torch.pi)) / (2 * torch.pi**2)
truth_solution = poisson_sol
problem = Poisson()
collector = problem.collector
for k,v in problem.conditions.items():
if isinstance(v, InputOutputPointsCondition):
assert collector._is_conditions_ready[k] == True
assert list(collector.data_collections[k].keys()) == ['input_points', 'output_points']
else:
assert collector._is_conditions_ready[k] == False
assert collector.data_collections[k] == {}
boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
problem.discretise_domain(10, 'grid', locations=boundaries)
problem.discretise_domain(10, 'grid', locations='D')
assert all(collector._is_conditions_ready.values())
for k,v in problem.conditions.items():
if isinstance(v, DomainEquationCondition):
assert list(collector.data_collections[k].keys()) == ['input_points', 'equation']
def test_supervised_graph_collector():
pos = torch.rand((100,3))
x = [torch.rand((100,3)) for _ in range(10)]
graph_list_1 = RadiusGraph(pos=pos, x=x, build_edge_attr=True, r=.4)
out_1 = torch.rand((10,100,3))
pos = torch.rand((50,3))
x = [torch.rand((50,3)) for _ in range(10)]
graph_list_2 = RadiusGraph(pos=pos, x=x, build_edge_attr=True, r=.4)
out_2 = torch.rand((10,50,3))
class SupervisedProblem(AbstractProblem):
output_variables = None
conditions = {
'data1' : Condition(input_points=graph_list_1,
output_points=out_1),
'data2' : Condition(input_points=graph_list_2,
output_points=out_2),
}
problem = SupervisedProblem()
collector = problem.collector
assert all(collector._is_conditions_ready.values())
for v in collector.conditions_name.values():
assert v in problem.conditions.keys()

143
tests/test_graph.py Normal file
View File

@@ -0,0 +1,143 @@
import pytest
import torch
from pina import Graph
from pina.graph import RadiusGraph, KNNGraph
@pytest.mark.parametrize(
"x, pos",
[
([torch.rand(10, 2) for _ in range(3)], [torch.rand(10, 3) for _ in range(3)]),
([torch.rand(10, 2) for _ in range(3)], [torch.rand(10, 3) for _ in range(3)]),
(torch.rand(3,10,2), torch.rand(3,10,3)),
(torch.rand(3,10,2), torch.rand(3,10,3)),
]
)
def test_build_multiple_graph_multiple_val(x, pos):
graph = RadiusGraph(x=x, pos=pos, build_edge_attr=False, r=.3)
assert len(graph.data) == 3
data = graph.data
assert all(torch.isclose(d_.x, x_).all() for (d_, x_) in zip(data, x))
assert all(torch.isclose(d_.pos, pos_).all() for d_, pos_ in zip(data, pos))
assert all(len(d.edge_index) == 2 for d in data)
graph = RadiusGraph(x=x, pos=pos, build_edge_attr=True, r=.3)
data = graph.data
assert all(torch.isclose(d_.x, x_).all() for (d_, x_) in zip(data, x))
assert all(torch.isclose(d_.pos, pos_).all() for d_, pos_ in zip(data, pos))
assert all(len(d.edge_index) == 2 for d in data)
assert all(d.edge_attr is not None for d in data)
assert all([d.edge_index.shape[1] == d.edge_attr.shape[0]] for d in data)
graph = KNNGraph(x=x, pos=pos, build_edge_attr=True, k = 3)
data = graph.data
assert all(torch.isclose(d_.x, x_).all() for (d_, x_) in zip(data, x))
assert all(torch.isclose(d_.pos, pos_).all() for d_, pos_ in zip(data, pos))
assert all(len(d.edge_index) == 2 for d in data)
assert all(d.edge_attr is not None for d in data)
assert all([d.edge_index.shape[1] == d.edge_attr.shape[0]] for d in data)
def test_build_single_graph_multiple_val():
x = torch.rand(10, 2)
pos = torch.rand(10, 3)
graph = RadiusGraph(x=x, pos=pos, build_edge_attr=False, r=.3)
assert len(graph.data) == 1
data = graph.data
assert all(torch.isclose(d.x, x).all() for d in data)
assert all(torch.isclose(d_.pos, pos).all() for d_ in data)
assert all(len(d.edge_index) == 2 for d in data)
graph = RadiusGraph(x=x, pos=pos, build_edge_attr=True, r=.3)
data = graph.data
assert len(graph.data) == 1
assert all(torch.isclose(d.x, x).all() for d in data)
assert all(torch.isclose(d_.pos, pos).all() for d_ in data)
assert all(len(d.edge_index) == 2 for d in data)
assert all(d.edge_attr is not None for d in data)
assert all([d.edge_index.shape[1] == d.edge_attr.shape[0]] for d in data)
x = torch.rand(10, 2)
pos = torch.rand(10, 3)
graph = KNNGraph(x=x, pos=pos, build_edge_attr=True, k=3)
assert len(graph.data) == 1
data = graph.data
assert all(torch.isclose(d.x, x).all() for d in data)
assert all(torch.isclose(d_.pos, pos).all() for d_ in data)
assert all(len(d.edge_index) == 2 for d in data)
graph = KNNGraph(x=x, pos=pos, build_edge_attr=True, k=3)
data = graph.data
assert len(graph.data) == 1
assert all(torch.isclose(d.x, x).all() for d in data)
assert all(torch.isclose(d_.pos, pos).all() for d_ in data)
assert all(len(d.edge_index) == 2 for d in data)
assert all(d.edge_attr is not None for d in data)
assert all([d.edge_index.shape[1] == d.edge_attr.shape[0]] for d in data)
@pytest.mark.parametrize(
"pos",
[
([torch.rand(10, 3) for _ in range(3)]),
([torch.rand(10, 3) for _ in range(3)]),
(torch.rand(3, 10, 3)),
(torch.rand(3, 10, 3))
]
)
def test_build_single_graph_single_val(pos):
x = torch.rand(10, 2)
graph = RadiusGraph(x=x, pos=pos, build_edge_attr=False, r=.3)
assert len(graph.data) == 3
data = graph.data
assert all(torch.isclose(d.x, x).all() for d in data)
assert all(torch.isclose(d_.pos, pos_).all() for d_, pos_ in zip(data, pos))
assert all(len(d.edge_index) == 2 for d in data)
graph = RadiusGraph(x=x, pos=pos, build_edge_attr=True, r=.3)
data = graph.data
assert all(torch.isclose(d.x, x).all() for d in data)
assert all(torch.isclose(d_.pos, pos_).all() for d_, pos_ in zip(data, pos))
assert all(len(d.edge_index) == 2 for d in data)
assert all(d.edge_attr is not None for d in data)
assert all([d.edge_index.shape[1] == d.edge_attr.shape[0]] for d in data)
x = torch.rand(10, 2)
graph = KNNGraph(x=x, pos=pos, build_edge_attr=False, k=3)
assert len(graph.data) == 3
data = graph.data
assert all(torch.isclose(d.x, x).all() for d in data)
assert all(torch.isclose(d_.pos, pos_).all() for d_, pos_ in zip(data, pos))
assert all(len(d.edge_index) == 2 for d in data)
graph = KNNGraph(x=x, pos=pos, build_edge_attr=True, k=3)
data = graph.data
assert all(torch.isclose(d.x, x).all() for d in data)
assert all(torch.isclose(d_.pos, pos_).all() for d_, pos_ in zip(data, pos))
assert all(len(d.edge_index) == 2 for d in data)
assert all(d.edge_attr is not None for d in data)
assert all([d.edge_index.shape[1] == d.edge_attr.shape[0]] for d in data)
def test_additional_parameters_1():
x = torch.rand(3, 10, 2)
pos = torch.rand(3, 10, 2)
additional_parameters = {'y': torch.ones(3)}
graph = RadiusGraph(x=x, pos=pos, build_edge_attr=True, r=.3,
additional_params=additional_parameters)
assert len(graph.data) == 3
data = graph.data
assert all(torch.isclose(d_.x, x_).all() for (d_, x_) in zip(data, x))
assert all(hasattr(d, 'y') for d in data)
assert all(d_.y == 1 for d_ in data)
@pytest.mark.parametrize(
"additional_parameters",
[
({'y': torch.rand(3,10,1)}),
({'y': [torch.rand(10,1) for _ in range(3)]}),
]
)
def test_additional_parameters_2(additional_parameters):
x = torch.rand(3, 10, 2)
pos = torch.rand(3, 10, 2)
graph = RadiusGraph(x=x, pos=pos, build_edge_attr=True, r=.3,
additional_params=additional_parameters)
assert len(graph.data) == 3
data = graph.data
assert all(torch.isclose(d_.x, x_).all() for (d_, x_) in zip(data, x))
assert all(hasattr(d, 'y') for d in data)
assert all(torch.isclose(d_.x, x_).all() for (d_, x_) in zip(data, x))