diff --git a/ThermalSolver/data_module.py b/ThermalSolver/data_module.py new file mode 100644 index 0000000..3b65f3a --- /dev/null +++ b/ThermalSolver/data_module.py @@ -0,0 +1,90 @@ +import torch +from tqdm import tqdm +from lightning import LightningDataModule +from datasets import load_dataset +from torch_geometric.data import Data +from torch_geometric.loader import DataLoader + + +class GraphDataModule(LightningDataModule): + def __init__( + self, + hf_repo: str, + split_name: str, + train_size: float = 0.8, + val_size: float = 0.1, + test_size: float = 0.1, + batch_size: int = 32, + ): + super().__init__() + self.hf_repo = hf_repo + self.split_name = split_name + self.dataset = None + self.geometry = None + self.train_size = train_size + self.val_size = val_size + self.test_size = test_size + self.batch_size = batch_size + + def prepare_data(self): + hf_dataset = load_dataset(self.hf_repo, name="snapshots")[ + self.split_name + ] + self.geometry = load_dataset(self.hf_repo, name="geometry")[ + self.split_name + ] + edge_index = torch.tensor( + self.geometry["edge_index"][0], dtype=torch.int32 + ) + pos = torch.tensor(self.geometry["points"][0], dtype=torch.float32)[ + :, :2 + ] + self.data = [ + self._build_dataset( + torch.tensor(snapshot["conductivity"], dtype=torch.float32), + torch.tensor(snapshot["boundary_values"], dtype=torch.float32), + torch.tensor(snapshot["temperature"], dtype=torch.float32), + edge_index.T, + pos, + ) + for snapshot in tqdm(hf_dataset, desc="Building graphs") + ] + + def _build_dataset( + self, conductivity, boundary_vales, temperature, edge_index, pos + ): + input_ = torch.stack([conductivity, boundary_vales], dim=1) + edge_attr = pos[edge_index[0]] - pos[edge_index[1]] + edge_attr = torch.cat( + [edge_attr, torch.norm(edge_attr, dim=1).unsqueeze(-1)], dim=1 + ) + + return Data( + x=input_, + edge_index=edge_index, + pos=pos, + edge_attr=edge_attr, + y=temperature, + ) + + def setup(self, stage=None): + n = len(self.data) + train_end = int(n * self.train_size) + val_end = train_end + int(n * self.val_size) + + if stage == "fit" or stage is None: + self.train_data = self.data[:train_end] + self.val_data = self.data[train_end:val_end] + if stage == "test" or stage is None: + self.test_data = self.data[val_end:] + + def train_dataloader(self): + return DataLoader( + self.train_data, batch_size=self.batch_size, shuffle=True + ) + + def val_dataloader(self): + return DataLoader(self.val_data, batch_size=self.batch_size) + + def test_dataloader(self): + return DataLoader(self.test_data, batch_size=self.batch_size) diff --git a/tests/test_datamodule.py b/tests/test_datamodule.py new file mode 100644 index 0000000..c91afa8 --- /dev/null +++ b/tests/test_datamodule.py @@ -0,0 +1,13 @@ +from ThermalSolver.data_module import GraphDataModule + +def test_graph_data_module(): + data_module = GraphDataModule( + hf_repo="SISSAmathLab/thermal-conduction", + split_name="pytest", + train_size=0.8, + val_size=0.1, + test_size=0.1, + batch_size=32, + ) + data_module.prepare_data() + data_module.setup("fit") \ No newline at end of file