refact dataset, dataloader and datamodule

This commit is contained in:
FilippoOlivo
2025-11-12 14:32:56 +01:00
parent f07e59b69b
commit 99e2f07cf7
4 changed files with 375 additions and 460 deletions

View File

@@ -4,4 +4,3 @@ __all__ = ["PinaDataModule", "PinaDataset"]
from .data_module import PinaDataModule from .data_module import PinaDataModule
from .dataset import PinaDataset

View File

@@ -11,203 +11,8 @@ from torch_geometric.data import Data
from torch.utils.data import DataLoader, SequentialSampler, RandomSampler from torch.utils.data import DataLoader, SequentialSampler, RandomSampler
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
from ..label_tensor import LabelTensor from ..label_tensor import LabelTensor
from .dataset import PinaDatasetFactory, PinaTensorDataset from .dataset import PinaDatasetFactory
from .dataloader import PinaDataLoader
class DummyDataloader:
def __init__(self, dataset):
"""
Prepare a dataloader object that returns the entire dataset in a single
batch. Depending on the number of GPUs, the dataset is managed
as follows:
- **Distributed Environment** (multiple GPUs): Divides dataset across
processes using the rank and world size. Fetches only portion of
data corresponding to the current process.
- **Non-Distributed Environment** (single GPU): Fetches the entire
dataset.
:param PinaDataset dataset: The dataset object to be processed.
.. note::
This dataloader is used when the batch size is ``None``.
"""
if (
torch.distributed.is_available()
and torch.distributed.is_initialized()
):
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
if len(dataset) < world_size:
raise RuntimeError(
"Dimension of the dataset smaller than world size."
" Increase the size of the partition or use a single GPU"
)
idx, i = [], rank
while i < len(dataset):
idx.append(i)
i += world_size
self.dataset = dataset.fetch_from_idx_list(idx)
else:
self.dataset = dataset.get_all_data()
def __iter__(self):
return self
def __len__(self):
return 1
def __next__(self):
return self.dataset
class Collator:
"""
This callable class is used to collate the data points fetched from the
dataset. The collation is performed based on the type of dataset used and
on the batching strategy.
"""
def __init__(
self, max_conditions_lengths, automatic_batching, dataset=None
):
"""
Initialize the object, setting the collate function based on whether
automatic batching is enabled or not.
:param dict max_conditions_lengths: ``dict`` containing the maximum
number of data points to consider in a single batch for
each condition.
:param bool automatic_batching: Whether automatic PyTorch batching is
enabled or not. For more information, see the
:class:`~pina.data.data_module.PinaDataModule` class.
:param PinaDataset dataset: The dataset where the data is stored.
"""
self.max_conditions_lengths = max_conditions_lengths
# Set the collate function based on the batching strategy
# collate_pina_dataloader is used when automatic batching is disabled
# collate_torch_dataloader is used when automatic batching is enabled
self.callable_function = (
self._collate_torch_dataloader
if automatic_batching
else (self._collate_pina_dataloader)
)
self.dataset = dataset
# Set the function which performs the actual collation
if isinstance(self.dataset, PinaTensorDataset):
# If the dataset is a PinaTensorDataset, use this collate function
self._collate = self._collate_tensor_dataset
else:
# If the dataset is a PinaDataset, use this collate function
self._collate = self._collate_graph_dataset
def _collate_pina_dataloader(self, batch):
"""
Function used to create a batch when automatic batching is disabled.
:param list[int] batch: List of integers representing the indices of
the data points to be fetched.
:return: Dictionary containing the data points fetched from the dataset.
:rtype: dict
"""
# Call the fetch_from_idx_list method of the dataset
return self.dataset.fetch_from_idx_list(batch)
def _collate_torch_dataloader(self, batch):
"""
Function used to collate the batch
:param list[dict] batch: List of retrieved data.
:return: Dictionary containing the data points fetched from the dataset,
collated.
:rtype: dict
"""
batch_dict = {}
if isinstance(batch, dict):
return batch
conditions_names = batch[0].keys()
# Condition names
for condition_name in conditions_names:
single_cond_dict = {}
condition_args = batch[0][condition_name].keys()
for arg in condition_args:
data_list = [
batch[idx][condition_name][arg]
for idx in range(
min(
len(batch),
self.max_conditions_lengths[condition_name],
)
)
]
single_cond_dict[arg] = self._collate(data_list)
batch_dict[condition_name] = single_cond_dict
return batch_dict
@staticmethod
def _collate_tensor_dataset(data_list):
"""
Function used to collate the data when the dataset is a
:class:`~pina.data.dataset.PinaTensorDataset`.
:param data_list: Elements to be collated.
:type data_list: list[torch.Tensor] | list[LabelTensor]
:return: Batch of data.
:rtype: dict
:raises RuntimeError: If the data is not a :class:`torch.Tensor` or a
:class:`~pina.label_tensor.LabelTensor`.
"""
if isinstance(data_list[0], LabelTensor):
return LabelTensor.stack(data_list)
if isinstance(data_list[0], torch.Tensor):
return torch.stack(data_list)
raise RuntimeError("Data must be Tensors or LabelTensor ")
def _collate_graph_dataset(self, data_list):
"""
Function used to collate data when the dataset is a
:class:`~pina.data.dataset.PinaGraphDataset`.
:param data_list: Elememts to be collated.
:type data_list: list[Data] | list[Graph]
:return: Batch of data.
:rtype: dict
:raises RuntimeError: If the data is not a
:class:`~torch_geometric.data.Data` or a :class:`~pina.graph.Graph`.
"""
if isinstance(data_list[0], LabelTensor):
return LabelTensor.cat(data_list)
if isinstance(data_list[0], torch.Tensor):
return torch.cat(data_list)
if isinstance(data_list[0], Data):
return self.dataset.create_batch(data_list)
raise RuntimeError(
"Data must be Tensors or LabelTensor or pyG "
"torch_geometric.data.Data"
)
def __call__(self, batch):
"""
Perform the collation of data fetched from the dataset. The behavoior
of the function is set based on the batching strategy during class
initialization.
:param batch: List of retrieved data or sampled indices.
:type batch: list[int] | list[dict]
:return: Dictionary containing colleted data fetched from the dataset.
:rtype: dict
"""
return self.callable_function(batch)
class PinaSampler: class PinaSampler:
@@ -235,6 +40,19 @@ class PinaSampler:
return sampler return sampler
def DataloaderCollector():
def __init__(self, dataloader_list):
"""
Initialize the object.
"""
assert isinstance(dataloader_list, list)
assert all(
isinstance(dataloader, DataLoader) for dataloader in dataloader_list
)
self.dataloader_list = dataloader_list
class PinaDataModule(LightningDataModule): class PinaDataModule(LightningDataModule):
""" """
This class extends :class:`~lightning.pytorch.core.LightningDataModule`, This class extends :class:`~lightning.pytorch.core.LightningDataModule`,
@@ -376,23 +194,23 @@ class PinaDataModule(LightningDataModule):
if stage == "fit" or stage is None: if stage == "fit" or stage is None:
self.train_dataset = PinaDatasetFactory( self.train_dataset = PinaDatasetFactory(
self.data_splits["train"], self.data_splits["train"],
max_conditions_lengths=self.find_max_conditions_lengths( # max_conditions_lengths=self.find_max_conditions_lengths(
"train" # "train"
), # ),
automatic_batching=self.automatic_batching, automatic_batching=self.automatic_batching,
) )
if "val" in self.data_splits.keys(): if "val" in self.data_splits.keys():
self.val_dataset = PinaDatasetFactory( self.val_dataset = PinaDatasetFactory(
self.data_splits["val"], self.data_splits["val"],
max_conditions_lengths=self.find_max_conditions_lengths( # max_conditions_lengths=self.find_max_conditions_lengths(
"val" # "val"
), # ),
automatic_batching=self.automatic_batching, automatic_batching=self.automatic_batching,
) )
elif stage == "test": elif stage == "test":
self.test_dataset = PinaDatasetFactory( self.test_dataset = PinaDatasetFactory(
self.data_splits["test"], self.data_splits["test"],
max_conditions_lengths=self.find_max_conditions_lengths("test"), # max_conditions_lengths=self.find_max_conditions_lengths("test"),
automatic_batching=self.automatic_batching, automatic_batching=self.automatic_batching,
) )
else: else:
@@ -502,32 +320,14 @@ class PinaDataModule(LightningDataModule):
), ),
module="lightning.pytorch.trainer.connectors.data_connector", module="lightning.pytorch.trainer.connectors.data_connector",
) )
# Use custom batching (good if batch size is large) return PinaDataLoader(
if self.batch_size is not None: dataset,
sampler = PinaSampler(dataset) batch_size=self.batch_size,
if self.automatic_batching: shuffle=self.shuffle,
collate = Collator( num_workers=self.num_workers,
self.find_max_conditions_lengths(split), collate_fn=None,
self.automatic_batching, common_batch_size=True,
dataset=dataset,
)
else:
collate = Collator(
None, self.automatic_batching, dataset=dataset
)
return DataLoader(
dataset,
self.batch_size,
collate_fn=collate,
sampler=sampler,
num_workers=self.num_workers,
)
dataloader = DummyDataloader(dataset)
dataloader.dataset = self._transfer_batch_to_device(
dataloader.dataset, self.trainer.strategy.root_device, 0
) )
self.transfer_batch_to_device = self._transfer_batch_to_device_dummy
return dataloader
def find_max_conditions_lengths(self, split): def find_max_conditions_lengths(self, split):
""" """

245
pina/data/dataloader.py Normal file
View File

@@ -0,0 +1,245 @@
from torch.utils.data import DataLoader
from functools import partial
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data.sampler import SequentialSampler
import torch
class DummyDataloader:
def __init__(self, dataset):
"""
Prepare a dataloader object that returns the entire dataset in a single
batch. Depending on the number of GPUs, the dataset is managed
as follows:
- **Distributed Environment** (multiple GPUs): Divides dataset across
processes using the rank and world size. Fetches only portion of
data corresponding to the current process.
- **Non-Distributed Environment** (single GPU): Fetches the entire
dataset.
:param PinaDataset dataset: The dataset object to be processed.
.. note::
This dataloader is used when the batch size is ``None``.
"""
print("Using DummyDataloader")
if (
torch.distributed.is_available()
and torch.distributed.is_initialized()
):
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
if len(dataset) < world_size:
raise RuntimeError(
"Dimension of the dataset smaller than world size."
" Increase the size of the partition or use a single GPU"
)
idx, i = [], rank
while i < len(dataset):
idx.append(i)
i += world_size
else:
idx = list(range(len(dataset)))
self.dataset = dataset._getitem_from_list(idx)
def __iter__(self):
return self
def __len__(self):
return 1
def __next__(self):
return self.dataset
class PinaSampler:
"""
This class is used to create the sampler instance based on the shuffle
parameter and the environment in which the code is running.
"""
def __new__(cls, dataset, shuffle=True):
"""
Instantiate and initialize the sampler.
:param PinaDataset dataset: The dataset from which to sample.
:return: The sampler instance.
:rtype: :class:`torch.utils.data.Sampler`
"""
if (
torch.distributed.is_available()
and torch.distributed.is_initialized()
):
sampler = DistributedSampler(dataset, shuffle=shuffle)
else:
if shuffle:
sampler = torch.utils.data.RandomSampler(dataset)
else:
sampler = SequentialSampler(dataset)
return sampler
def _collect_items(batch):
"""
Helper function to collect items from a batch of graph data samples.
:param batch: List of graph data samples.
"""
to_return = {name: [] for name in batch[0].keys()}
for sample in batch:
for k, v in sample.items():
to_return[k].append(v)
return to_return
def collate_fn_custom(batch, dataset):
"""
Override the default collate function to handle datasets without automatic batching.
:param batch: List of indices from the dataset.
:param dataset: The PinaDataset instance (must be provided).
"""
return dataset._getitem_from_list(batch)
def collate_fn_default(batch, stack_fn):
"""
Default collate function that simply returns the batch as is.
:param batch: List of data samples.
"""
print("Using default collate function")
to_return = _collect_items(batch)
return {k: stack_fn[k](v) for k, v in to_return.items()}
class PinaDataLoader:
"""
Custom DataLoader for PinaDataset.
"""
def __init__(
self,
dataset_dict,
batch_size,
shuffle=False,
num_workers=0,
collate_fn=None,
common_batch_size=True,
):
self.dataset_dict = dataset_dict
self.batch_size = batch_size
self.shuffle = shuffle
self.num_workers = num_workers
self.collate_fn = collate_fn
print(batch_size)
if batch_size is None:
batch_size_per_dataset = {
split: None for split in dataset_dict.keys()
}
else:
if common_batch_size:
batch_size_per_dataset = {
split: batch_size for split in dataset_dict.keys()
}
else:
batch_size_per_dataset = self._compute_batch_size()
self.dataloaders = {
split: self._create_dataloader(
dataset, batch_size_per_dataset[split]
)
for split, dataset in dataset_dict.items()
}
def _compute_batch_size(self):
"""
Compute an appropriate batch size for the given dataset.
"""
elements_per_dataset = {
dataset_name: len(dataset)
for dataset_name, dataset in self.dataset_dict.items()
}
total_elements = sum(el for el in elements_per_dataset.values())
portion_per_dataset = {
name: el / total_elements
for name, el in elements_per_dataset.items()
}
batch_size_per_dataset = {
name: max(1, int(portion * self.batch_size))
for name, portion in portion_per_dataset.items()
}
tot_el_per_batch = sum(el for el in batch_size_per_dataset.values())
if self.batch_size > tot_el_per_batch:
difference = self.batch_size - tot_el_per_batch
while difference > 0:
for k, v in batch_size_per_dataset.items():
if difference == 0:
break
if v > 1:
batch_size_per_dataset[k] += 1
difference -= 1
if self.batch_size < tot_el_per_batch:
difference = tot_el_per_batch - self.batch_size
while difference > 0:
for k, v in batch_size_per_dataset.items():
if difference == 0:
break
if v > 1:
batch_size_per_dataset[k] -= 1
difference -= 1
return batch_size_per_dataset
def _create_dataloader(self, dataset, batch_size):
print(batch_size)
if batch_size is None:
return DummyDataloader(dataset)
if not dataset.automatic_batching:
collate_fn = partial(collate_fn_custom, dataset=dataset)
else:
collate_fn = partial(collate_fn_default, stack_fn=dataset.stack_fn)
return DataLoader(
dataset,
batch_size=batch_size,
num_workers=self.num_workers,
collate_fn=collate_fn,
sampler=PinaSampler(dataset, shuffle=self.shuffle),
)
def __len__(self):
return max(len(dl) for dl in self.dataloaders.values())
def __iter__(self):
"""
Restituisce un iteratore che produce dizionari di batch.
Itera per un numero di passi pari al dataloader più lungo (come da __len__)
e fa ricominciare i dataloader più corti quando si esauriscono.
"""
# 1. Crea un iteratore per ogni dataloader
iterators = {split: iter(dl) for split, dl in self.dataloaders.items()}
# 2. Itera per il numero di batch del dataloader più lungo
for _ in range(len(self)):
# 3. Prepara il dizionario di batch per questo step
batch_dict = {}
# 4. Ottieni il prossimo batch da ogni iteratore
for split, it in iterators.items():
try:
batch = next(it)
except StopIteration:
# 5. Se un iteratore è esaurito, resettalo e prendi il primo batch
new_it = iter(self.dataloaders[split])
iterators[split] = new_it # Salva il nuovo iteratore
batch = next(new_it)
batch_dict[split] = batch
# 6. Restituisci il dizionario di batch
yield batch_dict

View File

@@ -1,9 +1,10 @@
"""Module for the PINA dataset classes.""" """Module for the PINA dataset classes."""
from abc import abstractmethod, ABC
from torch.utils.data import Dataset from torch.utils.data import Dataset
from torch_geometric.data import Data from torch_geometric.data import Data
from ..graph import Graph, LabelBatch from ..graph import Graph, LabelBatch
from ..label_tensor import LabelTensor
import torch
class PinaDatasetFactory: class PinaDatasetFactory:
@@ -41,286 +42,156 @@ class PinaDatasetFactory:
if len(conditions_dict) == 0: if len(conditions_dict) == 0:
raise ValueError("No conditions provided") raise ValueError("No conditions provided")
dataset_dict = {}
# Check is a Graph is present in the conditions # Check is a Graph is present in the conditions
is_graph = cls._is_graph_dataset(conditions_dict) for name, data in conditions_dict.items():
if is_graph: if not isinstance(data, dict):
# If a Graph is present, return a PinaGraphDataset raise ValueError(
return PinaGraphDataset(conditions_dict, **kwargs) f"Condition '{name}' data must be a dictionary"
# If no Graph is present, return a PinaTensorDataset )
return PinaTensorDataset(conditions_dict, **kwargs)
# is_graph = cls._is_graph_dataset(conditions_dict)
# if is_graph:
# raise NotImplementedError("PinaGraphDataset is not implemented yet.")
dataset_dict[name] = PinaTensorDataset(data, **kwargs)
return dataset_dict
@staticmethod @staticmethod
def _is_graph_dataset(conditions_dict): def _is_graph_dataset(cond_data):
""" """
Check if a graph is present in the conditions (at least one time). TODO: Docstring
:param conditions_dict: Dictionary containing the conditions.
:type conditions_dict: dict
:return: True if a graph is present in the conditions, False otherwise.
:rtype: bool
""" """
# Iterate over the conditions dictionary # Iterate over the values of the current condition
for v in conditions_dict.values(): for cond in cond_data.values():
# Iterate over the values of the current condition if isinstance(cond, (Data, Graph, list, tuple)):
for cond in v.values(): return True
# Check if the current value is a list of Data objects
if isinstance(cond, (Data, Graph, list, tuple)):
return True
return False return False
class PinaDataset(Dataset, ABC): class PinaTensorDataset(Dataset):
""" """
Abstract class for the PINA dataset which extends the PyTorch Dataset class for the PINA dataset with :class:`torch.Tensor` and
:class:`~torch.utils.data.Dataset` class. It defines the common interface :class:`~pina.label_tensor.LabelTensor` data.
for :class:`~pina.data.dataset.PinaTensorDataset` and
:class:`~pina.data.dataset.PinaGraphDataset` classes.
""" """
def __init__( def __init__(self, data_dict, automatic_batching=None):
self, conditions_dict, max_conditions_lengths, automatic_batching
):
""" """
Initialize the instance by storing the conditions dictionary, the Initialize the instance by storing the conditions dictionary.
maximum number of items per conditions to consider, and the automatic
batching flag.
:param dict conditions_dict: A dictionary mapping condition names to :param dict conditions_dict: A dictionary mapping condition names to
their respective data. Each key represents a condition name, and the their respective data. Each key represents a condition name, and the
corresponding value is a dictionary containing the associated data. corresponding value is a dictionary containing the associated data.
:param dict max_conditions_lengths: Maximum number of data points that
can be included in a single batch per condition.
:param bool automatic_batching: Indicates whether PyTorch automatic
batching is enabled in
:class:`~pina.data.data_module.PinaDataModule`.
""" """
# Store the conditions dictionary # Store the conditions dictionary
self.conditions_dict = conditions_dict self.data = data_dict
# Store the maximum number of conditions to consider self.automatic_batching = (
self.max_conditions_lengths = max_conditions_lengths automatic_batching if automatic_batching is not None else True
# Store length of each condition )
self.conditions_length = { self.stack_fn = (
k: len(v["input"]) for k, v in self.conditions_dict.items() {}
} ) # LabelTensor.stack if any(isinstance(v, LabelTensor) for v in data_dict.values()) else torch.stack
# Store the maximum length of the dataset for k, v in data_dict.items():
self.length = max(self.conditions_length.values()) if isinstance(v, LabelTensor):
# Dynamically set the getitem function based on automatic batching self.stack_fn[k] = LabelTensor.stack
if automatic_batching: elif isinstance(v, torch.Tensor):
self._getitem_func = self._getitem_int self.stack_fn[k] = torch.stack
else: elif isinstance(v, list) and all(
self._getitem_func = self._getitem_dummy isinstance(item, (Data, Graph)) for item in v
):
def _get_max_len(self): self.stack_fn[k] = LabelBatch.from_data_list
""" else:
Returns the length of the longest condition in the dataset. raise ValueError(
f"Unsupported data type for stacking: {type(v)}"
:return: Length of the longest condition in the dataset. )
:rtype: int
"""
max_len = 0
for condition in self.conditions_dict.values():
max_len = max(max_len, len(condition["input"]))
return max_len
def __len__(self): def __len__(self):
return self.length return len(next(iter(self.data.values())))
def __getitem__(self, idx): def __getitem__(self, idx):
return self._getitem_func(idx)
def _getitem_dummy(self, idx):
""" """
Return the index itself. This is used when automatic batching is Return the data at the given index in the dataset.
disabled to postpone the data retrieval to the dataloader.
:param int idx: Index.
:return: Index.
:rtype: int
"""
# If automatic batching is disabled, return the data at the given index
return idx
def _getitem_int(self, idx):
"""
Return the data at the given index in the dataset. This is used when
automatic batching is enabled.
:param int idx: Index. :param int idx: Index.
:return: A dictionary containing the data at the given index. :return: A dictionary containing the data at the given index.
:rtype: dict :rtype: dict
""" """
# If automatic batching is enabled, return the data at the given index if self.automatic_batching:
return { # Return the data at the given index
k: {k_data: v[k_data][idx % len(v["input"])] for k_data in v.keys()} return {
for k, v in self.conditions_dict.items() field_name: data[idx] for field_name, data in self.data.items()
} }
return idx
def get_all_data(self): def _getitem_from_list(self, idx_list):
"""
Return all data in the dataset.
:return: A dictionary containing all the data in the dataset.
:rtype: dict
"""
to_return_dict = {}
for condition, data in self.conditions_dict.items():
len_condition = len(
data["input"]
) # Length of the current condition
to_return_dict[condition] = self._retrive_data(
data, list(range(len_condition))
) # Retrieve the data from the current condition
return to_return_dict
def fetch_from_idx_list(self, idx):
""" """
Return data from the dataset given a list of indices. Return data from the dataset given a list of indices.
:param list[int] idx: List of indices. :param list[int] idx_list: List of indices.
:return: A dictionary containing the data at the given indices. :return: A dictionary containing the data at the given indices.
:rtype: dict :rtype: dict
""" """
to_return_dict = {} to_return = {}
for condition, data in self.conditions_dict.items(): for field_name, data in self.data.items():
# Get the indices for the current condition if self.stack_fn[field_name] == LabelBatch.from_data_list:
cond_idx = idx[: self.max_conditions_lengths[condition]] to_return[field_name] = self.stack_fn[field_name](
# Get the length of the current condition [data[i] for i in idx_list]
condition_len = self.conditions_length[condition] )
# If the length of the dataset is greater than the length of the
# current condition, repeat the indices
if self.length > condition_len:
cond_idx = [idx % condition_len for idx in cond_idx]
# Retrieve the data from the current condition
to_return_dict[condition] = self._retrive_data(data, cond_idx)
return to_return_dict
@abstractmethod
def _retrive_data(self, data, idx_list):
"""
Abstract method to retrieve data from the dataset given a list of
indices.
"""
class PinaTensorDataset(PinaDataset):
"""
Dataset class for the PINA dataset with :class:`torch.Tensor` and
:class:`~pina.label_tensor.LabelTensor` data.
"""
# Override _retrive_data method for torch.Tensor data
def _retrive_data(self, data, idx_list):
"""
Retrieve data from the dataset given a list of indices.
:param dict data: Dictionary containing the data
(only :class:`torch.Tensor` or
:class:`~pina.label_tensor.LabelTensor`).
:param list[int] idx_list: indices to retrieve.
:return: Dictionary containing the data at the given indices.
:rtype: dict
"""
return {k: v[idx_list] for k, v in data.items()}
@property
def input(self):
"""
Return the input data for the dataset.
:return: Dictionary containing the input points.
:rtype: dict
"""
return {k: v["input"] for k, v in self.conditions_dict.items()}
def update_data(self, new_conditions_dict):
"""
Update the dataset with new data.
This method is used to update the dataset with new data. It replaces
the current data with the new data provided in the new_conditions_dict
parameter.
:param dict new_conditions_dict: Dictionary containing the new data.
:return: None
"""
for condition, data in new_conditions_dict.items():
if condition in self.conditions_dict:
self.conditions_dict[condition].update(data)
else: else:
self.conditions_dict[condition] = data to_return[field_name] = data[idx_list]
return to_return
class PinaGraphDataset(PinaDataset): class PinaGraphDataset(Dataset):
""" def __init__(self, data_dict, automatic_batching=None):
Dataset class for the PINA dataset with :class:`~torch_geometric.data.Data`
and :class:`~pina.graph.Graph` data.
"""
def _create_graph_batch(self, data):
""" """
Create a LabelBatch object from a list of Initialize the instance by storing the conditions dictionary.
:class:`~torch_geometric.data.Data` objects.
:param data: List of items to collate in a single batch. :param dict conditions_dict: A dictionary mapping condition names to
:type data: list[Data] | list[Graph] their respective data. Each key represents a condition name, and the
:return: LabelBatch object all the graph collated in a single batch corresponding value is a dictionary containing the associated data.
disconnected graphs.
:rtype: LabelBatch
"""
batch = LabelBatch.from_data_list(data)
return batch
def create_batch(self, data):
"""
Create a Batch object from a list of :class:`~torch_geometric.data.Data`
objects.
:param data: List of items to collate in a single batch.
:type data: list[Data] | list[Graph]
:return: Batch object.
:rtype: :class:`~torch_geometric.data.Batch`
| :class:`~pina.graph.LabelBatch`
""" """
if isinstance(data[0], Data): # Store the conditions dictionary
return self._create_graph_batch(data) self.data = data_dict
return self._create_tensor_batch(data) self.automatic_batching = (
automatic_batching if automatic_batching is not None else True
)
# Override _retrive_data method for graph handling def __len__(self):
def _retrive_data(self, data, idx_list): return len(next(iter(self.data.values())))
def __getitem__(self, idx):
""" """
Retrieve data from the dataset given a list of indices. Return the data at the given index in the dataset.
:param dict data: Dictionary containing the data. :param int idx: Index.
:param list[int] idx_list: List of indices to retrieve. :return: A dictionary containing the data at the given index.
:return: Dictionary containing the data at the given indices. :rtype: dict
"""
if self.automatic_batching:
# Return the data at the given index
return {
field_name: data[idx] for field_name, data in self.data.items()
}
return idx
def _getitem_from_list(self, idx_list):
"""
Return data from the dataset given a list of indices.
:param list[int] idx_list: List of indices.
:return: A dictionary containing the data at the given indices.
:rtype: dict :rtype: dict
""" """
# Return the data from the current condition
# If the data is a list of Data objects, create a Batch object
# If the data is a list of torch.Tensor objects, create a torch.Tensor
return { return {
k: ( field_name: [data[i] for i in idx_list]
self._create_graph_batch([v[i] for i in idx_list]) for field_name, data in self.data.items()
if isinstance(v, list)
else v[idx_list]
)
for k, v in data.items()
} }
@property
def input(self):
"""
Return the input data for the dataset.
:return: Dictionary containing the input points.
:rtype: dict
"""
return {k: v["input"] for k, v in self.conditions_dict.items()}