Compare commits
1 Commits
refact-dat
...
fix-codacy
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
40747c56ff |
@@ -41,7 +41,7 @@ class EnEquivariantNetworkBlock(MessagePassing):
|
|||||||
DOI: `<https://doi.org/10.48550/arXiv.2102.09844>`_.
|
DOI: `<https://doi.org/10.48550/arXiv.2102.09844>`_.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__( # pylint: disable=R0913, R0917
|
||||||
self,
|
self,
|
||||||
node_feature_dim,
|
node_feature_dim,
|
||||||
edge_feature_dim,
|
edge_feature_dim,
|
||||||
@@ -143,7 +143,9 @@ class EnEquivariantNetworkBlock(MessagePassing):
|
|||||||
func=activation,
|
func=activation,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x, pos, edge_index, edge_attr=None, vel=None):
|
def forward(
|
||||||
|
self, x, pos, edge_index, edge_attr=None, vel=None
|
||||||
|
): # pylint: disable=R0917
|
||||||
"""
|
"""
|
||||||
Forward pass of the block, triggering the message-passing routine.
|
Forward pass of the block, triggering the message-passing routine.
|
||||||
|
|
||||||
@@ -169,7 +171,9 @@ class EnEquivariantNetworkBlock(MessagePassing):
|
|||||||
edge_index=edge_index, x=x, pos=pos, edge_attr=edge_attr, vel=vel
|
edge_index=edge_index, x=x, pos=pos, edge_attr=edge_attr, vel=vel
|
||||||
)
|
)
|
||||||
|
|
||||||
def message(self, x_i, x_j, pos_i, pos_j, edge_attr):
|
def message(
|
||||||
|
self, x_i, x_j, pos_i, pos_j, edge_attr
|
||||||
|
): # pylint: disable=R0917
|
||||||
"""
|
"""
|
||||||
Compute the message to be passed between nodes and edges.
|
Compute the message to be passed between nodes and edges.
|
||||||
|
|
||||||
@@ -234,7 +238,9 @@ class EnEquivariantNetworkBlock(MessagePassing):
|
|||||||
|
|
||||||
return agg_message, agg_m_ij
|
return agg_message, agg_m_ij
|
||||||
|
|
||||||
def update(self, aggregated_inputs, x, pos, edge_index, vel):
|
def update(
|
||||||
|
self, aggregated_inputs, x, pos, edge_index, vel
|
||||||
|
): # pylint: disable=R0917
|
||||||
"""
|
"""
|
||||||
Update node features, positions, and optionally velocities.
|
Update node features, positions, and optionally velocities.
|
||||||
|
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ class EquivariantGraphNeuralOperatorBlock(torch.nn.Module):
|
|||||||
<https://arxiv.org/abs/2401.11037>`_
|
<https://arxiv.org/abs/2401.11037>`_
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__( # pylint: disable=R0913, R0917
|
||||||
self,
|
self,
|
||||||
node_feature_dim,
|
node_feature_dim,
|
||||||
edge_feature_dim,
|
edge_feature_dim,
|
||||||
@@ -101,7 +101,9 @@ class EquivariantGraphNeuralOperatorBlock(torch.nn.Module):
|
|||||||
flow=flow,
|
flow=flow,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x, pos, vel, edge_index, edge_attr=None):
|
def forward( # pylint: disable=R0917
|
||||||
|
self, x, pos, vel, edge_index, edge_attr=None
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Forward pass of the Equivariant Graph Neural Operator block.
|
Forward pass of the Equivariant Graph Neural Operator block.
|
||||||
|
|
||||||
@@ -182,7 +184,11 @@ class EquivariantGraphNeuralOperatorBlock(torch.nn.Module):
|
|||||||
weights = torch.complex(real[..., :modes], img[..., :modes])
|
weights = torch.complex(real[..., :modes], img[..., :modes])
|
||||||
|
|
||||||
# Convolution in Fourier space
|
# Convolution in Fourier space
|
||||||
fourier = torch.fft.rfftn(x, dim=[0])[:modes]
|
# torch.fft.rfftn and irfftn are callable functions, but pylint
|
||||||
|
# incorrectly flags them as E1102 (not callable).
|
||||||
|
fourier = torch.fft.rfftn(x, dim=[0])[:modes] # pylint: disable=E1102
|
||||||
out = torch.einsum(einsum_idx, fourier, weights)
|
out = torch.einsum(einsum_idx, fourier, weights)
|
||||||
|
|
||||||
return torch.fft.irfftn(out, s=x.shape[0], dim=0)
|
return torch.fft.irfftn( # pylint: disable=E1102
|
||||||
|
out, s=x.shape[0], dim=0
|
||||||
|
)
|
||||||
|
|||||||
@@ -5,7 +5,9 @@ from ..utils import check_positive_integer
|
|||||||
from .block.message_passing import EquivariantGraphNeuralOperatorBlock
|
from .block.message_passing import EquivariantGraphNeuralOperatorBlock
|
||||||
|
|
||||||
|
|
||||||
class EquivariantGraphNeuralOperator(torch.nn.Module):
|
# Disable pylint warnings for too few public methods (since this is a simple
|
||||||
|
# model class in a standard PyTorch style)
|
||||||
|
class EquivariantGraphNeuralOperator(torch.nn.Module): # pylint: disable=R0903
|
||||||
"""
|
"""
|
||||||
Equivariant Graph Neural Operator (EGNO) for modeling 3D dynamics.
|
Equivariant Graph Neural Operator (EGNO) for modeling 3D dynamics.
|
||||||
|
|
||||||
@@ -32,7 +34,9 @@ class EquivariantGraphNeuralOperator(torch.nn.Module):
|
|||||||
<https://arxiv.org/abs/2401.11037>`_
|
<https://arxiv.org/abs/2401.11037>`_
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
# Disable pylint warnings for too many arguments in init (since this is a
|
||||||
|
# model class with many configurable parameters)
|
||||||
|
def __init__( # pylint: disable=R0913, R0917, R0914
|
||||||
self,
|
self,
|
||||||
n_egno_layers,
|
n_egno_layers,
|
||||||
node_feature_dim,
|
node_feature_dim,
|
||||||
|
|||||||
Reference in New Issue
Block a user