From 0ee63686dde860e5e6d58151b1ba4b29c2192c12 Mon Sep 17 00:00:00 2001 From: FilippoOlivo Date: Thu, 13 Nov 2025 17:03:18 +0100 Subject: [PATCH] fix --- pina/data/data_module.py | 21 ++++++++++++++++----- pina/data/dataloader.py | 15 +++++++++++---- pina/data/dataset.py | 37 +++++++++++++++++++++++++++++-------- 3 files changed, 56 insertions(+), 17 deletions(-) diff --git a/pina/data/data_module.py b/pina/data/data_module.py index f1910f8..9a0cf0a 100644 --- a/pina/data/data_module.py +++ b/pina/data/data_module.py @@ -275,14 +275,19 @@ class PinaDataModule(LightningDataModule): ), module="lightning.pytorch.trainer.connectors.data_connector", ) - return PinaDataLoader( + dl = PinaDataLoader( dataset, batch_size=self.batch_size, shuffle=self.shuffle, num_workers=self.num_workers, common_batch_size=self.common_batch_size, separate_conditions=self.separate_conditions, + device=self.trainer.strategy.root_device, ) + if self.batch_size is None: + # Override the method to transfer the batch to the device + self.transfer_batch_to_device = self._transfer_batch_to_device_dummy + return dl def val_dataloader(self): """ @@ -325,7 +330,7 @@ class PinaDataModule(LightningDataModule): :rtype: list[tuple] """ - return batch + return [(k, v) for k, v in batch.items()] def _transfer_batch_to_device(self, batch, device, dataloader_idx): """ @@ -383,9 +388,15 @@ class PinaDataModule(LightningDataModule): to_return = {} if hasattr(self, "train_dataset") and self.train_dataset is not None: - to_return["train"] = self.train_dataset.input + to_return["train"] = { + cond: data.input for cond, data in self.train_dataset.items() + } if hasattr(self, "val_dataset") and self.val_dataset is not None: - to_return["val"] = self.val_dataset.input + to_return["val"] = { + cond: data.input for cond, data in self.val_dataset.items() + } if hasattr(self, "test_dataset") and self.test_dataset is not None: - to_return["test"] = self.test_dataset.input + to_return["test"] = { + cond: data.input for cond, data in self.test_dataset.items() + } return to_return diff --git a/pina/data/dataloader.py b/pina/data/dataloader.py index 29b3673..6267868 100644 --- a/pina/data/dataloader.py +++ b/pina/data/dataloader.py @@ -13,7 +13,7 @@ class DummyDataloader: DataLoader that returns the entire dataset in a single batch. """ - def __init__(self, dataset): + def __init__(self, dataset, device=None): """ Prepare a dataloader object that returns the entire dataset in a single batch. Depending on the number of GPUs, the dataset is managed @@ -47,9 +47,14 @@ class DummyDataloader: idx.append(i) i += world_size else: - idx = list(range(len(dataset))) - + idx = [i for i in range(len(dataset))] self.dataset = dataset.getitem_from_list(idx) + self.device = device + self.dataset = ( + {k: v.to(self.device) for k, v in self.dataset.items()} + if self.device + else self.dataset + ) def __iter__(self): """ @@ -155,12 +160,14 @@ class PinaDataLoader: shuffle=False, common_batch_size=True, separate_conditions=False, + device=None, ): self.dataset_dict = dataset_dict self.batch_size = batch_size self.num_workers = num_workers self.shuffle = shuffle self.separate_conditions = separate_conditions + self.device = device # Batch size None means we want to load the entire dataset in a single # batch @@ -238,7 +245,7 @@ class PinaDataLoader: """ # If batch size is None, use DummyDataloader if batch_size is None or batch_size >= len(dataset): - return DummyDataloader(dataset) + return DummyDataloader(dataset, device=self.device) # Determine the appropriate collate function if not dataset.automatic_batching: diff --git a/pina/data/dataset.py b/pina/data/dataset.py index 725f1f5..9448511 100644 --- a/pina/data/dataset.py +++ b/pina/data/dataset.py @@ -6,6 +6,12 @@ from torch_geometric.data import Data from ..graph import Graph, LabelBatch from ..label_tensor import LabelTensor +STACK_FN_MAP = { + "label_tensor": LabelTensor.stack, + "tensor": torch.stack, + "data": LabelBatch.from_data_list, +} + class PinaDatasetFactory: """ @@ -65,18 +71,18 @@ class PinaDataset(Dataset): self.automatic_batching = ( automatic_batching if automatic_batching is not None else True ) - self.stack_fn = {} + self._stack_fn = {} self.is_graph_dataset = False # Determine stacking functions for each data type (used in collate_fn) for k, v in data_dict.items(): if isinstance(v, LabelTensor): - self.stack_fn[k] = LabelTensor.stack + self._stack_fn[k] = "label_tensor" elif isinstance(v, torch.Tensor): - self.stack_fn[k] = torch.stack + self._stack_fn[k] = "tensor" elif isinstance(v, list) and all( isinstance(item, (Data, Graph)) for item in v ): - self.stack_fn[k] = LabelBatch.from_data_list + self._stack_fn[k] = "data" self.is_graph_dataset = True else: raise ValueError( @@ -84,6 +90,11 @@ class PinaDataset(Dataset): ) def __len__(self): + """ + Return the length of the dataset. + :return: The length of the dataset. + :rtype: int + """ return len(next(iter(self.data.values()))) def __getitem__(self, idx): @@ -113,10 +124,9 @@ class PinaDataset(Dataset): to_return = {} for field_name, data in self.data.items(): - if self.stack_fn[field_name] is LabelBatch.from_data_list: - to_return[field_name] = self.stack_fn[field_name]( - [data[i] for i in idx_list] - ) + if self._stack_fn[field_name] == "data": + fn = STACK_FN_MAP[self._stack_fn[field_name]] + to_return[field_name] = fn([data[i] for i in idx_list]) else: to_return[field_name] = data[idx_list] return to_return @@ -148,3 +158,14 @@ class PinaDataset(Dataset): :rtype: torch.Tensor | LabelTensor | Data | Graph """ return self.data["input"] + + @property + def stack_fn(self): + """ + Get the mapping of stacking functions for each data type in the dataset. + + :return: A dictionary mapping condition names to their respective + stacking function identifiers. + :rtype: dict + """ + return {k: STACK_FN_MAP[v] for k, v in self._stack_fn.items()}