add egno (#602)
Co-authored-by: GiovanniCanali <giovanni.canali98@yahoo.it>
This commit is contained in:
@@ -5,19 +5,22 @@ from pina.model.block.message_passing import EnEquivariantNetworkBlock
|
||||
# Data for testing
|
||||
x = torch.rand(10, 4)
|
||||
pos = torch.rand(10, 3)
|
||||
edge_index = torch.randint(0, 10, (2, 20))
|
||||
edge_attr = torch.randn(20, 2)
|
||||
velocity = torch.rand(10, 3)
|
||||
edge_idx = torch.randint(0, 10, (2, 20))
|
||||
edge_attributes = torch.randn(20, 2)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("node_feature_dim", [1, 3])
|
||||
@pytest.mark.parametrize("edge_feature_dim", [0, 2])
|
||||
@pytest.mark.parametrize("pos_dim", [2, 3])
|
||||
def test_constructor(node_feature_dim, edge_feature_dim, pos_dim):
|
||||
@pytest.mark.parametrize("use_velocity", [True, False])
|
||||
def test_constructor(node_feature_dim, edge_feature_dim, pos_dim, use_velocity):
|
||||
|
||||
EnEquivariantNetworkBlock(
|
||||
node_feature_dim=node_feature_dim,
|
||||
edge_feature_dim=edge_feature_dim,
|
||||
pos_dim=pos_dim,
|
||||
use_velocity=use_velocity,
|
||||
hidden_dim=64,
|
||||
n_message_layers=2,
|
||||
n_update_layers=2,
|
||||
@@ -29,6 +32,7 @@ def test_constructor(node_feature_dim, edge_feature_dim, pos_dim):
|
||||
node_feature_dim=-1,
|
||||
edge_feature_dim=edge_feature_dim,
|
||||
pos_dim=pos_dim,
|
||||
use_velocity=use_velocity,
|
||||
)
|
||||
|
||||
# Should fail if edge_feature_dim is negative
|
||||
@@ -37,6 +41,7 @@ def test_constructor(node_feature_dim, edge_feature_dim, pos_dim):
|
||||
node_feature_dim=node_feature_dim,
|
||||
edge_feature_dim=-1,
|
||||
pos_dim=pos_dim,
|
||||
use_velocity=use_velocity,
|
||||
)
|
||||
|
||||
# Should fail if pos_dim is negative
|
||||
@@ -45,6 +50,7 @@ def test_constructor(node_feature_dim, edge_feature_dim, pos_dim):
|
||||
node_feature_dim=node_feature_dim,
|
||||
edge_feature_dim=edge_feature_dim,
|
||||
pos_dim=-1,
|
||||
use_velocity=use_velocity,
|
||||
)
|
||||
|
||||
# Should fail if hidden_dim is negative
|
||||
@@ -54,6 +60,7 @@ def test_constructor(node_feature_dim, edge_feature_dim, pos_dim):
|
||||
edge_feature_dim=edge_feature_dim,
|
||||
pos_dim=pos_dim,
|
||||
hidden_dim=-1,
|
||||
use_velocity=use_velocity,
|
||||
)
|
||||
|
||||
# Should fail if n_message_layers is negative
|
||||
@@ -63,6 +70,7 @@ def test_constructor(node_feature_dim, edge_feature_dim, pos_dim):
|
||||
edge_feature_dim=edge_feature_dim,
|
||||
pos_dim=pos_dim,
|
||||
n_message_layers=-1,
|
||||
use_velocity=use_velocity,
|
||||
)
|
||||
|
||||
# Should fail if n_update_layers is negative
|
||||
@@ -72,11 +80,22 @@ def test_constructor(node_feature_dim, edge_feature_dim, pos_dim):
|
||||
edge_feature_dim=edge_feature_dim,
|
||||
pos_dim=pos_dim,
|
||||
n_update_layers=-1,
|
||||
use_velocity=use_velocity,
|
||||
)
|
||||
|
||||
# Should fail if use_velocity is not boolean
|
||||
with pytest.raises(ValueError):
|
||||
EnEquivariantNetworkBlock(
|
||||
node_feature_dim=node_feature_dim,
|
||||
edge_feature_dim=edge_feature_dim,
|
||||
pos_dim=pos_dim,
|
||||
use_velocity="False",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("edge_feature_dim", [0, 2])
|
||||
def test_forward(edge_feature_dim):
|
||||
@pytest.mark.parametrize("use_velocity", [True, False])
|
||||
def test_forward(edge_feature_dim, use_velocity):
|
||||
|
||||
model = EnEquivariantNetworkBlock(
|
||||
node_feature_dim=x.shape[1],
|
||||
@@ -85,21 +104,26 @@ def test_forward(edge_feature_dim):
|
||||
hidden_dim=64,
|
||||
n_message_layers=2,
|
||||
n_update_layers=2,
|
||||
use_velocity=use_velocity,
|
||||
)
|
||||
|
||||
if edge_feature_dim == 0:
|
||||
output_ = model(edge_index=edge_index, x=x, pos=pos)
|
||||
else:
|
||||
output_ = model(
|
||||
edge_index=edge_index, x=x, pos=pos, edge_attr=edge_attr
|
||||
)
|
||||
# Manage inputs
|
||||
vel = velocity if use_velocity else None
|
||||
edge_attr = edge_attributes if edge_feature_dim > 0 else None
|
||||
|
||||
# Checks on output shapes
|
||||
output_ = model(
|
||||
x=x, pos=pos, edge_index=edge_idx, edge_attr=edge_attr, vel=vel
|
||||
)
|
||||
assert output_[0].shape == x.shape
|
||||
assert output_[1].shape == pos.shape
|
||||
if vel is not None:
|
||||
assert output_[2].shape == vel.shape
|
||||
|
||||
|
||||
@pytest.mark.parametrize("edge_feature_dim", [0, 2])
|
||||
def test_backward(edge_feature_dim):
|
||||
@pytest.mark.parametrize("use_velocity", [True, False])
|
||||
def test_backward(edge_feature_dim, use_velocity):
|
||||
|
||||
model = EnEquivariantNetworkBlock(
|
||||
node_feature_dim=x.shape[1],
|
||||
@@ -108,35 +132,45 @@ def test_backward(edge_feature_dim):
|
||||
hidden_dim=64,
|
||||
n_message_layers=2,
|
||||
n_update_layers=2,
|
||||
use_velocity=use_velocity,
|
||||
)
|
||||
|
||||
# Manage inputs
|
||||
vel = velocity.requires_grad_() if use_velocity else None
|
||||
edge_attr = (
|
||||
edge_attributes.requires_grad_() if edge_feature_dim > 0 else None
|
||||
)
|
||||
|
||||
if edge_feature_dim == 0:
|
||||
output_ = model(
|
||||
edge_index=edge_index,
|
||||
edge_index=edge_idx,
|
||||
x=x.requires_grad_(),
|
||||
pos=pos.requires_grad_(),
|
||||
vel=vel,
|
||||
)
|
||||
else:
|
||||
output_ = model(
|
||||
edge_index=edge_index,
|
||||
edge_index=edge_idx,
|
||||
x=x.requires_grad_(),
|
||||
pos=pos.requires_grad_(),
|
||||
edge_attr=edge_attr.requires_grad_(),
|
||||
edge_attr=edge_attr,
|
||||
vel=vel,
|
||||
)
|
||||
|
||||
loss = torch.mean(output_[0])
|
||||
# Checks on gradients
|
||||
loss = sum(torch.mean(output_[i]) for i in range(len(output_)))
|
||||
loss.backward()
|
||||
assert x.grad.shape == x.shape
|
||||
assert pos.grad.shape == pos.shape
|
||||
if use_velocity:
|
||||
assert vel.grad.shape == vel.shape
|
||||
|
||||
|
||||
def test_equivariance():
|
||||
@pytest.mark.parametrize("edge_feature_dim", [0, 2])
|
||||
@pytest.mark.parametrize("use_velocity", [True, False])
|
||||
def test_equivariance(edge_feature_dim, use_velocity):
|
||||
|
||||
# Graph to be fully connected and undirected
|
||||
edge_index = torch.combinations(torch.arange(x.shape[0]), r=2).T
|
||||
edge_index = torch.cat([edge_index, edge_index.flip(0)], dim=1)
|
||||
|
||||
# Random rotation (det(rotation) should be 1)
|
||||
# Random rotation
|
||||
rotation = torch.linalg.qr(torch.rand(pos.shape[-1], pos.shape[-1])).Q
|
||||
if torch.det(rotation) < 0:
|
||||
rotation[:, 0] *= -1
|
||||
@@ -146,20 +180,37 @@ def test_equivariance():
|
||||
|
||||
model = EnEquivariantNetworkBlock(
|
||||
node_feature_dim=x.shape[1],
|
||||
edge_feature_dim=0,
|
||||
edge_feature_dim=edge_feature_dim,
|
||||
pos_dim=pos.shape[1],
|
||||
hidden_dim=64,
|
||||
n_message_layers=2,
|
||||
n_update_layers=2,
|
||||
use_velocity=use_velocity,
|
||||
).eval()
|
||||
|
||||
h1, pos1 = model(edge_index=edge_index, x=x, pos=pos)
|
||||
h2, pos2 = model(
|
||||
edge_index=edge_index, x=x, pos=pos @ rotation.T + translation
|
||||
# Manage inputs
|
||||
vel = velocity if use_velocity else None
|
||||
edge_attr = edge_attributes if edge_feature_dim > 0 else None
|
||||
|
||||
# Transform inputs (no translation for velocity)
|
||||
pos_rot = pos @ rotation.T + translation
|
||||
vel_rot = vel @ rotation.T if use_velocity else vel
|
||||
|
||||
# Get model outputs
|
||||
out1 = model(
|
||||
x=x, pos=pos, edge_index=edge_idx, edge_attr=edge_attr, vel=vel
|
||||
)
|
||||
out2 = model(
|
||||
x=x, pos=pos_rot, edge_index=edge_idx, edge_attr=edge_attr, vel=vel_rot
|
||||
)
|
||||
|
||||
# Transform model output
|
||||
pos1_transformed = (pos1 @ rotation.T) + translation
|
||||
# Unpack outputs
|
||||
h1, pos1, *other1 = out1
|
||||
h2, pos2, *other2 = out2
|
||||
if use_velocity:
|
||||
vel1, vel2 = other1[0], other2[0]
|
||||
|
||||
assert torch.allclose(pos2, pos1_transformed, atol=1e-5)
|
||||
assert torch.allclose(pos2, pos1 @ rotation.T + translation, atol=1e-5)
|
||||
assert torch.allclose(h1, h2, atol=1e-5)
|
||||
if vel is not None:
|
||||
assert torch.allclose(vel2, vel1 @ rotation.T, atol=1e-5)
|
||||
|
||||
132
tests/test_messagepassing/test_equivariant_operator_block.py
Normal file
132
tests/test_messagepassing/test_equivariant_operator_block.py
Normal file
@@ -0,0 +1,132 @@
|
||||
import pytest
|
||||
import torch
|
||||
from pina.model.block.message_passing import EquivariantGraphNeuralOperatorBlock
|
||||
|
||||
# Data for testing. Shapes: (time, nodes, features)
|
||||
x = torch.rand(5, 10, 4)
|
||||
pos = torch.rand(5, 10, 3)
|
||||
vel = torch.rand(5, 10, 3)
|
||||
|
||||
# Edge index and attributes
|
||||
edge_idx = torch.randint(0, 10, (2, 20))
|
||||
edge_attributes = torch.randn(20, 2)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("node_feature_dim", [1, 3])
|
||||
@pytest.mark.parametrize("edge_feature_dim", [0, 2])
|
||||
@pytest.mark.parametrize("pos_dim", [2, 3])
|
||||
@pytest.mark.parametrize("modes", [1, 5])
|
||||
def test_constructor(node_feature_dim, edge_feature_dim, pos_dim, modes):
|
||||
|
||||
EquivariantGraphNeuralOperatorBlock(
|
||||
node_feature_dim=node_feature_dim,
|
||||
edge_feature_dim=edge_feature_dim,
|
||||
pos_dim=pos_dim,
|
||||
modes=modes,
|
||||
)
|
||||
|
||||
# Should fail if modes is negative
|
||||
with pytest.raises(AssertionError):
|
||||
EquivariantGraphNeuralOperatorBlock(
|
||||
node_feature_dim=node_feature_dim,
|
||||
edge_feature_dim=edge_feature_dim,
|
||||
pos_dim=pos_dim,
|
||||
modes=-1,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("modes", [1, 5])
|
||||
def test_forward(modes):
|
||||
|
||||
model = EquivariantGraphNeuralOperatorBlock(
|
||||
node_feature_dim=x.shape[2],
|
||||
edge_feature_dim=edge_attributes.shape[1],
|
||||
pos_dim=pos.shape[2],
|
||||
modes=modes,
|
||||
)
|
||||
|
||||
output_ = model(
|
||||
x=x,
|
||||
pos=pos,
|
||||
vel=vel,
|
||||
edge_index=edge_idx,
|
||||
edge_attr=edge_attributes,
|
||||
)
|
||||
|
||||
# Checks on output shapes
|
||||
assert output_[0].shape == x.shape
|
||||
assert output_[1].shape == pos.shape
|
||||
assert output_[2].shape == vel.shape
|
||||
|
||||
|
||||
@pytest.mark.parametrize("modes", [1, 5])
|
||||
def test_backward(modes):
|
||||
|
||||
model = EquivariantGraphNeuralOperatorBlock(
|
||||
node_feature_dim=x.shape[2],
|
||||
edge_feature_dim=edge_attributes.shape[1],
|
||||
pos_dim=pos.shape[2],
|
||||
modes=modes,
|
||||
)
|
||||
|
||||
output_ = model(
|
||||
x=x.requires_grad_(),
|
||||
pos=pos.requires_grad_(),
|
||||
vel=vel.requires_grad_(),
|
||||
edge_index=edge_idx,
|
||||
edge_attr=edge_attributes.requires_grad_(),
|
||||
)
|
||||
|
||||
# Checks on gradients
|
||||
loss = sum(torch.mean(output_[i]) for i in range(len(output_)))
|
||||
loss.backward()
|
||||
assert x.grad.shape == x.shape
|
||||
assert pos.grad.shape == pos.shape
|
||||
assert vel.grad.shape == vel.shape
|
||||
|
||||
|
||||
@pytest.mark.parametrize("modes", [1, 5])
|
||||
def test_equivariance(modes):
|
||||
|
||||
# Random rotation
|
||||
rotation = torch.linalg.qr(torch.rand(pos.shape[2], pos.shape[2])).Q
|
||||
if torch.det(rotation) < 0:
|
||||
rotation[:, 0] *= -1
|
||||
|
||||
# Random translation
|
||||
translation = torch.rand(1, pos.shape[2])
|
||||
|
||||
model = EquivariantGraphNeuralOperatorBlock(
|
||||
node_feature_dim=x.shape[2],
|
||||
edge_feature_dim=edge_attributes.shape[1],
|
||||
pos_dim=pos.shape[2],
|
||||
modes=modes,
|
||||
).eval()
|
||||
|
||||
# Transform inputs (no translation for velocity)
|
||||
pos_rot = pos @ rotation.T + translation
|
||||
vel_rot = vel @ rotation.T
|
||||
|
||||
# Get model outputs
|
||||
out1 = model(
|
||||
x=x,
|
||||
pos=pos,
|
||||
vel=vel,
|
||||
edge_index=edge_idx,
|
||||
edge_attr=edge_attributes,
|
||||
)
|
||||
out2 = model(
|
||||
x=x,
|
||||
pos=pos_rot,
|
||||
vel=vel_rot,
|
||||
edge_index=edge_idx,
|
||||
edge_attr=edge_attributes,
|
||||
)
|
||||
|
||||
# Unpack outputs
|
||||
h1, pos1, vel1 = out1
|
||||
h2, pos2, vel2 = out2
|
||||
|
||||
assert torch.allclose(pos2, pos1 @ rotation.T + translation, atol=1e-5)
|
||||
assert torch.allclose(vel2, vel1 @ rotation.T, atol=1e-5)
|
||||
assert torch.allclose(h1, h2, atol=1e-5)
|
||||
194
tests/test_model/test_equivariant_graph_neural_operator.py
Normal file
194
tests/test_model/test_equivariant_graph_neural_operator.py
Normal file
@@ -0,0 +1,194 @@
|
||||
import pytest
|
||||
import torch
|
||||
import copy
|
||||
from pina.model import EquivariantGraphNeuralOperator
|
||||
from pina.graph import Graph
|
||||
|
||||
|
||||
# Utility to create graphs
|
||||
def make_graph(include_vel=True, use_edge_attr=True):
|
||||
data = dict(
|
||||
x=torch.rand(10, 4),
|
||||
pos=torch.rand(10, 3),
|
||||
edge_index=torch.randint(0, 10, (2, 20)),
|
||||
edge_attr=torch.randn(20, 2) if use_edge_attr else None,
|
||||
)
|
||||
if include_vel:
|
||||
data["vel"] = torch.rand(10, 3)
|
||||
return Graph(**data)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("n_egno_layers", [1, 3])
|
||||
@pytest.mark.parametrize("time_steps", [1, 3])
|
||||
@pytest.mark.parametrize("time_emb_dim", [4, 8])
|
||||
@pytest.mark.parametrize("max_time_idx", [10, 20])
|
||||
def test_constructor(n_egno_layers, time_steps, time_emb_dim, max_time_idx):
|
||||
|
||||
# Create graph and model
|
||||
graph = make_graph()
|
||||
EquivariantGraphNeuralOperator(
|
||||
n_egno_layers=n_egno_layers,
|
||||
node_feature_dim=graph.x.shape[1],
|
||||
edge_feature_dim=graph.edge_attr.shape[1],
|
||||
pos_dim=graph.pos.shape[1],
|
||||
modes=5,
|
||||
time_steps=time_steps,
|
||||
time_emb_dim=time_emb_dim,
|
||||
max_time_idx=max_time_idx,
|
||||
)
|
||||
|
||||
# Should fail if n_egno_layers is negative
|
||||
with pytest.raises(AssertionError):
|
||||
EquivariantGraphNeuralOperator(
|
||||
n_egno_layers=-1,
|
||||
node_feature_dim=graph.x.shape[1],
|
||||
edge_feature_dim=graph.edge_attr.shape[1],
|
||||
pos_dim=graph.pos.shape[1],
|
||||
modes=5,
|
||||
time_steps=time_steps,
|
||||
time_emb_dim=time_emb_dim,
|
||||
max_time_idx=max_time_idx,
|
||||
)
|
||||
|
||||
# Should fail if time_steps is negative
|
||||
with pytest.raises(AssertionError):
|
||||
EquivariantGraphNeuralOperator(
|
||||
n_egno_layers=n_egno_layers,
|
||||
node_feature_dim=graph.x.shape[1],
|
||||
edge_feature_dim=graph.edge_attr.shape[1],
|
||||
pos_dim=graph.pos.shape[1],
|
||||
modes=5,
|
||||
time_steps=-1,
|
||||
time_emb_dim=time_emb_dim,
|
||||
max_time_idx=max_time_idx,
|
||||
)
|
||||
|
||||
# Should fail if max_time_idx is negative
|
||||
with pytest.raises(AssertionError):
|
||||
EquivariantGraphNeuralOperator(
|
||||
n_egno_layers=n_egno_layers,
|
||||
node_feature_dim=graph.x.shape[1],
|
||||
edge_feature_dim=graph.edge_attr.shape[1],
|
||||
pos_dim=graph.pos.shape[1],
|
||||
modes=5,
|
||||
time_steps=time_steps,
|
||||
time_emb_dim=time_emb_dim,
|
||||
max_time_idx=-1,
|
||||
)
|
||||
|
||||
# Should fail if time_emb_dim is negative
|
||||
with pytest.raises(AssertionError):
|
||||
EquivariantGraphNeuralOperator(
|
||||
n_egno_layers=n_egno_layers,
|
||||
node_feature_dim=graph.x.shape[1],
|
||||
edge_feature_dim=graph.edge_attr.shape[1],
|
||||
pos_dim=graph.pos.shape[1],
|
||||
modes=5,
|
||||
time_steps=time_steps,
|
||||
time_emb_dim=-1,
|
||||
max_time_idx=max_time_idx,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("n_egno_layers", [1, 3])
|
||||
@pytest.mark.parametrize("time_steps", [1, 5])
|
||||
@pytest.mark.parametrize("modes", [1, 3, 10])
|
||||
@pytest.mark.parametrize("use_edge_attr", [True, False])
|
||||
def test_forward(n_egno_layers, time_steps, modes, use_edge_attr):
|
||||
|
||||
# Create graph and model
|
||||
graph = make_graph(use_edge_attr=use_edge_attr)
|
||||
model = EquivariantGraphNeuralOperator(
|
||||
n_egno_layers=n_egno_layers,
|
||||
node_feature_dim=graph.x.shape[1],
|
||||
edge_feature_dim=graph.edge_attr.shape[1] if use_edge_attr else 0,
|
||||
pos_dim=graph.pos.shape[1],
|
||||
modes=modes,
|
||||
time_steps=time_steps,
|
||||
)
|
||||
|
||||
# Checks on output shapes
|
||||
output_ = model(graph)
|
||||
assert output_.x.shape == (time_steps, *graph.x.shape)
|
||||
assert output_.pos.shape == (time_steps, *graph.pos.shape)
|
||||
assert output_.vel.shape == (time_steps, *graph.vel.shape)
|
||||
|
||||
# Should fail graph has no vel attribute
|
||||
with pytest.raises(ValueError):
|
||||
graph_no_vel = make_graph(include_vel=False)
|
||||
model(graph_no_vel)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("n_egno_layers", [1, 3])
|
||||
@pytest.mark.parametrize("time_steps", [1, 5])
|
||||
@pytest.mark.parametrize("modes", [1, 3, 10])
|
||||
@pytest.mark.parametrize("use_edge_attr", [True, False])
|
||||
def test_backward(n_egno_layers, time_steps, modes, use_edge_attr):
|
||||
|
||||
# Create graph and model
|
||||
graph = make_graph(use_edge_attr=use_edge_attr)
|
||||
model = EquivariantGraphNeuralOperator(
|
||||
n_egno_layers=n_egno_layers,
|
||||
node_feature_dim=graph.x.shape[1],
|
||||
edge_feature_dim=graph.edge_attr.shape[1] if use_edge_attr else 0,
|
||||
pos_dim=graph.pos.shape[1],
|
||||
modes=modes,
|
||||
time_steps=time_steps,
|
||||
)
|
||||
|
||||
# Set requires_grad and perform forward pass
|
||||
graph.x.requires_grad_()
|
||||
graph.pos.requires_grad_()
|
||||
graph.vel.requires_grad_()
|
||||
out = model(graph)
|
||||
|
||||
# Checks on gradients
|
||||
loss = torch.mean(out.x) + torch.mean(out.pos) + torch.mean(out.vel)
|
||||
loss.backward()
|
||||
assert graph.x.grad.shape == graph.x.shape
|
||||
assert graph.pos.grad.shape == graph.pos.shape
|
||||
assert graph.vel.grad.shape == graph.vel.shape
|
||||
|
||||
|
||||
@pytest.mark.parametrize("n_egno_layers", [1, 3])
|
||||
@pytest.mark.parametrize("time_steps", [1, 5])
|
||||
@pytest.mark.parametrize("modes", [1, 3, 10])
|
||||
@pytest.mark.parametrize("use_edge_attr", [True, False])
|
||||
def test_equivariance(n_egno_layers, time_steps, modes, use_edge_attr):
|
||||
|
||||
graph = make_graph(use_edge_attr=use_edge_attr)
|
||||
model = EquivariantGraphNeuralOperator(
|
||||
n_egno_layers=n_egno_layers,
|
||||
node_feature_dim=graph.x.shape[1],
|
||||
edge_feature_dim=graph.edge_attr.shape[1] if use_edge_attr else 0,
|
||||
pos_dim=graph.pos.shape[1],
|
||||
modes=modes,
|
||||
time_steps=time_steps,
|
||||
).eval()
|
||||
|
||||
# Random rotation
|
||||
rotation = torch.linalg.qr(
|
||||
torch.rand(graph.pos.shape[1], graph.pos.shape[1])
|
||||
).Q
|
||||
if torch.det(rotation) < 0:
|
||||
rotation[:, 0] *= -1
|
||||
|
||||
# Random translation
|
||||
translation = torch.rand(1, graph.pos.shape[1])
|
||||
|
||||
# Transform graph (no translation for velocity)
|
||||
graph_rot = copy.deepcopy(graph)
|
||||
graph_rot.pos = graph.pos @ rotation.T + translation
|
||||
graph_rot.vel = graph.vel @ rotation.T
|
||||
|
||||
# Get model outputs
|
||||
out1 = model(graph)
|
||||
out2 = model(graph_rot)
|
||||
|
||||
# Unpack outputs
|
||||
h1, pos1, vel1 = out1.x, out1.pos, out1.vel
|
||||
h2, pos2, vel2 = out2.x, out2.pos, out2.vel
|
||||
|
||||
assert torch.allclose(pos2, pos1 @ rotation.T + translation, atol=1e-5)
|
||||
assert torch.allclose(vel2, vel1 @ rotation.T, atol=1e-5)
|
||||
assert torch.allclose(h1, h2, atol=1e-5)
|
||||
Reference in New Issue
Block a user