Refactoring for 0.2 * Data module, data loader and dataset * Refactor LabelTensor * Refactor solvers Co-authored-by: dario-coscia <dariocos99@gmail.com>
118 lines
3.5 KiB
Python
118 lines
3.5 KiB
Python
""" Module for Loss class """
|
|
|
|
import logging
|
|
from torch_geometric.nn import MessagePassing, InstanceNorm, radius_graph
|
|
from torch_geometric.data import Data
|
|
import torch
|
|
|
|
class Graph:
|
|
"""
|
|
PINA Graph managing the PyG Data class.
|
|
"""
|
|
def __init__(self, data):
|
|
self.data = data
|
|
|
|
@staticmethod
|
|
def _build_triangulation(**kwargs):
|
|
logging.debug("Creating graph with triangulation mode.")
|
|
|
|
# check for mandatory arguments
|
|
if "nodes_coordinates" not in kwargs:
|
|
raise ValueError("Nodes coordinates must be provided in the kwargs.")
|
|
if "nodes_data" not in kwargs:
|
|
raise ValueError("Nodes data must be provided in the kwargs.")
|
|
if "triangles" not in kwargs:
|
|
raise ValueError("Triangles must be provided in the kwargs.")
|
|
|
|
nodes_coordinates = kwargs["nodes_coordinates"]
|
|
nodes_data = kwargs["nodes_data"]
|
|
triangles = kwargs["triangles"]
|
|
|
|
|
|
|
|
def less_first(a, b):
|
|
return [a, b] if a < b else [b, a]
|
|
|
|
list_of_edges = []
|
|
|
|
for triangle in triangles:
|
|
for e1, e2 in [[0, 1], [1, 2], [2, 0]]:
|
|
list_of_edges.append(less_first(triangle[e1],triangle[e2]))
|
|
|
|
array_of_edges = torch.unique(torch.Tensor(list_of_edges), dim=0) # remove duplicates
|
|
array_of_edges = array_of_edges.t().contiguous()
|
|
print(array_of_edges)
|
|
|
|
# list_of_lengths = []
|
|
|
|
# for p1,p2 in array_of_edges:
|
|
# x1, y1 = tri.points[p1]
|
|
# x2, y2 = tri.points[p2]
|
|
# list_of_lengths.append((x1-x2)**2 + (y1-y2)**2)
|
|
|
|
# array_of_lengths = np.sqrt(np.array(list_of_lengths))
|
|
|
|
# return array_of_edges, array_of_lengths
|
|
|
|
return Data(
|
|
x=nodes_data,
|
|
pos=nodes_coordinates.T,
|
|
|
|
edge_index=array_of_edges,
|
|
)
|
|
|
|
@staticmethod
|
|
def _build_radius(**kwargs):
|
|
logging.debug("Creating graph with radius mode.")
|
|
|
|
# check for mandatory arguments
|
|
if "nodes_coordinates" not in kwargs:
|
|
raise ValueError("Nodes coordinates must be provided in the kwargs.")
|
|
if "nodes_data" not in kwargs:
|
|
raise ValueError("Nodes data must be provided in the kwargs.")
|
|
if "radius" not in kwargs:
|
|
raise ValueError("Radius must be provided in the kwargs.")
|
|
|
|
nodes_coordinates = kwargs["nodes_coordinates"]
|
|
nodes_data = kwargs["nodes_data"]
|
|
radius = kwargs["radius"]
|
|
|
|
edges_data = kwargs.get("edge_data", None)
|
|
loop = kwargs.get("loop", False)
|
|
batch = kwargs.get("batch", None)
|
|
|
|
logging.debug(f"radius: {radius}, loop: {loop}, "
|
|
f"batch: {batch}")
|
|
|
|
edge_index = radius_graph(
|
|
x=nodes_coordinates.tensor,
|
|
r=radius,
|
|
loop=loop,
|
|
batch=batch,
|
|
)
|
|
|
|
logging.debug(f"edge_index computed")
|
|
return Data(
|
|
x=nodes_data.tensor,
|
|
pos=nodes_coordinates.tensor,
|
|
edge_index=edge_index,
|
|
edge_attr=edges_data,
|
|
)
|
|
|
|
@staticmethod
|
|
def build(mode, **kwargs):
|
|
"""
|
|
Constructor for the `Graph` class.
|
|
"""
|
|
if mode == "radius":
|
|
graph = Graph._build_radius(**kwargs)
|
|
elif mode == "triangulation":
|
|
graph = Graph._build_triangulation(**kwargs)
|
|
else:
|
|
raise ValueError(f"Mode {mode} not recognized")
|
|
|
|
return Graph(graph)
|
|
|
|
|
|
def __repr__(self):
|
|
return f"Graph(data={self.data})" |