This commit is contained in:
FilippoOlivo
2025-11-14 17:03:38 +01:00
parent 43163fdf74
commit 0b877d86b9
6 changed files with 20 additions and 20 deletions

View File

@@ -26,6 +26,7 @@ Trainer, Dataset and Datamodule
Trainer <trainer.rst>
Dataset <data/dataset.rst>
DataModule <data/data_module.rst>
Dataloader <data/dataloader.rst>
Data Types
------------

View File

@@ -2,14 +2,6 @@ DataModule
======================
.. currentmodule:: pina.data.data_module
.. autoclass:: Collator
:members:
:show-inheritance:
.. autoclass:: PinaDataModule
:members:
:show-inheritance:
.. autoclass:: PinaSampler
:members:
:show-inheritance:

View File

@@ -0,0 +1,11 @@
Dataloader
======================
.. currentmodule:: pina.data.dataloader
.. autoclass:: PinaSampler
:members:
:show-inheritance:
.. autoclass:: PinaDataLoader
:members:
:show-inheritance:

View File

@@ -7,13 +7,5 @@ Dataset
:show-inheritance:
.. autoclass:: PinaDatasetFactory
:members:
:show-inheritance:
.. autoclass:: PinaGraphDataset
:members:
:show-inheritance:
.. autoclass:: PinaTensorDataset
:members:
:show-inheritance:

View File

@@ -163,6 +163,7 @@ class PinaDataLoader:
):
"""
Initialize the PinaDataLoader.
:param dict dataset_dict: A dictionary mapping dataset names to their
respective PinaDataset instances.
:param int batch_size: The batch size for the dataloader.
@@ -172,6 +173,7 @@ class PinaDataLoader:
"common_batch_size", "separate_conditions", and "proportional".
:param device: The device to which the data should be moved.
"""
self.dataset_dict = dataset_dict
self.batch_size = batch_size
self.num_workers = num_workers
@@ -209,6 +211,7 @@ class PinaDataLoader:
"""
Compute an appropriate batch size for the given dataset.
"""
# Compute number of elements per dataset
elements_per_dataset = {
dataset_name: len(dataset)
@@ -281,6 +284,7 @@ class PinaDataLoader:
def __len__(self):
"""
Return the length of the dataloader.
:return: The length of the dataloader.
:rtype: int
"""
@@ -293,6 +297,7 @@ class PinaDataLoader:
def __iter__(self):
"""
Iterate over the dataloader.
:return: Yields batches from the dataloader.
:rtype: dict
"""

View File

@@ -17,16 +17,13 @@ class PinaDatasetFactory:
"""
Factory class to create PINA datasets based on the provided conditions
dictionary.
:param dict conditions_dict: A dictionary where keys are condition names
and values are dictionaries containing the associated data.
:return: A dictionary mapping condition names to their respective
:class:`PinaDataset` instances.
"""
def __new__(cls, conditions_dict, **kwargs):
"""
Create PINA dataset instances based on the provided conditions
dictionary.
:param dict conditions_dict: A dictionary where keys are condition names
and values are dictionaries containing the associated data.
:return: A dictionary mapping condition names to their respective
@@ -92,6 +89,7 @@ class PinaDataset(Dataset):
def __len__(self):
"""
Return the length of the dataset.
:return: The length of the dataset.
:rtype: int
"""
@@ -134,6 +132,7 @@ class PinaDataset(Dataset):
def update_data(self, update_dict):
"""
Update the dataset's data in-place.
:param dict update_dict: A dictionary where keys are condition names
and values are dictionaries with updated data for those conditions.
"""