Formatting
* Adding black as dev dependency * Formatting pina code * Formatting tests
This commit is contained in:
committed by
Nicola Demo
parent
4c4482b155
commit
42ab1a666b
@@ -8,14 +8,14 @@ class GNOBlock(MessagePassing):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
width,
|
||||
edges_features,
|
||||
n_layers=2,
|
||||
layers=None,
|
||||
inner_size=None,
|
||||
internal_func=None,
|
||||
external_func=None
|
||||
self,
|
||||
width,
|
||||
edges_features,
|
||||
n_layers=2,
|
||||
layers=None,
|
||||
inner_size=None,
|
||||
internal_func=None,
|
||||
external_func=None,
|
||||
):
|
||||
"""
|
||||
Initialize the Graph Integral Layer, inheriting from the MessagePassing class of PyTorch Geometric.
|
||||
@@ -28,16 +28,19 @@ class GNOBlock(MessagePassing):
|
||||
:type n_layers: int
|
||||
"""
|
||||
from pina.model import FeedForward
|
||||
super(GNOBlock, self).__init__(aggr='mean')
|
||||
|
||||
super(GNOBlock, self).__init__(aggr="mean")
|
||||
self.width = width
|
||||
if layers is None and inner_size is None:
|
||||
inner_size = width
|
||||
self.dense = FeedForward(input_dimensions=edges_features,
|
||||
output_dimensions=width ** 2,
|
||||
n_layers=n_layers,
|
||||
layers=layers,
|
||||
inner_size=inner_size,
|
||||
func=internal_func)
|
||||
self.dense = FeedForward(
|
||||
input_dimensions=edges_features,
|
||||
output_dimensions=width**2,
|
||||
n_layers=n_layers,
|
||||
layers=layers,
|
||||
inner_size=inner_size,
|
||||
func=internal_func,
|
||||
)
|
||||
self.W = torch.nn.Linear(width, width)
|
||||
self.func = external_func()
|
||||
|
||||
@@ -53,7 +56,7 @@ class GNOBlock(MessagePassing):
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
x = self.dense(edge_attr).view(-1, self.width, self.width)
|
||||
return torch.einsum('bij,bj->bi', x, x_j)
|
||||
return torch.einsum("bij,bj->bi", x, x_j)
|
||||
|
||||
def update(self, aggr_out, x):
|
||||
"""
|
||||
@@ -82,6 +85,4 @@ class GNOBlock(MessagePassing):
|
||||
:return: Output of a single iteration over the Graph Integral Layer.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
return self.func(
|
||||
self.propagate(edge_index, x=x, edge_attr=edge_attr)
|
||||
)
|
||||
return self.func(self.propagate(edge_index, x=x, edge_attr=edge_attr))
|
||||
|
||||
Reference in New Issue
Block a user