This commit is contained in:
FilippoOlivo
2025-11-13 17:03:18 +01:00
parent 51a0399111
commit 0ee63686dd
3 changed files with 56 additions and 17 deletions

View File

@@ -275,14 +275,19 @@ class PinaDataModule(LightningDataModule):
), ),
module="lightning.pytorch.trainer.connectors.data_connector", module="lightning.pytorch.trainer.connectors.data_connector",
) )
return PinaDataLoader( dl = PinaDataLoader(
dataset, dataset,
batch_size=self.batch_size, batch_size=self.batch_size,
shuffle=self.shuffle, shuffle=self.shuffle,
num_workers=self.num_workers, num_workers=self.num_workers,
common_batch_size=self.common_batch_size, common_batch_size=self.common_batch_size,
separate_conditions=self.separate_conditions, separate_conditions=self.separate_conditions,
device=self.trainer.strategy.root_device,
) )
if self.batch_size is None:
# Override the method to transfer the batch to the device
self.transfer_batch_to_device = self._transfer_batch_to_device_dummy
return dl
def val_dataloader(self): def val_dataloader(self):
""" """
@@ -325,7 +330,7 @@ class PinaDataModule(LightningDataModule):
:rtype: list[tuple] :rtype: list[tuple]
""" """
return batch return [(k, v) for k, v in batch.items()]
def _transfer_batch_to_device(self, batch, device, dataloader_idx): def _transfer_batch_to_device(self, batch, device, dataloader_idx):
""" """
@@ -383,9 +388,15 @@ class PinaDataModule(LightningDataModule):
to_return = {} to_return = {}
if hasattr(self, "train_dataset") and self.train_dataset is not None: if hasattr(self, "train_dataset") and self.train_dataset is not None:
to_return["train"] = self.train_dataset.input to_return["train"] = {
cond: data.input for cond, data in self.train_dataset.items()
}
if hasattr(self, "val_dataset") and self.val_dataset is not None: if hasattr(self, "val_dataset") and self.val_dataset is not None:
to_return["val"] = self.val_dataset.input to_return["val"] = {
cond: data.input for cond, data in self.val_dataset.items()
}
if hasattr(self, "test_dataset") and self.test_dataset is not None: if hasattr(self, "test_dataset") and self.test_dataset is not None:
to_return["test"] = self.test_dataset.input to_return["test"] = {
cond: data.input for cond, data in self.test_dataset.items()
}
return to_return return to_return

View File

@@ -13,7 +13,7 @@ class DummyDataloader:
DataLoader that returns the entire dataset in a single batch. DataLoader that returns the entire dataset in a single batch.
""" """
def __init__(self, dataset): def __init__(self, dataset, device=None):
""" """
Prepare a dataloader object that returns the entire dataset in a single Prepare a dataloader object that returns the entire dataset in a single
batch. Depending on the number of GPUs, the dataset is managed batch. Depending on the number of GPUs, the dataset is managed
@@ -47,9 +47,14 @@ class DummyDataloader:
idx.append(i) idx.append(i)
i += world_size i += world_size
else: else:
idx = list(range(len(dataset))) idx = [i for i in range(len(dataset))]
self.dataset = dataset.getitem_from_list(idx) self.dataset = dataset.getitem_from_list(idx)
self.device = device
self.dataset = (
{k: v.to(self.device) for k, v in self.dataset.items()}
if self.device
else self.dataset
)
def __iter__(self): def __iter__(self):
""" """
@@ -155,12 +160,14 @@ class PinaDataLoader:
shuffle=False, shuffle=False,
common_batch_size=True, common_batch_size=True,
separate_conditions=False, separate_conditions=False,
device=None,
): ):
self.dataset_dict = dataset_dict self.dataset_dict = dataset_dict
self.batch_size = batch_size self.batch_size = batch_size
self.num_workers = num_workers self.num_workers = num_workers
self.shuffle = shuffle self.shuffle = shuffle
self.separate_conditions = separate_conditions self.separate_conditions = separate_conditions
self.device = device
# Batch size None means we want to load the entire dataset in a single # Batch size None means we want to load the entire dataset in a single
# batch # batch
@@ -238,7 +245,7 @@ class PinaDataLoader:
""" """
# If batch size is None, use DummyDataloader # If batch size is None, use DummyDataloader
if batch_size is None or batch_size >= len(dataset): if batch_size is None or batch_size >= len(dataset):
return DummyDataloader(dataset) return DummyDataloader(dataset, device=self.device)
# Determine the appropriate collate function # Determine the appropriate collate function
if not dataset.automatic_batching: if not dataset.automatic_batching:

View File

@@ -6,6 +6,12 @@ from torch_geometric.data import Data
from ..graph import Graph, LabelBatch from ..graph import Graph, LabelBatch
from ..label_tensor import LabelTensor from ..label_tensor import LabelTensor
STACK_FN_MAP = {
"label_tensor": LabelTensor.stack,
"tensor": torch.stack,
"data": LabelBatch.from_data_list,
}
class PinaDatasetFactory: class PinaDatasetFactory:
""" """
@@ -65,18 +71,18 @@ class PinaDataset(Dataset):
self.automatic_batching = ( self.automatic_batching = (
automatic_batching if automatic_batching is not None else True automatic_batching if automatic_batching is not None else True
) )
self.stack_fn = {} self._stack_fn = {}
self.is_graph_dataset = False self.is_graph_dataset = False
# Determine stacking functions for each data type (used in collate_fn) # Determine stacking functions for each data type (used in collate_fn)
for k, v in data_dict.items(): for k, v in data_dict.items():
if isinstance(v, LabelTensor): if isinstance(v, LabelTensor):
self.stack_fn[k] = LabelTensor.stack self._stack_fn[k] = "label_tensor"
elif isinstance(v, torch.Tensor): elif isinstance(v, torch.Tensor):
self.stack_fn[k] = torch.stack self._stack_fn[k] = "tensor"
elif isinstance(v, list) and all( elif isinstance(v, list) and all(
isinstance(item, (Data, Graph)) for item in v isinstance(item, (Data, Graph)) for item in v
): ):
self.stack_fn[k] = LabelBatch.from_data_list self._stack_fn[k] = "data"
self.is_graph_dataset = True self.is_graph_dataset = True
else: else:
raise ValueError( raise ValueError(
@@ -84,6 +90,11 @@ class PinaDataset(Dataset):
) )
def __len__(self): def __len__(self):
"""
Return the length of the dataset.
:return: The length of the dataset.
:rtype: int
"""
return len(next(iter(self.data.values()))) return len(next(iter(self.data.values())))
def __getitem__(self, idx): def __getitem__(self, idx):
@@ -113,10 +124,9 @@ class PinaDataset(Dataset):
to_return = {} to_return = {}
for field_name, data in self.data.items(): for field_name, data in self.data.items():
if self.stack_fn[field_name] is LabelBatch.from_data_list: if self._stack_fn[field_name] == "data":
to_return[field_name] = self.stack_fn[field_name]( fn = STACK_FN_MAP[self._stack_fn[field_name]]
[data[i] for i in idx_list] to_return[field_name] = fn([data[i] for i in idx_list])
)
else: else:
to_return[field_name] = data[idx_list] to_return[field_name] = data[idx_list]
return to_return return to_return
@@ -148,3 +158,14 @@ class PinaDataset(Dataset):
:rtype: torch.Tensor | LabelTensor | Data | Graph :rtype: torch.Tensor | LabelTensor | Data | Graph
""" """
return self.data["input"] return self.data["input"]
@property
def stack_fn(self):
"""
Get the mapping of stacking functions for each data type in the dataset.
:return: A dictionary mapping condition names to their respective
stacking function identifiers.
:rtype: dict
"""
return {k: STACK_FN_MAP[v] for k, v in self._stack_fn.items()}