import torch from tqdm import tqdm from lightning import LightningDataModule from datasets import load_dataset from torch_geometric.data import Data from torch_geometric.loader import DataLoader from torch_geometric.utils import to_undirected from .mesh_data import MeshData class GraphDataModule(LightningDataModule): def __init__( self, hf_repo: str, split_name: str, train_size: float = 0.2, val_size: float = 0.1, test_size: float = 0.1, batch_size: int = 32, ): super().__init__() self.hf_repo = hf_repo self.split_name = split_name self.dataset = None self.geometry = None self.train_size = train_size self.val_size = val_size self.test_size = test_size self.batch_size = batch_size def prepare_data(self): hf_dataset = load_dataset(self.hf_repo, name="snapshots")[ self.split_name ] self.geometry = load_dataset(self.hf_repo, name="geometry")[ self.split_name ] edge_index = torch.tensor( self.geometry["edge_index"][0], dtype=torch.int64 ) pos = torch.tensor(self.geometry["points"][0], dtype=torch.float32)[ :, :2 ] bottom_ids = torch.tensor( self.geometry["bottom_boundary_ids"][0], dtype=torch.long ) top_ids = torch.tensor( self.geometry["top_boundary_ids"][0], dtype=torch.long ) left_ids = torch.tensor( self.geometry["left_boundary_ids"][0], dtype=torch.long ) right_ids = torch.tensor( self.geometry["right_boundary_ids"][0], dtype=torch.long ) self.data = [ self._build_dataset( snapshot, edge_index.T, pos, bottom_ids, top_ids, left_ids, right_ids, ) for snapshot in tqdm(hf_dataset, desc="Building graphs") ] def _build_dataset( self, snapshot: dict, edge_index: torch.Tensor, pos: torch.Tensor, bottom_ids: torch.Tensor, top_ids: torch.Tensor, left_ids: torch.Tensor, right_ids: torch.Tensor, ) -> Data: conductivity = torch.tensor( snapshot["conductivity"], dtype=torch.float32 ) temperature = torch.tensor(snapshot["temperature"], dtype=torch.float32) edge_index = to_undirected(edge_index, num_nodes=pos.size(0)) edge_attr = pos[edge_index[0]] - pos[edge_index[1]] edge_attr = torch.cat( [edge_attr, torch.norm(edge_attr, dim=1).unsqueeze(-1)], dim=1 ) left_ids = left_ids[~torch.isin(left_ids, bottom_ids)] right_ids = right_ids[~torch.isin(right_ids, bottom_ids)] left_ids = left_ids[~torch.isin(left_ids, top_ids)] right_ids = right_ids[~torch.isin(right_ids, top_ids)] bottom_bc = temperature[bottom_ids].median() bottom_bc_mask = torch.ones(len(bottom_ids)) * bottom_bc left_bc = temperature[left_ids].median() left_bc_mask = torch.ones(len(left_ids)) * left_bc right_bc = temperature[right_ids].median() right_bc_mask = torch.ones(len(right_ids)) * right_bc boundary_values = torch.cat( [bottom_bc_mask, right_bc_mask, left_bc_mask], dim=0 ) boundary_mask = torch.cat([bottom_ids, right_ids, left_ids], dim=0) return MeshData( x=torch.rand_like(temperature).unsqueeze(-1), c=conductivity.unsqueeze(-1), edge_index=edge_index, pos=pos, edge_attr=edge_attr, boundary_mask=boundary_mask, boundary_values=boundary_values.unsqueeze(-1), y=temperature.unsqueeze(-1), ) def setup(self, stage: str = None): n = len(self.data) train_end = int(n * self.train_size) val_end = train_end + int(n * self.val_size) if stage == "fit" or stage is None: self.train_data = self.data[:train_end] self.val_data = self.data[train_end:val_end] if stage == "test" or stage is None: self.test_data = self.data[val_end:] # nel tuo LightningDataModule def train_dataloader(self): return DataLoader( self.train_data, batch_size=self.batch_size, shuffle=True ) def val_dataloader(self): return DataLoader( self.val_data, batch_size=self.batch_size, shuffle=False ) def test_dataloader(self): return DataLoader( self.test_data, batch_size=self.batch_size, shuffle=False )