fix
This commit is contained in:
@@ -275,14 +275,19 @@ class PinaDataModule(LightningDataModule):
|
|||||||
),
|
),
|
||||||
module="lightning.pytorch.trainer.connectors.data_connector",
|
module="lightning.pytorch.trainer.connectors.data_connector",
|
||||||
)
|
)
|
||||||
return PinaDataLoader(
|
dl = PinaDataLoader(
|
||||||
dataset,
|
dataset,
|
||||||
batch_size=self.batch_size,
|
batch_size=self.batch_size,
|
||||||
shuffle=self.shuffle,
|
shuffle=self.shuffle,
|
||||||
num_workers=self.num_workers,
|
num_workers=self.num_workers,
|
||||||
common_batch_size=self.common_batch_size,
|
common_batch_size=self.common_batch_size,
|
||||||
separate_conditions=self.separate_conditions,
|
separate_conditions=self.separate_conditions,
|
||||||
|
device=self.trainer.strategy.root_device,
|
||||||
)
|
)
|
||||||
|
if self.batch_size is None:
|
||||||
|
# Override the method to transfer the batch to the device
|
||||||
|
self.transfer_batch_to_device = self._transfer_batch_to_device_dummy
|
||||||
|
return dl
|
||||||
|
|
||||||
def val_dataloader(self):
|
def val_dataloader(self):
|
||||||
"""
|
"""
|
||||||
@@ -325,7 +330,7 @@ class PinaDataModule(LightningDataModule):
|
|||||||
:rtype: list[tuple]
|
:rtype: list[tuple]
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return batch
|
return [(k, v) for k, v in batch.items()]
|
||||||
|
|
||||||
def _transfer_batch_to_device(self, batch, device, dataloader_idx):
|
def _transfer_batch_to_device(self, batch, device, dataloader_idx):
|
||||||
"""
|
"""
|
||||||
@@ -383,9 +388,15 @@ class PinaDataModule(LightningDataModule):
|
|||||||
|
|
||||||
to_return = {}
|
to_return = {}
|
||||||
if hasattr(self, "train_dataset") and self.train_dataset is not None:
|
if hasattr(self, "train_dataset") and self.train_dataset is not None:
|
||||||
to_return["train"] = self.train_dataset.input
|
to_return["train"] = {
|
||||||
|
cond: data.input for cond, data in self.train_dataset.items()
|
||||||
|
}
|
||||||
if hasattr(self, "val_dataset") and self.val_dataset is not None:
|
if hasattr(self, "val_dataset") and self.val_dataset is not None:
|
||||||
to_return["val"] = self.val_dataset.input
|
to_return["val"] = {
|
||||||
|
cond: data.input for cond, data in self.val_dataset.items()
|
||||||
|
}
|
||||||
if hasattr(self, "test_dataset") and self.test_dataset is not None:
|
if hasattr(self, "test_dataset") and self.test_dataset is not None:
|
||||||
to_return["test"] = self.test_dataset.input
|
to_return["test"] = {
|
||||||
|
cond: data.input for cond, data in self.test_dataset.items()
|
||||||
|
}
|
||||||
return to_return
|
return to_return
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ class DummyDataloader:
|
|||||||
DataLoader that returns the entire dataset in a single batch.
|
DataLoader that returns the entire dataset in a single batch.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, dataset):
|
def __init__(self, dataset, device=None):
|
||||||
"""
|
"""
|
||||||
Prepare a dataloader object that returns the entire dataset in a single
|
Prepare a dataloader object that returns the entire dataset in a single
|
||||||
batch. Depending on the number of GPUs, the dataset is managed
|
batch. Depending on the number of GPUs, the dataset is managed
|
||||||
@@ -47,9 +47,14 @@ class DummyDataloader:
|
|||||||
idx.append(i)
|
idx.append(i)
|
||||||
i += world_size
|
i += world_size
|
||||||
else:
|
else:
|
||||||
idx = list(range(len(dataset)))
|
idx = [i for i in range(len(dataset))]
|
||||||
|
|
||||||
self.dataset = dataset.getitem_from_list(idx)
|
self.dataset = dataset.getitem_from_list(idx)
|
||||||
|
self.device = device
|
||||||
|
self.dataset = (
|
||||||
|
{k: v.to(self.device) for k, v in self.dataset.items()}
|
||||||
|
if self.device
|
||||||
|
else self.dataset
|
||||||
|
)
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
"""
|
"""
|
||||||
@@ -155,12 +160,14 @@ class PinaDataLoader:
|
|||||||
shuffle=False,
|
shuffle=False,
|
||||||
common_batch_size=True,
|
common_batch_size=True,
|
||||||
separate_conditions=False,
|
separate_conditions=False,
|
||||||
|
device=None,
|
||||||
):
|
):
|
||||||
self.dataset_dict = dataset_dict
|
self.dataset_dict = dataset_dict
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.num_workers = num_workers
|
self.num_workers = num_workers
|
||||||
self.shuffle = shuffle
|
self.shuffle = shuffle
|
||||||
self.separate_conditions = separate_conditions
|
self.separate_conditions = separate_conditions
|
||||||
|
self.device = device
|
||||||
|
|
||||||
# Batch size None means we want to load the entire dataset in a single
|
# Batch size None means we want to load the entire dataset in a single
|
||||||
# batch
|
# batch
|
||||||
@@ -238,7 +245,7 @@ class PinaDataLoader:
|
|||||||
"""
|
"""
|
||||||
# If batch size is None, use DummyDataloader
|
# If batch size is None, use DummyDataloader
|
||||||
if batch_size is None or batch_size >= len(dataset):
|
if batch_size is None or batch_size >= len(dataset):
|
||||||
return DummyDataloader(dataset)
|
return DummyDataloader(dataset, device=self.device)
|
||||||
|
|
||||||
# Determine the appropriate collate function
|
# Determine the appropriate collate function
|
||||||
if not dataset.automatic_batching:
|
if not dataset.automatic_batching:
|
||||||
|
|||||||
@@ -6,6 +6,12 @@ from torch_geometric.data import Data
|
|||||||
from ..graph import Graph, LabelBatch
|
from ..graph import Graph, LabelBatch
|
||||||
from ..label_tensor import LabelTensor
|
from ..label_tensor import LabelTensor
|
||||||
|
|
||||||
|
STACK_FN_MAP = {
|
||||||
|
"label_tensor": LabelTensor.stack,
|
||||||
|
"tensor": torch.stack,
|
||||||
|
"data": LabelBatch.from_data_list,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class PinaDatasetFactory:
|
class PinaDatasetFactory:
|
||||||
"""
|
"""
|
||||||
@@ -65,18 +71,18 @@ class PinaDataset(Dataset):
|
|||||||
self.automatic_batching = (
|
self.automatic_batching = (
|
||||||
automatic_batching if automatic_batching is not None else True
|
automatic_batching if automatic_batching is not None else True
|
||||||
)
|
)
|
||||||
self.stack_fn = {}
|
self._stack_fn = {}
|
||||||
self.is_graph_dataset = False
|
self.is_graph_dataset = False
|
||||||
# Determine stacking functions for each data type (used in collate_fn)
|
# Determine stacking functions for each data type (used in collate_fn)
|
||||||
for k, v in data_dict.items():
|
for k, v in data_dict.items():
|
||||||
if isinstance(v, LabelTensor):
|
if isinstance(v, LabelTensor):
|
||||||
self.stack_fn[k] = LabelTensor.stack
|
self._stack_fn[k] = "label_tensor"
|
||||||
elif isinstance(v, torch.Tensor):
|
elif isinstance(v, torch.Tensor):
|
||||||
self.stack_fn[k] = torch.stack
|
self._stack_fn[k] = "tensor"
|
||||||
elif isinstance(v, list) and all(
|
elif isinstance(v, list) and all(
|
||||||
isinstance(item, (Data, Graph)) for item in v
|
isinstance(item, (Data, Graph)) for item in v
|
||||||
):
|
):
|
||||||
self.stack_fn[k] = LabelBatch.from_data_list
|
self._stack_fn[k] = "data"
|
||||||
self.is_graph_dataset = True
|
self.is_graph_dataset = True
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -84,6 +90,11 @@ class PinaDataset(Dataset):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
|
"""
|
||||||
|
Return the length of the dataset.
|
||||||
|
:return: The length of the dataset.
|
||||||
|
:rtype: int
|
||||||
|
"""
|
||||||
return len(next(iter(self.data.values())))
|
return len(next(iter(self.data.values())))
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
@@ -113,10 +124,9 @@ class PinaDataset(Dataset):
|
|||||||
|
|
||||||
to_return = {}
|
to_return = {}
|
||||||
for field_name, data in self.data.items():
|
for field_name, data in self.data.items():
|
||||||
if self.stack_fn[field_name] is LabelBatch.from_data_list:
|
if self._stack_fn[field_name] == "data":
|
||||||
to_return[field_name] = self.stack_fn[field_name](
|
fn = STACK_FN_MAP[self._stack_fn[field_name]]
|
||||||
[data[i] for i in idx_list]
|
to_return[field_name] = fn([data[i] for i in idx_list])
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
to_return[field_name] = data[idx_list]
|
to_return[field_name] = data[idx_list]
|
||||||
return to_return
|
return to_return
|
||||||
@@ -148,3 +158,14 @@ class PinaDataset(Dataset):
|
|||||||
:rtype: torch.Tensor | LabelTensor | Data | Graph
|
:rtype: torch.Tensor | LabelTensor | Data | Graph
|
||||||
"""
|
"""
|
||||||
return self.data["input"]
|
return self.data["input"]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def stack_fn(self):
|
||||||
|
"""
|
||||||
|
Get the mapping of stacking functions for each data type in the dataset.
|
||||||
|
|
||||||
|
:return: A dictionary mapping condition names to their respective
|
||||||
|
stacking function identifiers.
|
||||||
|
:rtype: dict
|
||||||
|
"""
|
||||||
|
return {k: STACK_FN_MAP[v] for k, v in self._stack_fn.items()}
|
||||||
|
|||||||
Reference in New Issue
Block a user