217 lines
6.4 KiB
Python
217 lines
6.4 KiB
Python
import pytest
|
|
import torch
|
|
from pina.model.block.message_passing import EnEquivariantNetworkBlock
|
|
|
|
# Data for testing
|
|
x = torch.rand(10, 4)
|
|
pos = torch.rand(10, 3)
|
|
velocity = torch.rand(10, 3)
|
|
edge_idx = torch.randint(0, 10, (2, 20))
|
|
edge_attributes = torch.randn(20, 2)
|
|
|
|
|
|
@pytest.mark.parametrize("node_feature_dim", [1, 3])
|
|
@pytest.mark.parametrize("edge_feature_dim", [0, 2])
|
|
@pytest.mark.parametrize("pos_dim", [2, 3])
|
|
@pytest.mark.parametrize("use_velocity", [True, False])
|
|
def test_constructor(node_feature_dim, edge_feature_dim, pos_dim, use_velocity):
|
|
|
|
EnEquivariantNetworkBlock(
|
|
node_feature_dim=node_feature_dim,
|
|
edge_feature_dim=edge_feature_dim,
|
|
pos_dim=pos_dim,
|
|
use_velocity=use_velocity,
|
|
hidden_dim=64,
|
|
n_message_layers=2,
|
|
n_update_layers=2,
|
|
)
|
|
|
|
# Should fail if node_feature_dim is negative
|
|
with pytest.raises(AssertionError):
|
|
EnEquivariantNetworkBlock(
|
|
node_feature_dim=-1,
|
|
edge_feature_dim=edge_feature_dim,
|
|
pos_dim=pos_dim,
|
|
use_velocity=use_velocity,
|
|
)
|
|
|
|
# Should fail if edge_feature_dim is negative
|
|
with pytest.raises(AssertionError):
|
|
EnEquivariantNetworkBlock(
|
|
node_feature_dim=node_feature_dim,
|
|
edge_feature_dim=-1,
|
|
pos_dim=pos_dim,
|
|
use_velocity=use_velocity,
|
|
)
|
|
|
|
# Should fail if pos_dim is negative
|
|
with pytest.raises(AssertionError):
|
|
EnEquivariantNetworkBlock(
|
|
node_feature_dim=node_feature_dim,
|
|
edge_feature_dim=edge_feature_dim,
|
|
pos_dim=-1,
|
|
use_velocity=use_velocity,
|
|
)
|
|
|
|
# Should fail if hidden_dim is negative
|
|
with pytest.raises(AssertionError):
|
|
EnEquivariantNetworkBlock(
|
|
node_feature_dim=node_feature_dim,
|
|
edge_feature_dim=edge_feature_dim,
|
|
pos_dim=pos_dim,
|
|
hidden_dim=-1,
|
|
use_velocity=use_velocity,
|
|
)
|
|
|
|
# Should fail if n_message_layers is negative
|
|
with pytest.raises(AssertionError):
|
|
EnEquivariantNetworkBlock(
|
|
node_feature_dim=node_feature_dim,
|
|
edge_feature_dim=edge_feature_dim,
|
|
pos_dim=pos_dim,
|
|
n_message_layers=-1,
|
|
use_velocity=use_velocity,
|
|
)
|
|
|
|
# Should fail if n_update_layers is negative
|
|
with pytest.raises(AssertionError):
|
|
EnEquivariantNetworkBlock(
|
|
node_feature_dim=node_feature_dim,
|
|
edge_feature_dim=edge_feature_dim,
|
|
pos_dim=pos_dim,
|
|
n_update_layers=-1,
|
|
use_velocity=use_velocity,
|
|
)
|
|
|
|
# Should fail if use_velocity is not boolean
|
|
with pytest.raises(ValueError):
|
|
EnEquivariantNetworkBlock(
|
|
node_feature_dim=node_feature_dim,
|
|
edge_feature_dim=edge_feature_dim,
|
|
pos_dim=pos_dim,
|
|
use_velocity="False",
|
|
)
|
|
|
|
|
|
@pytest.mark.parametrize("edge_feature_dim", [0, 2])
|
|
@pytest.mark.parametrize("use_velocity", [True, False])
|
|
def test_forward(edge_feature_dim, use_velocity):
|
|
|
|
model = EnEquivariantNetworkBlock(
|
|
node_feature_dim=x.shape[1],
|
|
edge_feature_dim=edge_feature_dim,
|
|
pos_dim=pos.shape[1],
|
|
hidden_dim=64,
|
|
n_message_layers=2,
|
|
n_update_layers=2,
|
|
use_velocity=use_velocity,
|
|
)
|
|
|
|
# Manage inputs
|
|
vel = velocity if use_velocity else None
|
|
edge_attr = edge_attributes if edge_feature_dim > 0 else None
|
|
|
|
# Checks on output shapes
|
|
output_ = model(
|
|
x=x, pos=pos, edge_index=edge_idx, edge_attr=edge_attr, vel=vel
|
|
)
|
|
assert output_[0].shape == x.shape
|
|
assert output_[1].shape == pos.shape
|
|
if vel is not None:
|
|
assert output_[2].shape == vel.shape
|
|
|
|
|
|
@pytest.mark.parametrize("edge_feature_dim", [0, 2])
|
|
@pytest.mark.parametrize("use_velocity", [True, False])
|
|
def test_backward(edge_feature_dim, use_velocity):
|
|
|
|
model = EnEquivariantNetworkBlock(
|
|
node_feature_dim=x.shape[1],
|
|
edge_feature_dim=edge_feature_dim,
|
|
pos_dim=pos.shape[1],
|
|
hidden_dim=64,
|
|
n_message_layers=2,
|
|
n_update_layers=2,
|
|
use_velocity=use_velocity,
|
|
)
|
|
|
|
# Manage inputs
|
|
vel = velocity.requires_grad_() if use_velocity else None
|
|
edge_attr = (
|
|
edge_attributes.requires_grad_() if edge_feature_dim > 0 else None
|
|
)
|
|
|
|
if edge_feature_dim == 0:
|
|
output_ = model(
|
|
edge_index=edge_idx,
|
|
x=x.requires_grad_(),
|
|
pos=pos.requires_grad_(),
|
|
vel=vel,
|
|
)
|
|
else:
|
|
output_ = model(
|
|
edge_index=edge_idx,
|
|
x=x.requires_grad_(),
|
|
pos=pos.requires_grad_(),
|
|
edge_attr=edge_attr,
|
|
vel=vel,
|
|
)
|
|
|
|
# Checks on gradients
|
|
loss = sum(torch.mean(output_[i]) for i in range(len(output_)))
|
|
loss.backward()
|
|
assert x.grad.shape == x.shape
|
|
assert pos.grad.shape == pos.shape
|
|
if use_velocity:
|
|
assert vel.grad.shape == vel.shape
|
|
|
|
|
|
@pytest.mark.parametrize("edge_feature_dim", [0, 2])
|
|
@pytest.mark.parametrize("use_velocity", [True, False])
|
|
def test_equivariance(edge_feature_dim, use_velocity):
|
|
|
|
# Random rotation
|
|
rotation = torch.linalg.qr(torch.rand(pos.shape[-1], pos.shape[-1])).Q
|
|
if torch.det(rotation) < 0:
|
|
rotation[:, 0] *= -1
|
|
|
|
# Random translation
|
|
translation = torch.rand(1, pos.shape[-1])
|
|
|
|
model = EnEquivariantNetworkBlock(
|
|
node_feature_dim=x.shape[1],
|
|
edge_feature_dim=edge_feature_dim,
|
|
pos_dim=pos.shape[1],
|
|
hidden_dim=64,
|
|
n_message_layers=2,
|
|
n_update_layers=2,
|
|
use_velocity=use_velocity,
|
|
).eval()
|
|
|
|
# Manage inputs
|
|
vel = velocity if use_velocity else None
|
|
edge_attr = edge_attributes if edge_feature_dim > 0 else None
|
|
|
|
# Transform inputs (no translation for velocity)
|
|
pos_rot = pos @ rotation.T + translation
|
|
vel_rot = vel @ rotation.T if use_velocity else vel
|
|
|
|
# Get model outputs
|
|
out1 = model(
|
|
x=x, pos=pos, edge_index=edge_idx, edge_attr=edge_attr, vel=vel
|
|
)
|
|
out2 = model(
|
|
x=x, pos=pos_rot, edge_index=edge_idx, edge_attr=edge_attr, vel=vel_rot
|
|
)
|
|
|
|
# Unpack outputs
|
|
h1, pos1, *other1 = out1
|
|
h2, pos2, *other2 = out2
|
|
if use_velocity:
|
|
vel1, vel2 = other1[0], other2[0]
|
|
|
|
assert torch.allclose(pos2, pos1 @ rotation.T + translation, atol=1e-5)
|
|
assert torch.allclose(h1, h2, atol=1e-5)
|
|
if vel is not None:
|
|
assert torch.allclose(vel2, vel1 @ rotation.T, atol=1e-5)
|