import torch from tqdm import tqdm from lightning import LightningDataModule from datasets import load_dataset import os from torch.utils.data import DataLoader, TensorDataset from torch.nn.utils.rnn import pad_sequence class PointDataModule(LightningDataModule): def __init__( self, hf_repo: str, split_name: str, train_size: float = 0.2, val_size: float = 0.1, test_size: float = 0.1, batch_size: int = 32, remove_boundary_edges: bool = True, ): super().__init__() self.hf_repo = hf_repo self.split_name = split_name self.dataset_dict = {} self.geometry_dict = {} self.train_size = train_size self.val_size = val_size self.test_size = test_size self.batch_size = batch_size self.remove_boundary_edges = remove_boundary_edges def prepare_data(self): dataset = load_dataset(self.hf_repo, name="snapshots")[self.split_name] geometry = load_dataset(self.hf_repo, name="geometry")[self.split_name] total_len = len(dataset) train_len = int(self.train_size * total_len) valid_len = int(self.val_size * total_len) self.dataset_dict = { "train": dataset.select(range(0, train_len)), "val": dataset.select(range(train_len, train_len + valid_len)), "test": dataset.select(range(train_len + valid_len, total_len)), } self.geometry_dict = { "train": geometry.select(range(0, train_len)), "val": geometry.select(range(train_len, train_len + valid_len)), "test": geometry.select(range(train_len + valid_len, total_len)), } def _compute_boundary_mask( self, bottom_ids, right_ids, top_ids, left_ids, temperature ): left_ids = left_ids[~torch.isin(left_ids, bottom_ids)] right_ids = right_ids[~torch.isin(right_ids, bottom_ids)] left_ids = left_ids[~torch.isin(left_ids, top_ids)] right_ids = right_ids[~torch.isin(right_ids, top_ids)] bottom_bc = temperature[bottom_ids].median() bottom_bc_mask = torch.ones(len(bottom_ids)) * bottom_bc left_bc = temperature[left_ids].median() left_bc_mask = torch.ones(len(left_ids)) * left_bc right_bc = temperature[right_ids].median() right_bc_mask = torch.ones(len(right_ids)) * right_bc boundary_values = torch.cat( [bottom_bc_mask, right_bc_mask, left_bc_mask], dim=0 ) boundary_mask = torch.cat([bottom_ids, right_ids, left_ids], dim=0) return boundary_mask, boundary_values def _build_dataset( self, snapshot: dict, geometry: dict, ) -> tuple[torch.Tensor, torch.Tensor]: conductivity = torch.tensor( snapshot["conductivity"], dtype=torch.float32 ) temperature = torch.tensor(snapshot["temperature"], dtype=torch.float32) pos = torch.tensor(geometry["points"], dtype=torch.float32)[:, :2] bottom_ids = torch.tensor( geometry["bottom_boundary_ids"], dtype=torch.long ) top_ids = torch.tensor(geometry["top_boundary_ids"], dtype=torch.long) left_ids = torch.tensor(geometry["left_boundary_ids"], dtype=torch.long) right_ids = torch.tensor( geometry["right_boundary_ids"], dtype=torch.long ) boundary_mask, boundary_values = self._compute_boundary_mask( bottom_ids, right_ids, top_ids, left_ids, temperature ) x = torch.zeros_like(temperature, dtype=torch.float32).unsqueeze(-1) x[boundary_mask] = boundary_values.unsqueeze(-1) x = torch.cat([x, conductivity.unsqueeze(-1), pos], dim=-1) return x, temperature.unsqueeze(-1) def setup(self, stage: str = None): if stage == "fit" or stage is None: x = [] y = [] for snap, geom in tqdm( zip(self.dataset_dict["train"], self.geometry_dict["train"]), desc="Building train graphs", total=len(self.dataset_dict["train"]), ): x_i, y_i = self._build_dataset(snap, geom) x.append(x_i) y.append(y_i) self.train_dataset = TensorDataset( pad_sequence(x, batch_first=True, padding_value=-1), pad_sequence(y, batch_first=True, padding_value=-1), ) for snap, geom in tqdm( zip(self.dataset_dict["val"], self.geometry_dict["val"]), desc="Building val graphs", total=len(self.dataset_dict["val"]), ): x_i, y_i = self._build_dataset(snap, geom) x.append(x_i) y.append(y_i) self.val_dataset = TensorDataset( pad_sequence(x, batch_first=True, padding_value=-1), pad_sequence(y, batch_first=True, padding_value=-1), ) if stage == "test" or stage is None: x = [] y = [] for snap, geom in tqdm( zip(self.dataset_dict["test"], self.geometry_dict["test"]), desc="Building test graphs", total=len(self.dataset_dict["test"]), ): x_i, y_i = self._build_dataset(snap, geom) x.append(x_i) y.append(y_i) self.test_data = TensorDataset( pad_sequence(x, batch_first=True, padding_value=-1), pad_sequence(y, batch_first=True, padding_value=-1), ) def train_dataloader(self): return DataLoader( self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=8, pin_memory=True, ) def val_dataloader(self): return DataLoader( self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=8, pin_memory=True, ) def test_dataloader(self): return DataLoader( self.test_data, batch_size=self.batch_size, shuffle=False, num_workers=8, pin_memory=True, )