add module and first model

This commit is contained in:
Filippo Olivo
2025-09-24 15:16:41 +02:00
parent bb9241d9a0
commit d53b076ecc
3 changed files with 200 additions and 11 deletions

View File

@@ -4,6 +4,7 @@ from lightning import LightningDataModule
from datasets import load_dataset
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.utils import to_undirected
class GraphDataModule(LightningDataModule):
@@ -34,7 +35,7 @@ class GraphDataModule(LightningDataModule):
self.split_name
]
edge_index = torch.tensor(
self.geometry["edge_index"][0], dtype=torch.int32
self.geometry["edge_index"][0], dtype=torch.int64
)
pos = torch.tensor(self.geometry["points"][0], dtype=torch.float32)[
:, :2
@@ -51,23 +52,29 @@ class GraphDataModule(LightningDataModule):
]
def _build_dataset(
self, conductivity, boundary_vales, temperature, edge_index, pos
):
input_ = torch.stack([conductivity, boundary_vales], dim=1)
self,
conductivity: torch.Tensor,
boundary_vales: torch.Tensor,
temperature: torch.Tensor,
edge_index: torch.Tensor,
pos: torch.Tensor,
) -> Data:
edge_index = to_undirected(edge_index, num_nodes=pos.size(0))
edge_attr = pos[edge_index[0]] - pos[edge_index[1]]
edge_attr = torch.cat(
[edge_attr, torch.norm(edge_attr, dim=1).unsqueeze(-1)], dim=1
)
return Data(
x=input_,
x=boundary_vales.unsqueeze(-1),
c=conductivity.unsqueeze(-1),
edge_index=edge_index,
pos=pos,
edge_attr=edge_attr,
y=temperature,
y=temperature.unsqueeze(-1),
)
def setup(self, stage=None):
def setup(self, stage: str = None):
n = len(self.data)
train_end = int(n * self.train_size)
val_end = train_end + int(n * self.val_size)
@@ -78,13 +85,13 @@ class GraphDataModule(LightningDataModule):
if stage == "test" or stage is None:
self.test_data = self.data[val_end:]
def train_dataloader(self):
def train_dataloader(self) -> DataLoader:
return DataLoader(
self.train_data, batch_size=self.batch_size, shuffle=True
)
def val_dataloader(self):
def val_dataloader(self) -> DataLoader:
return DataLoader(self.val_data, batch_size=self.batch_size)
def test_dataloader(self):
def test_dataloader(self) -> DataLoader:
return DataLoader(self.test_data, batch_size=self.batch_size)