import torch from tqdm import tqdm from lightning import LightningDataModule from datasets import load_dataset from torch_geometric.data import Data from torch_geometric.loader import DataLoader from torch_geometric.utils import to_undirected from .mesh_data import MeshData class GraphDataModule(LightningDataModule): def __init__( self, hf_repo: str, split_name: str, n_elements: int = None, train_size: float = 0.2, val_size: float = 0.1, test_size: float = 0.1, batch_size: int = 32, remove_boundary_edges: bool = False, build_radial_graph: bool = False, radius: float = None, unrolling_steps: int = 1, ): super().__init__() self.hf_repo = hf_repo self.split_name = split_name self.n_elements = n_elements self.dataset_dict = {} self.train_dataset, self.val_dataset, self.test_dataset = ( None, None, None, ) self.unrolling_steps = unrolling_steps self.geometry_dict = {} self.train_size = train_size self.val_size = val_size self.test_size = test_size self.batch_size = batch_size self.remove_boundary_edges = remove_boundary_edges self.build_radial_graph = build_radial_graph self.radius = radius def prepare_data(self): dataset = load_dataset(self.hf_repo, name="snapshots")[self.split_name] geometry = load_dataset(self.hf_repo, name="geometry")[self.split_name] if self.n_elements is not None: dataset = dataset.select(range(self.n_elements)) geometry = geometry.select(range(self.n_elements)) total_len = len(dataset) train_len = int(self.train_size * total_len) valid_len = int(self.val_size * total_len) self.dataset_dict = { "train": dataset.select(range(0, train_len)), "val": dataset.select(range(train_len, train_len + valid_len)), "test": dataset.select(range(train_len + valid_len, total_len)), } self.geometry_dict = { "train": geometry.select(range(0, train_len)), "val": geometry.select(range(train_len, train_len + valid_len)), "test": geometry.select(range(train_len + valid_len, total_len)), } def _build_dataset( self, snapshot: dict, geometry: dict, test: bool = False, ) -> Data: conductivity = torch.tensor( geometry["conductivity"], dtype=torch.float32 ) temperatures = ( torch.tensor(snapshot["unsteady"], dtype=torch.float32) if not test else torch.stack( [ torch.tensor(snapshot["unsteady"], dtype=torch.float32)[ 0, ... ], torch.tensor(snapshot["steady"], dtype=torch.float32), ], dim=0, ) ) print(temperatures.shape) pos = torch.tensor(geometry["points"], dtype=torch.float32)[:, :2] if self.build_radial_graph: raise NotImplementedError( "Radial graph building not implemented yet." ) else: edge_index = torch.tensor( geometry["edge_index"], dtype=torch.int64 ).T edge_index = to_undirected(edge_index, num_nodes=pos.size(0)) boundary_mask = torch.tensor( geometry["constraints_mask"], dtype=torch.int64 ) boundary_values = torch.tensor( geometry["constraints_values"], dtype=torch.float32 ) edge_attr = torch.norm(pos[edge_index[0]] - pos[edge_index[1]], dim=1) if self.remove_boundary_edges: boundary_idx = torch.unique(boundary_mask) edge_index_mask = ~torch.isin(edge_index[1], boundary_idx) edge_index = edge_index[:, edge_index_mask] edge_attr = edge_attr[edge_index_mask] n_data = max(temperatures.size(0) - self.unrolling_steps, 1) data = [] if test: data.append( MeshData( x=temperatures[0, :].unsqueeze(-1), y=temperatures[1:2, :].unsqueeze(-1).permute(1, 0, 2), c=conductivity.unsqueeze(-1), edge_index=edge_index, pos=pos, edge_attr=edge_attr, boundary_mask=boundary_mask, boundary_values=boundary_values, ) ) return data for i in range(n_data): x = temperatures[i, :].unsqueeze(-1) y = ( temperatures[i + 1 : i + 1 + self.unrolling_steps, :] .unsqueeze(-1) .permute(1, 0, 2) ) data.append( MeshData( x=x, y=y, c=conductivity.unsqueeze(-1), edge_index=edge_index, pos=pos, edge_attr=edge_attr, boundary_mask=boundary_mask, boundary_values=boundary_values, ) ) return data def setup(self, stage: str = None): if stage == "fit" or stage is None: self.train_data = [ self._build_dataset(snap, geom) for snap, geom in tqdm( zip( self.dataset_dict["train"], self.geometry_dict["train"] ), desc="Building train graphs", total=len(self.dataset_dict["train"]), ) ] self.val_data = [ self._build_dataset(snap, geom) for snap, geom in tqdm( zip(self.dataset_dict["val"], self.geometry_dict["val"]), desc="Building val graphs", total=len(self.dataset_dict["val"]), ) ] if stage == "test" or stage is None: self.test_data = [ self._build_dataset(snap, geom, test=True) for snap, geom in tqdm( zip(self.dataset_dict["test"], self.geometry_dict["test"]), desc="Building test graphs", total=len(self.dataset_dict["test"]), ) ] # def create_autoregressive_datasets(self, dataset: str, no_unrolling: bool = False): # if dataset == "train": # return AutoregressiveDataset(self.train_data, self.unrolling_steps, no_unrolling) # if dataset == "val": # return AutoregressiveDataset(self.val_data, self.unrolling_steps, no_unrolling) # if dataset == "test": # return AutoregressiveDataset(self.test_data, self.unrolling_steps, no_unrolling) def train_dataloader(self): # ds = self.create_autoregressive_datasets(dataset="train") # self.train_dataset = ds # print(type(self.train_data[0])) ds = [i for data in self.train_data for i in data] print( f"\nLoading training data, using {self.unrolling_steps} unrolling steps..." ) return DataLoader( ds, batch_size=self.batch_size, shuffle=True, num_workers=8, pin_memory=True, ) def val_dataloader(self): print( f"\nLoading validation data, using {self.unrolling_steps} unrolling steps..." ) ds = [i for data in self.val_data for i in data] return DataLoader( ds, batch_size=128, shuffle=False, num_workers=8, pin_memory=True, ) def test_dataloader(self): ds = [i for data in self.test_data for i in data] return DataLoader( ds, batch_size=1, shuffle=False, num_workers=8, pin_memory=True, )