Compare commits
1 Commits
refact-dat
...
fix-codacy
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
40747c56ff |
@@ -41,7 +41,7 @@ class EnEquivariantNetworkBlock(MessagePassing):
|
||||
DOI: `<https://doi.org/10.48550/arXiv.2102.09844>`_.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
def __init__( # pylint: disable=R0913, R0917
|
||||
self,
|
||||
node_feature_dim,
|
||||
edge_feature_dim,
|
||||
@@ -143,7 +143,9 @@ class EnEquivariantNetworkBlock(MessagePassing):
|
||||
func=activation,
|
||||
)
|
||||
|
||||
def forward(self, x, pos, edge_index, edge_attr=None, vel=None):
|
||||
def forward(
|
||||
self, x, pos, edge_index, edge_attr=None, vel=None
|
||||
): # pylint: disable=R0917
|
||||
"""
|
||||
Forward pass of the block, triggering the message-passing routine.
|
||||
|
||||
@@ -169,7 +171,9 @@ class EnEquivariantNetworkBlock(MessagePassing):
|
||||
edge_index=edge_index, x=x, pos=pos, edge_attr=edge_attr, vel=vel
|
||||
)
|
||||
|
||||
def message(self, x_i, x_j, pos_i, pos_j, edge_attr):
|
||||
def message(
|
||||
self, x_i, x_j, pos_i, pos_j, edge_attr
|
||||
): # pylint: disable=R0917
|
||||
"""
|
||||
Compute the message to be passed between nodes and edges.
|
||||
|
||||
@@ -234,7 +238,9 @@ class EnEquivariantNetworkBlock(MessagePassing):
|
||||
|
||||
return agg_message, agg_m_ij
|
||||
|
||||
def update(self, aggregated_inputs, x, pos, edge_index, vel):
|
||||
def update(
|
||||
self, aggregated_inputs, x, pos, edge_index, vel
|
||||
): # pylint: disable=R0917
|
||||
"""
|
||||
Update node features, positions, and optionally velocities.
|
||||
|
||||
|
||||
@@ -23,7 +23,7 @@ class EquivariantGraphNeuralOperatorBlock(torch.nn.Module):
|
||||
<https://arxiv.org/abs/2401.11037>`_
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
def __init__( # pylint: disable=R0913, R0917
|
||||
self,
|
||||
node_feature_dim,
|
||||
edge_feature_dim,
|
||||
@@ -101,7 +101,9 @@ class EquivariantGraphNeuralOperatorBlock(torch.nn.Module):
|
||||
flow=flow,
|
||||
)
|
||||
|
||||
def forward(self, x, pos, vel, edge_index, edge_attr=None):
|
||||
def forward( # pylint: disable=R0917
|
||||
self, x, pos, vel, edge_index, edge_attr=None
|
||||
):
|
||||
"""
|
||||
Forward pass of the Equivariant Graph Neural Operator block.
|
||||
|
||||
@@ -182,7 +184,11 @@ class EquivariantGraphNeuralOperatorBlock(torch.nn.Module):
|
||||
weights = torch.complex(real[..., :modes], img[..., :modes])
|
||||
|
||||
# Convolution in Fourier space
|
||||
fourier = torch.fft.rfftn(x, dim=[0])[:modes]
|
||||
# torch.fft.rfftn and irfftn are callable functions, but pylint
|
||||
# incorrectly flags them as E1102 (not callable).
|
||||
fourier = torch.fft.rfftn(x, dim=[0])[:modes] # pylint: disable=E1102
|
||||
out = torch.einsum(einsum_idx, fourier, weights)
|
||||
|
||||
return torch.fft.irfftn(out, s=x.shape[0], dim=0)
|
||||
return torch.fft.irfftn( # pylint: disable=E1102
|
||||
out, s=x.shape[0], dim=0
|
||||
)
|
||||
|
||||
@@ -5,7 +5,9 @@ from ..utils import check_positive_integer
|
||||
from .block.message_passing import EquivariantGraphNeuralOperatorBlock
|
||||
|
||||
|
||||
class EquivariantGraphNeuralOperator(torch.nn.Module):
|
||||
# Disable pylint warnings for too few public methods (since this is a simple
|
||||
# model class in a standard PyTorch style)
|
||||
class EquivariantGraphNeuralOperator(torch.nn.Module): # pylint: disable=R0903
|
||||
"""
|
||||
Equivariant Graph Neural Operator (EGNO) for modeling 3D dynamics.
|
||||
|
||||
@@ -32,7 +34,9 @@ class EquivariantGraphNeuralOperator(torch.nn.Module):
|
||||
<https://arxiv.org/abs/2401.11037>`_
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
# Disable pylint warnings for too many arguments in init (since this is a
|
||||
# model class with many configurable parameters)
|
||||
def __init__( # pylint: disable=R0913, R0917, R0914
|
||||
self,
|
||||
n_egno_layers,
|
||||
node_feature_dim,
|
||||
|
||||
Reference in New Issue
Block a user