try a new model

This commit is contained in:
Filippo Olivo
2025-11-12 15:20:43 +01:00
parent a2dd348423
commit dc59114f4a
3 changed files with 31 additions and 106 deletions

View File

@@ -122,10 +122,11 @@ class GraphDataModule(LightningDataModule):
edge_index_mask = ~torch.isin(edge_index[1], boundary_idx)
edge_index = edge_index[:, edge_index_mask]
edge_attr = pos[edge_index[0]] - pos[edge_index[1]]
edge_attr = torch.cat(
[edge_attr, torch.norm(edge_attr, dim=1).unsqueeze(-1)], dim=1
)
# edge_attr = pos[edge_index[0]] - pos[edge_index[1]]
# edge_attr = torch.cat(
# [edge_attr, torch.norm(edge_attr, dim=1).unsqueeze(-1)], dim=1
# )
edge_attr = torch.norm(pos[edge_index[0]] - pos[edge_index[1]], dim=1)
x = torch.zeros_like(temperature, dtype=torch.float32).unsqueeze(-1)
if self.remove_boundary_edges: