Compare commits

10 Commits

Author SHA1 Message Date
FilippoOlivo
2521e4d2bd add stack conditions option 2025-11-22 12:59:22 +01:00
FilippoOlivo
0b877d86b9 fix doc 2025-11-14 17:03:38 +01:00
FilippoOlivo
43163fdf74 fix tests and modules 2025-11-14 16:52:10 +01:00
FilippoOlivo
8440a672a7 fix tests 2025-11-13 17:03:31 +01:00
FilippoOlivo
0ee63686dd fix 2025-11-13 17:03:18 +01:00
FilippoOlivo
51a0399111 fix doc 2025-11-13 14:09:44 +01:00
FilippoOlivo
18b02f43c5 fix some codacy warnings 2025-11-13 14:01:18 +01:00
FilippoOlivo
6bb44052b0 add update_data and input functions 2025-11-13 10:48:47 +01:00
FilippoOlivo
c0cbb13a92 fix callbacks 2025-11-13 10:48:20 +01:00
FilippoOlivo
09677d3c15 integrate new datamodule in trainer 2025-11-12 15:59:48 +01:00
17 changed files with 646 additions and 478 deletions

View File

@@ -26,6 +26,7 @@ Trainer, Dataset and Datamodule
Trainer <trainer.rst> Trainer <trainer.rst>
Dataset <data/dataset.rst> Dataset <data/dataset.rst>
DataModule <data/data_module.rst> DataModule <data/data_module.rst>
Dataloader <data/dataloader.rst>
Data Types Data Types
------------ ------------

View File

@@ -2,14 +2,6 @@ DataModule
====================== ======================
.. currentmodule:: pina.data.data_module .. currentmodule:: pina.data.data_module
.. autoclass:: Collator
:members:
:show-inheritance:
.. autoclass:: PinaDataModule .. autoclass:: PinaDataModule
:members:
:show-inheritance:
.. autoclass:: PinaSampler
:members: :members:
:show-inheritance: :show-inheritance:

View File

@@ -0,0 +1,11 @@
Dataloader
======================
.. currentmodule:: pina.data.dataloader
.. autoclass:: PinaSampler
:members:
:show-inheritance:
.. autoclass:: PinaDataLoader
:members:
:show-inheritance:

View File

@@ -7,13 +7,5 @@ Dataset
:show-inheritance: :show-inheritance:
.. autoclass:: PinaDatasetFactory .. autoclass:: PinaDatasetFactory
:members:
:show-inheritance:
.. autoclass:: PinaGraphDataset
:members:
:show-inheritance:
.. autoclass:: PinaTensorDataset
:members: :members:
:show-inheritance: :show-inheritance:

View File

@@ -5,7 +5,6 @@ from lightning.pytorch import Callback
from ..label_tensor import LabelTensor from ..label_tensor import LabelTensor
from ..utils import check_consistency, is_function from ..utils import check_consistency, is_function
from ..condition import InputTargetCondition from ..condition import InputTargetCondition
from ..data.dataset import PinaGraphDataset
class NormalizerDataCallback(Callback): class NormalizerDataCallback(Callback):
@@ -122,7 +121,10 @@ class NormalizerDataCallback(Callback):
""" """
# Ensure datsets are not graph-based # Ensure datsets are not graph-based
if isinstance(trainer.datamodule.train_dataset, PinaGraphDataset): if any(
ds.is_graph_dataset
for ds in trainer.datamodule.train_dataset.values()
):
raise NotImplementedError( raise NotImplementedError(
"NormalizerDataCallback is not compatible with " "NormalizerDataCallback is not compatible with "
"graph-based datasets." "graph-based datasets."
@@ -164,8 +166,8 @@ class NormalizerDataCallback(Callback):
:param dataset: The `~pina.data.dataset.PinaDataset` dataset. :param dataset: The `~pina.data.dataset.PinaDataset` dataset.
""" """
for cond in conditions: for cond in conditions:
if cond in dataset.conditions_dict: if cond in dataset:
data = dataset.conditions_dict[cond][self.apply_to] data = dataset[cond].data[self.apply_to]
shift = self.shift_fn(data) shift = self.shift_fn(data)
scale = self.scale_fn(data) scale = self.scale_fn(data)
self._normalizer[cond] = { self._normalizer[cond] = {
@@ -197,25 +199,20 @@ class NormalizerDataCallback(Callback):
:param PinaDataset dataset: The dataset to be normalized. :param PinaDataset dataset: The dataset to be normalized.
""" """
# Initialize update dictionary
update_dataset_dict = {}
# Iterate over conditions and apply normalization # Iterate over conditions and apply normalization
for cond, norm_params in self.normalizer.items(): for cond, norm_params in self.normalizer.items():
points = dataset.conditions_dict[cond][self.apply_to] update_dataset_dict = {}
points = dataset[cond].data[self.apply_to]
scale = norm_params["scale"] scale = norm_params["scale"]
shift = norm_params["shift"] shift = norm_params["shift"]
normalized_points = self._norm_fn(points, scale, shift) normalized_points = self._norm_fn(points, scale, shift)
update_dataset_dict[cond] = { update_dataset_dict[self.apply_to] = (
self.apply_to: ( LabelTensor(normalized_points, points.labels)
LabelTensor(normalized_points, points.labels) if isinstance(points, LabelTensor)
if isinstance(points, LabelTensor) else normalized_points
else normalized_points )
) dataset[cond].data.update(update_dataset_dict)
}
# Update the dataset in-place
dataset.update_data(update_dataset_dict)
@property @property
def normalizer(self): def normalizer(self):

View File

@@ -133,13 +133,12 @@ class RefinementInterface(Callback, metaclass=ABCMeta):
:param PINNInterface solver: The solver object. :param PINNInterface solver: The solver object.
""" """
new_points = {}
for name in self._condition_to_update: for name in self._condition_to_update:
current_points = self.dataset.conditions_dict[name]["input"] new_points = {}
new_points[name] = { current_points = self.dataset[name].data["input"]
"input": self.sample(current_points, name, solver) new_points["input"] = self.sample(current_points, name, solver)
}
self.dataset.update_data(new_points) self.dataset[name].update_data(new_points)
def _compute_population_size(self, conditions): def _compute_population_size(self, conditions):
""" """
@@ -150,6 +149,5 @@ class RefinementInterface(Callback, metaclass=ABCMeta):
:rtype: dict :rtype: dict
""" """
return { return {
cond: len(self.dataset.conditions_dict[cond]["input"]) cond: len(self.dataset[cond].data["input"]) for cond in conditions
for cond in conditions
} }

View File

@@ -27,8 +27,7 @@ class PinaDataModule(LightningDataModule):
val_size=0.1, val_size=0.1,
batch_size=None, batch_size=None,
shuffle=True, shuffle=True,
common_batch_size=True, batching_mode="common_batch_size",
separate_conditions=False,
automatic_batching=None, automatic_batching=None,
num_workers=0, num_workers=0,
pin_memory=False, pin_memory=False,
@@ -84,8 +83,7 @@ class PinaDataModule(LightningDataModule):
# Store fixed attributes # Store fixed attributes
self.batch_size = batch_size self.batch_size = batch_size
self.shuffle = shuffle self.shuffle = shuffle
self.common_batch_size = common_batch_size self.batching_mode = batching_mode
self.separate_conditions = separate_conditions
self.automatic_batching = automatic_batching self.automatic_batching = automatic_batching
# If batch size is None, num_workers has no effect # If batch size is None, num_workers has no effect
@@ -255,7 +253,7 @@ class PinaDataModule(LightningDataModule):
dataset_dict[key].update({condition_name: data}) dataset_dict[key].update({condition_name: data})
return dataset_dict return dataset_dict
def _create_dataloader(self, split, dataset): def _create_dataloader(self, dataset):
""" " """ "
Create the dataloader for the given split. Create the dataloader for the given split.
@@ -275,15 +273,18 @@ class PinaDataModule(LightningDataModule):
), ),
module="lightning.pytorch.trainer.connectors.data_connector", module="lightning.pytorch.trainer.connectors.data_connector",
) )
return PinaDataLoader( dl = PinaDataLoader(
dataset, dataset,
batch_size=self.batch_size, batch_size=self.batch_size,
shuffle=self.shuffle, shuffle=self.shuffle,
num_workers=self.num_workers, num_workers=self.num_workers,
collate_fn=None, batching_mode=self.batching_mode,
common_batch_size=self.common_batch_size, device=self.trainer.strategy.root_device,
separate_conditions=self.separate_conditions,
) )
if self.batch_size is None:
# Override the method to transfer the batch to the device
self.transfer_batch_to_device = self._transfer_batch_to_device_dummy
return dl
def val_dataloader(self): def val_dataloader(self):
""" """
@@ -292,7 +293,7 @@ class PinaDataModule(LightningDataModule):
:return: The validation dataloader :return: The validation dataloader
:rtype: torch.utils.data.DataLoader :rtype: torch.utils.data.DataLoader
""" """
return self._create_dataloader("val", self.val_dataset) return self._create_dataloader(self.val_dataset)
def train_dataloader(self): def train_dataloader(self):
""" """
@@ -301,7 +302,7 @@ class PinaDataModule(LightningDataModule):
:return: The training dataloader :return: The training dataloader
:rtype: torch.utils.data.DataLoader :rtype: torch.utils.data.DataLoader
""" """
return self._create_dataloader("train", self.train_dataset) return self._create_dataloader(self.train_dataset)
def test_dataloader(self): def test_dataloader(self):
""" """
@@ -310,7 +311,7 @@ class PinaDataModule(LightningDataModule):
:return: The testing dataloader :return: The testing dataloader
:rtype: torch.utils.data.DataLoader :rtype: torch.utils.data.DataLoader
""" """
return self._create_dataloader("test", self.test_dataset) return self._create_dataloader(self.test_dataset)
@staticmethod @staticmethod
def _transfer_batch_to_device_dummy(batch, device, dataloader_idx): def _transfer_batch_to_device_dummy(batch, device, dataloader_idx):
@@ -326,7 +327,7 @@ class PinaDataModule(LightningDataModule):
:rtype: list[tuple] :rtype: list[tuple]
""" """
return batch return list(batch.items())
def _transfer_batch_to_device(self, batch, device, dataloader_idx): def _transfer_batch_to_device(self, batch, device, dataloader_idx):
""" """
@@ -384,9 +385,15 @@ class PinaDataModule(LightningDataModule):
to_return = {} to_return = {}
if hasattr(self, "train_dataset") and self.train_dataset is not None: if hasattr(self, "train_dataset") and self.train_dataset is not None:
to_return["train"] = self.train_dataset.input to_return["train"] = {
cond: data.input for cond, data in self.train_dataset.items()
}
if hasattr(self, "val_dataset") and self.val_dataset is not None: if hasattr(self, "val_dataset") and self.val_dataset is not None:
to_return["val"] = self.val_dataset.input to_return["val"] = {
cond: data.input for cond, data in self.val_dataset.items()
}
if hasattr(self, "test_dataset") and self.test_dataset is not None: if hasattr(self, "test_dataset") and self.test_dataset is not None:
to_return["test"] = self.test_dataset.input to_return["test"] = {
cond: data.input for cond, data in self.test_dataset.items()
}
return to_return return to_return

View File

@@ -1,13 +1,21 @@
from torch.utils.data import DataLoader """DataLoader module for PinaDataset."""
import itertools
import random
from functools import partial from functools import partial
import torch
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
from torch.utils.data.sampler import SequentialSampler from torch.utils.data.sampler import SequentialSampler
import torch from .stacked_dataloader import StackedDataLoader
class DummyDataloader: class DummyDataloader:
"""
DataLoader that returns the entire dataset in a single batch.
"""
def __init__(self, dataset): def __init__(self, dataset, device=None):
""" """
Prepare a dataloader object that returns the entire dataset in a single Prepare a dataloader object that returns the entire dataset in a single
batch. Depending on the number of GPUs, the dataset is managed batch. Depending on the number of GPUs, the dataset is managed
@@ -24,34 +32,52 @@ class DummyDataloader:
.. note:: .. note::
This dataloader is used when the batch size is ``None``. This dataloader is used when the batch size is ``None``.
""" """
print("Using DummyDataloader") # Handle distributed environment
if ( if PinaSampler.is_distributed():
torch.distributed.is_available() # Get rank and world size
and torch.distributed.is_initialized()
):
rank = torch.distributed.get_rank() rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size() world_size = torch.distributed.get_world_size()
# Ensure dataset is large enough
if len(dataset) < world_size: if len(dataset) < world_size:
raise RuntimeError( raise RuntimeError(
"Dimension of the dataset smaller than world size." "Dimension of the dataset smaller than world size."
" Increase the size of the partition or use a single GPU" " Increase the size of the partition or use a single GPU"
) )
# Split dataset among processes
idx, i = [], rank idx, i = [], rank
while i < len(dataset): while i < len(dataset):
idx.append(i) idx.append(i)
i += world_size i += world_size
else: else:
idx = list(range(len(dataset))) idx = list(range(len(dataset)))
self.dataset = dataset.getitem_from_list(idx)
self.dataset = dataset._getitem_from_list(idx) self.device = device
self.dataset = (
{k: v.to(self.device) for k, v in self.dataset.items()}
if self.device
else self.dataset
)
def __iter__(self): def __iter__(self):
"""
Iterate over the dataloader.
"""
return self return self
def __len__(self): def __len__(self):
"""
Return the length of the dataloader, which is always 1.
:return: The length of the dataloader.
:rtype: int
"""
return 1 return 1
def __next__(self): def __next__(self):
"""
Return the entire dataset as a single batch.
:return: The entire dataset.
:rtype: dict
"""
return self.dataset return self.dataset
@@ -70,10 +96,7 @@ class PinaSampler:
:rtype: :class:`torch.utils.data.Sampler` :rtype: :class:`torch.utils.data.Sampler`
""" """
if ( if cls.is_distributed():
torch.distributed.is_available()
and torch.distributed.is_initialized()
):
sampler = DistributedSampler(dataset, shuffle=shuffle) sampler = DistributedSampler(dataset, shuffle=shuffle)
else: else:
if shuffle: if shuffle:
@@ -82,6 +105,18 @@ class PinaSampler:
sampler = SequentialSampler(dataset) sampler = SequentialSampler(dataset)
return sampler return sampler
@staticmethod
def is_distributed():
"""
Check if the sampler is distributed.
:return: True if the sampler is distributed, False otherwise.
:rtype: bool
"""
return (
torch.distributed.is_available()
and torch.distributed.is_initialized()
)
def _collect_items(batch): def _collect_items(batch):
""" """
@@ -97,11 +132,12 @@ def _collect_items(batch):
def collate_fn_custom(batch, dataset): def collate_fn_custom(batch, dataset):
""" """
Override the default collate function to handle datasets without automatic batching. Override the default collate function to handle datasets without automatic
batching.
:param batch: List of indices from the dataset. :param batch: List of indices from the dataset.
:param dataset: The PinaDataset instance (must be provided). :param dataset: The PinaDataset instance (must be provided).
""" """
return dataset._getitem_from_list(batch) return dataset.getitem_from_list(batch)
def collate_fn_default(batch, stack_fn): def collate_fn_default(batch, stack_fn):
@@ -109,7 +145,6 @@ def collate_fn_default(batch, stack_fn):
Default collate function that simply returns the batch as is. Default collate function that simply returns the batch as is.
:param batch: List of data samples. :param batch: List of data samples.
""" """
print("Using default collate function")
to_return = _collect_items(batch) to_return = _collect_items(batch)
return {k: stack_fn[k](v) for k, v in to_return.items()} return {k: stack_fn[k](v) for k, v in to_return.items()}
@@ -119,34 +154,71 @@ class PinaDataLoader:
Custom DataLoader for PinaDataset. Custom DataLoader for PinaDataset.
""" """
def __new__(cls, *args, **kwargs):
batching_mode = kwargs.get("batching_mode", "common_batch_size").lower()
batch_size = kwargs.get("batch_size")
if batching_mode == "stacked" and batch_size is not None:
return StackedDataLoader(
args[0],
batch_size=batch_size,
shuffle=kwargs.get("shuffle", True),
)
elif batch_size is None:
kwargs["batching_mode"] = "proportional"
print(
"Using PinaDataLoader with batching mode:", kwargs["batching_mode"]
)
return super(PinaDataLoader, cls).__new__(cls)
def __init__( def __init__(
self, self,
dataset_dict, dataset_dict,
batch_size, batch_size,
shuffle=False,
num_workers=0, num_workers=0,
collate_fn=None, shuffle=False,
common_batch_size=True, batching_mode="common_batch_size",
separate_conditions=False, device=None,
): ):
"""
Initialize the PinaDataLoader.
:param dict dataset_dict: A dictionary mapping dataset names to their
respective PinaDataset instances.
:param int batch_size: The batch size for the dataloader.
:param int num_workers: Number of worker processes for data loading.
:param bool shuffle: Whether to shuffle the data at every epoch.
:param str batching_mode: The batching mode to use. Options are
"common_batch_size", "separate_conditions", and "proportional".
:param device: The device to which the data should be moved.
"""
self.dataset_dict = dataset_dict self.dataset_dict = dataset_dict
self.batch_size = batch_size self.batch_size = batch_size
self.shuffle = shuffle
self.num_workers = num_workers self.num_workers = num_workers
self.collate_fn = collate_fn self.shuffle = shuffle
self.separate_conditions = separate_conditions self.batching_mode = batching_mode.lower()
self.device = device
# Batch size None means we want to load the entire dataset in a single
# batch
if batch_size is None: if batch_size is None:
batch_size_per_dataset = { batch_size_per_dataset = {
split: None for split in dataset_dict.keys() split: None for split in dataset_dict.keys()
} }
else: else:
if common_batch_size: # Compute batch size per dataset
if batching_mode in ["common_batch_size", "separate_conditions"]:
# (the sum of the batch sizes is equal to
# n_conditions * batch_size)
batch_size_per_dataset = { batch_size_per_dataset = {
split: batch_size for split in dataset_dict.keys() split: min(batch_size, len(ds))
for split, ds in dataset_dict.items()
} }
else: elif batching_mode == "proportional":
# batch sizes is equal to the specified batch size)
batch_size_per_dataset = self._compute_batch_size() batch_size_per_dataset = self._compute_batch_size()
# Creaete a dataloader per dataset
self.dataloaders = { self.dataloaders = {
split: self._create_dataloader( split: self._create_dataloader(
dataset, batch_size_per_dataset[split] dataset, batch_size_per_dataset[split]
@@ -158,21 +230,27 @@ class PinaDataLoader:
""" """
Compute an appropriate batch size for the given dataset. Compute an appropriate batch size for the given dataset.
""" """
# Compute number of elements per dataset
elements_per_dataset = { elements_per_dataset = {
dataset_name: len(dataset) dataset_name: len(dataset)
for dataset_name, dataset in self.dataset_dict.items() for dataset_name, dataset in self.dataset_dict.items()
} }
# Compute the total number of elements
total_elements = sum(el for el in elements_per_dataset.values()) total_elements = sum(el for el in elements_per_dataset.values())
# Compute the portion of each dataset
portion_per_dataset = { portion_per_dataset = {
name: el / total_elements name: el / total_elements
for name, el in elements_per_dataset.items() for name, el in elements_per_dataset.items()
} }
# Compute batch size per dataset. Ensure at least 1 element per
# dataset.
batch_size_per_dataset = { batch_size_per_dataset = {
name: max(1, int(portion * self.batch_size)) name: max(1, int(portion * self.batch_size))
for name, portion in portion_per_dataset.items() for name, portion in portion_per_dataset.items()
} }
# Adjust batch sizes to match the specified total batch size
tot_el_per_batch = sum(el for el in batch_size_per_dataset.values()) tot_el_per_batch = sum(el for el in batch_size_per_dataset.values())
if self.batch_size > tot_el_per_batch: if self.batch_size > tot_el_per_batch:
difference = self.batch_size - tot_el_per_batch difference = self.batch_size - tot_el_per_batch
while difference > 0: while difference > 0:
@@ -194,49 +272,76 @@ class PinaDataLoader:
return batch_size_per_dataset return batch_size_per_dataset
def _create_dataloader(self, dataset, batch_size): def _create_dataloader(self, dataset, batch_size):
print(batch_size) """
if batch_size is None: Create the dataloader for the given dataset.
return DummyDataloader(dataset)
:param PinaDataset dataset: The dataset for which to create the
dataloader.
:param int batch_size: The batch size for the dataloader.
:return: The created dataloader.
:rtype: :class:`torch.utils.data.DataLoader`
"""
# If batch size is None, use DummyDataloader
if batch_size is None or batch_size >= len(dataset):
return DummyDataloader(dataset, device=self.device)
# Determine the appropriate collate function
if not dataset.automatic_batching: if not dataset.automatic_batching:
collate_fn = partial(collate_fn_custom, dataset=dataset) collate_fn = partial(collate_fn_custom, dataset=dataset)
else: else:
collate_fn = partial(collate_fn_default, stack_fn=dataset.stack_fn) collate_fn = partial(collate_fn_default, stack_fn=dataset.stack_fn)
# Create and return the dataloader
return DataLoader( return DataLoader(
dataset, dataset,
batch_size=batch_size, batch_size=batch_size,
num_workers=self.num_workers,
collate_fn=collate_fn, collate_fn=collate_fn,
num_workers=self.num_workers,
sampler=PinaSampler(dataset, shuffle=self.shuffle), sampler=PinaSampler(dataset, shuffle=self.shuffle),
) )
def __len__(self): def __len__(self):
if self.separate_conditions: """
Return the length of the dataloader.
:return: The length of the dataloader.
:rtype: int
"""
# If separate conditions, return sum of lengths of all dataloaders
# else, return max length among dataloaders
if self.batching_mode == "separate_conditions":
return sum(len(dl) for dl in self.dataloaders.values()) return sum(len(dl) for dl in self.dataloaders.values())
return max(len(dl) for dl in self.dataloaders.values()) return max(len(dl) for dl in self.dataloaders.values())
def __iter__(self): def __iter__(self):
""" """
Restituisce un iteratore che produce dizionari di batch. Iterate over the dataloader. Yields a dictionary mapping split name to batch.
Itera per un numero di passi pari al dataloader più lungo (come da __len__) The iteration logic for 'separate_conditions' is now iterative and memory-efficient.
e fa ricominciare i dataloader più corti quando si esauriscono.
""" """
if self.separate_conditions: if self.batching_mode == "separate_conditions":
tmp = []
for split, dl in self.dataloaders.items(): for split, dl in self.dataloaders.items():
for batch in dl: len_split = len(dl)
yield {split: batch} for i, batch in enumerate(dl):
tmp.append({split: batch})
if i + 1 >= len_split:
break
random.shuffle(tmp)
for batch_dict in tmp:
yield batch_dict
return return
iterators = {split: iter(dl) for split, dl in self.dataloaders.items()} # Common_batch_size or Proportional mode (round-robin sampling)
iterators = {
split: itertools.cycle(dl) for split, dl in self.dataloaders.items()
}
# Iterate for the length of the longest dataloader
for _ in range(len(self)): for _ in range(len(self)):
batch_dict = {} batch_dict: BatchDict = {}
for split, it in iterators.items(): for split, it in iterators.items():
try: # Since we use itertools.cycle, next(it) will always yield a batch
batch = next(it) # by repeating the dataset, so no need for the 'if batch is None: return' check.
except StopIteration: batch_dict[split] = next(it)
new_it = iter(self.dataloaders[split])
iterators[split] = new_it
batch = next(new_it)
batch_dict[split] = batch
yield batch_dict yield batch_dict

View File

@@ -6,29 +6,44 @@ from torch_geometric.data import Data
from ..graph import Graph, LabelBatch from ..graph import Graph, LabelBatch
from ..label_tensor import LabelTensor from ..label_tensor import LabelTensor
STACK_FN_MAP = {
"label_tensor": LabelTensor.stack,
"tensor": torch.stack,
"data": LabelBatch.from_data_list,
}
class PinaDatasetFactory: class PinaDatasetFactory:
""" """
TODO: Update docstring Factory class to create PINA datasets based on the provided conditions
dictionary.
""" """
def __new__(cls, conditions_dict, **kwargs): def __new__(cls, conditions_dict, **kwargs):
""" """
TODO: Update docstring Create PINA dataset instances based on the provided conditions
dictionary.
:param dict conditions_dict: A dictionary where keys are condition names
and values are dictionaries containing the associated data.
:return: A dictionary mapping condition names to their respective
:class:`PinaDataset` instances.
""" """
# Check if conditions_dict is empty # Check if conditions_dict is empty
if len(conditions_dict) == 0: if len(conditions_dict) == 0:
raise ValueError("No conditions provided") raise ValueError("No conditions provided")
dataset_dict = {} dataset_dict = {} # Dictionary to hold the created datasets
# Check is a Graph is present in the conditions # Check is a Graph is present in the conditions
for name, data in conditions_dict.items(): for name, data in conditions_dict.items():
# Validate that data is a dictionary
if not isinstance(data, dict): if not isinstance(data, dict):
raise ValueError( raise ValueError(
f"Condition '{name}' data must be a dictionary" f"Condition '{name}' data must be a dictionary"
) )
# Create PinaDataset instance for each condition
dataset_dict[name] = PinaDataset(data, **kwargs) dataset_dict[name] = PinaDataset(data, **kwargs)
return dataset_dict return dataset_dict
@@ -53,23 +68,31 @@ class PinaDataset(Dataset):
self.automatic_batching = ( self.automatic_batching = (
automatic_batching if automatic_batching is not None else True automatic_batching if automatic_batching is not None else True
) )
self.stack_fn = {} self._stack_fn = {}
self.is_graph_dataset = False
# Determine stacking functions for each data type (used in collate_fn) # Determine stacking functions for each data type (used in collate_fn)
for k, v in data_dict.items(): for k, v in data_dict.items():
if isinstance(v, LabelTensor): if isinstance(v, LabelTensor):
self.stack_fn[k] = LabelTensor.stack self._stack_fn[k] = "label_tensor"
elif isinstance(v, torch.Tensor): elif isinstance(v, torch.Tensor):
self.stack_fn[k] = torch.stack self._stack_fn[k] = "tensor"
elif isinstance(v, list) and all( elif isinstance(v, list) and all(
isinstance(item, (Data, Graph)) for item in v isinstance(item, (Data, Graph)) for item in v
): ):
self.stack_fn[k] = LabelBatch.from_data_list self._stack_fn[k] = "data"
self.is_graph_dataset = True
else: else:
raise ValueError( raise ValueError(
f"Unsupported data type for stacking: {type(v)}" f"Unsupported data type for stacking: {type(v)}"
) )
def __len__(self): def __len__(self):
"""
Return the length of the dataset.
:return: The length of the dataset.
:rtype: int
"""
return len(next(iter(self.data.values()))) return len(next(iter(self.data.values())))
def __getitem__(self, idx): def __getitem__(self, idx):
@@ -88,7 +111,7 @@ class PinaDataset(Dataset):
} }
return idx return idx
def _getitem_from_list(self, idx_list): def getitem_from_list(self, idx_list):
""" """
Return data from the dataset given a list of indices. Return data from the dataset given a list of indices.
@@ -99,60 +122,49 @@ class PinaDataset(Dataset):
to_return = {} to_return = {}
for field_name, data in self.data.items(): for field_name, data in self.data.items():
if self.stack_fn[field_name] == LabelBatch.from_data_list: if self._stack_fn[field_name] == "data":
to_return[field_name] = self.stack_fn[field_name]( fn = STACK_FN_MAP[self._stack_fn[field_name]]
[data[i] for i in idx_list] to_return[field_name] = fn([data[i] for i in idx_list])
)
else: else:
to_return[field_name] = data[idx_list] to_return[field_name] = data[idx_list]
return to_return return to_return
def update_data(self, update_dict):
class PinaGraphDataset(Dataset):
def __init__(self, data_dict, automatic_batching=None):
""" """
Initialize the instance by storing the conditions dictionary. Update the dataset's data in-place.
:param dict conditions_dict: A dictionary mapping condition names to :param dict update_dict: A dictionary where keys are condition names
their respective data. Each key represents a condition name, and the and values are dictionaries with updated data for those conditions.
corresponding value is a dictionary containing the associated data.
""" """
for field_name, updates in update_dict.items():
if field_name not in self.data:
raise KeyError(
f"Condition '{field_name}' not found in dataset."
)
if not isinstance(updates, (LabelTensor, torch.Tensor)):
raise ValueError(
f"Updates for condition '{field_name}' must be of type "
f"LabelTensor or torch.Tensor."
)
self.data[field_name] = updates
# Store the conditions dictionary @property
self.data = data_dict def input(self):
self.automatic_batching = (
automatic_batching if automatic_batching is not None else True
)
def __len__(self):
return len(next(iter(self.data.values())))
def __getitem__(self, idx):
""" """
Return the data at the given index in the dataset. Get the input data from the dataset.
:param int idx: Index. :return: The input data.
:return: A dictionary containing the data at the given index. :rtype: torch.Tensor | LabelTensor | Data | Graph
"""
return self.data["input"]
@property
def stack_fn(self):
"""
Get the mapping of stacking functions for each data type in the dataset.
:return: A dictionary mapping condition names to their respective
stacking function identifiers.
:rtype: dict :rtype: dict
""" """
return {k: STACK_FN_MAP[v] for k, v in self._stack_fn.items()}
if self.automatic_batching:
# Return the data at the given index
return {
field_name: data[idx] for field_name, data in self.data.items()
}
return idx
def _getitem_from_list(self, idx_list):
"""
Return data from the dataset given a list of indices.
:param list[int] idx_list: List of indices.
:return: A dictionary containing the data at the given indices.
:rtype: dict
"""
return {
field_name: [data[i] for i in idx_list]
for field_name, data in self.data.items()
}

View File

@@ -0,0 +1,53 @@
import torch
from math import ceil
class StackedDataLoader:
def __init__(self, datasets, batch_size=32, shuffle=True):
for d in datasets.values():
if d.is_graph_dataset:
raise ValueError("Each dataset must be a dictionary")
self.chunks = {}
self.total_length = 0
self.indices = []
self._init_chunks(datasets)
self.indices = list(range(self.total_length))
self.batch_size = batch_size
self.shuffle = shuffle
if self.shuffle:
torch.random.manual_seed(42)
self.indices = torch.randperm(self.total_length).tolist()
self.datasets = datasets
def _init_chunks(self, datasets):
inc = 0
total_length = 0
for name, dataset in datasets.items():
self.chunks[name] = {"start": inc, "end": inc + len(dataset)}
inc += len(dataset)
self.total_length = inc
def __len__(self):
return ceil(self.total_length / self.batch_size)
def _build_batch_indices(self, batch_idx):
start = batch_idx * self.batch_size
end = min(start + self.batch_size, self.total_length)
return self.indices[start:end]
def __iter__(self):
for batch_idx in range(len(self)):
batch_indices = self._build_batch_indices(batch_idx)
batch_data = {}
for name, chunk in self.chunks.items():
local_indices = [
idx - chunk["start"]
for idx in batch_indices
if chunk["start"] <= idx < chunk["end"]
]
if local_indices:
batch_data[name] = self.datasets[name].getitem_from_list(
local_indices
)
yield batch_data

View File

@@ -31,7 +31,7 @@ class Trainer(lightning.pytorch.Trainer):
test_size=0.0, test_size=0.0,
val_size=0.0, val_size=0.0,
compile=None, compile=None,
repeat=None, batching_mode="common_batch_size",
automatic_batching=None, automatic_batching=None,
num_workers=None, num_workers=None,
pin_memory=None, pin_memory=None,
@@ -56,10 +56,12 @@ class Trainer(lightning.pytorch.Trainer):
:param bool compile: If ``True``, the model is compiled before training. :param bool compile: If ``True``, the model is compiled before training.
Default is ``False``. For Windows users, it is always disabled. Not Default is ``False``. For Windows users, it is always disabled. Not
supported for python version greater or equal than 3.14. supported for python version greater or equal than 3.14.
:param bool repeat: Whether to repeat the dataset data in each :param bool common_batch_size: If ``True``, the same batch size is used
condition during training. For further details, see the for all conditions. If ``False``, each condition can have its own
:class:`~pina.data.data_module.PinaDataModule` class. Default is batch size, proportional to the size of the dataset in that
``False``. condition. Default is ``True``.
:param bool separate_conditions: If ``True``, dataloaders for each
condition are iterated separately. Default is ``False``.
:param bool automatic_batching: If ``True``, automatic PyTorch batching :param bool automatic_batching: If ``True``, automatic PyTorch batching
is performed, otherwise the items are retrieved from the dataset is performed, otherwise the items are retrieved from the dataset
all at once. For further details, see the all at once. For further details, see the
@@ -82,7 +84,7 @@ class Trainer(lightning.pytorch.Trainer):
train_size=train_size, train_size=train_size,
test_size=test_size, test_size=test_size,
val_size=val_size, val_size=val_size,
repeat=repeat, batching_mode=batching_mode,
automatic_batching=automatic_batching, automatic_batching=automatic_batching,
compile=compile, compile=compile,
) )
@@ -122,8 +124,6 @@ class Trainer(lightning.pytorch.Trainer):
UserWarning, UserWarning,
) )
repeat = repeat if repeat is not None else False
automatic_batching = ( automatic_batching = (
automatic_batching if automatic_batching is not None else False automatic_batching if automatic_batching is not None else False
) )
@@ -139,7 +139,7 @@ class Trainer(lightning.pytorch.Trainer):
test_size=test_size, test_size=test_size,
val_size=val_size, val_size=val_size,
batch_size=batch_size, batch_size=batch_size,
repeat=repeat, batching_mode=batching_mode,
automatic_batching=automatic_batching, automatic_batching=automatic_batching,
pin_memory=pin_memory, pin_memory=pin_memory,
num_workers=num_workers, num_workers=num_workers,
@@ -177,7 +177,7 @@ class Trainer(lightning.pytorch.Trainer):
test_size, test_size,
val_size, val_size,
batch_size, batch_size,
repeat, batching_mode,
automatic_batching, automatic_batching,
pin_memory, pin_memory,
num_workers, num_workers,
@@ -196,8 +196,10 @@ class Trainer(lightning.pytorch.Trainer):
:param float val_size: The percentage of elements to include in the :param float val_size: The percentage of elements to include in the
validation dataset. validation dataset.
:param int batch_size: The number of samples per batch to load. :param int batch_size: The number of samples per batch to load.
:param bool repeat: Whether to repeat the dataset data in each :param bool common_batch_size: Whether to use the same batch size for
condition during training. all conditions.
:param bool seperate_conditions: Whether to iterate dataloaders for
each condition separately.
:param bool automatic_batching: Whether to perform automatic batching :param bool automatic_batching: Whether to perform automatic batching
with PyTorch. with PyTorch.
:param bool pin_memory: Whether to use pinned memory for faster data :param bool pin_memory: Whether to use pinned memory for faster data
@@ -227,7 +229,7 @@ class Trainer(lightning.pytorch.Trainer):
test_size=test_size, test_size=test_size,
val_size=val_size, val_size=val_size,
batch_size=batch_size, batch_size=batch_size,
repeat=repeat, batching_mode=batching_mode,
automatic_batching=automatic_batching, automatic_batching=automatic_batching,
num_workers=num_workers, num_workers=num_workers,
pin_memory=pin_memory, pin_memory=pin_memory,
@@ -279,7 +281,7 @@ class Trainer(lightning.pytorch.Trainer):
train_size, train_size,
test_size, test_size,
val_size, val_size,
repeat, batching_mode,
automatic_batching, automatic_batching,
compile, compile,
): ):
@@ -293,8 +295,10 @@ class Trainer(lightning.pytorch.Trainer):
test dataset. test dataset.
:param float val_size: The percentage of elements to include in the :param float val_size: The percentage of elements to include in the
validation dataset. validation dataset.
:param bool repeat: Whether to repeat the dataset data in each :param bool common_batch_size: Whether to use the same batch size for
condition during training. all conditions.
:param bool seperate_conditions: Whether to iterate dataloaders for
each condition separately.
:param bool automatic_batching: Whether to perform automatic batching :param bool automatic_batching: Whether to perform automatic batching
with PyTorch. with PyTorch.
:param bool compile: If ``True``, the model is compiled before training. :param bool compile: If ``True``, the model is compiled before training.
@@ -304,8 +308,7 @@ class Trainer(lightning.pytorch.Trainer):
check_consistency(train_size, float) check_consistency(train_size, float)
check_consistency(test_size, float) check_consistency(test_size, float)
check_consistency(val_size, float) check_consistency(val_size, float)
if repeat is not None: check_consistency(batching_mode, str)
check_consistency(repeat, bool)
if automatic_batching is not None: if automatic_batching is not None:
check_consistency(automatic_batching, bool) check_consistency(automatic_batching, bool)
if compile is not None: if compile is not None:

View File

@@ -51,7 +51,7 @@ def test_sample(condition_to_update):
} }
trainer.train() trainer.train()
after_n_points = { after_n_points = {
loc: len(trainer.data_module.train_dataset.input[loc]) loc: len(trainer.data_module.train_dataset[loc].input)
for loc in condition_to_update for loc in condition_to_update
} }
assert before_n_points == trainer.callbacks[0].initial_population_size assert before_n_points == trainer.callbacks[0].initial_population_size

View File

@@ -142,14 +142,10 @@ def test_setup(solver, fn, stage, apply_to):
for cond in ["data1", "data2"]: for cond in ["data1", "data2"]:
scale = scale_fn( scale = scale_fn(
trainer_copy.data_module.train_dataset.conditions_dict[cond][ trainer_copy.data_module.train_dataset[cond].data[apply_to]
apply_to
]
) )
shift = shift_fn( shift = shift_fn(
trainer_copy.data_module.train_dataset.conditions_dict[cond][ trainer_copy.data_module.train_dataset[cond].data[apply_to]
apply_to
]
) )
assert "scale" in normalizer[cond] assert "scale" in normalizer[cond]
assert "shift" in normalizer[cond] assert "shift" in normalizer[cond]
@@ -158,8 +154,8 @@ def test_setup(solver, fn, stage, apply_to):
for ds_name in stage_map[stage]: for ds_name in stage_map[stage]:
dataset = getattr(trainer.data_module, ds_name, None) dataset = getattr(trainer.data_module, ds_name, None)
old_dataset = getattr(trainer_copy.data_module, ds_name, None) old_dataset = getattr(trainer_copy.data_module, ds_name, None)
current_points = dataset.conditions_dict[cond][apply_to] current_points = dataset[cond].data[apply_to]
old_points = old_dataset.conditions_dict[cond][apply_to] old_points = old_dataset[cond].data[apply_to]
expected = (old_points - shift) / scale expected = (old_points - shift) / scale
assert torch.allclose(current_points, expected) assert torch.allclose(current_points, expected)
@@ -204,10 +200,10 @@ def test_setup_pinn(fn, stage, apply_to):
cond = "data" cond = "data"
scale = scale_fn( scale = scale_fn(
trainer_copy.data_module.train_dataset.conditions_dict[cond][apply_to] trainer_copy.data_module.train_dataset[cond].data[apply_to]
) )
shift = shift_fn( shift = shift_fn(
trainer_copy.data_module.train_dataset.conditions_dict[cond][apply_to] trainer_copy.data_module.train_dataset[cond].data[apply_to]
) )
assert "scale" in normalizer[cond] assert "scale" in normalizer[cond]
assert "shift" in normalizer[cond] assert "shift" in normalizer[cond]
@@ -216,8 +212,8 @@ def test_setup_pinn(fn, stage, apply_to):
for ds_name in stage_map[stage]: for ds_name in stage_map[stage]:
dataset = getattr(trainer.data_module, ds_name, None) dataset = getattr(trainer.data_module, ds_name, None)
old_dataset = getattr(trainer_copy.data_module, ds_name, None) old_dataset = getattr(trainer_copy.data_module, ds_name, None)
current_points = dataset.conditions_dict[cond][apply_to] current_points = dataset[cond].data[apply_to]
old_points = old_dataset.conditions_dict[cond][apply_to] old_points = old_dataset[cond].data[apply_to]
expected = (old_points - shift) / scale expected = (old_points - shift) / scale
assert torch.allclose(current_points, expected) assert torch.allclose(current_points, expected)
@@ -242,3 +238,7 @@ def test_setup_graph_dataset():
) )
with pytest.raises(NotImplementedError): with pytest.raises(NotImplementedError):
trainer.train() trainer.train()
# if __name__ == "__main__":
# test_setup(supervised_solver_lt, [torch.std, torch.mean], "all", "input")

View File

@@ -1,10 +1,11 @@
import torch import torch
import pytest import pytest
from pina.data import PinaDataModule from pina.data import PinaDataModule
from pina.data.dataset import PinaTensorDataset, PinaGraphDataset from pina.data.dataset import PinaDataset
from pina.problem.zoo import SupervisedProblem from pina.problem.zoo import SupervisedProblem
from pina.graph import RadiusGraph from pina.graph import RadiusGraph
from pina.data.data_module import DummyDataloader
from pina.data.dataloader import DummyDataloader, PinaDataLoader
from pina import Trainer from pina import Trainer
from pina.solver import SupervisedSolver from pina.solver import SupervisedSolver
from torch_geometric.data import Batch from torch_geometric.data import Batch
@@ -44,22 +45,33 @@ def test_setup_train(input_, output_, train_size, val_size, test_size):
) )
dm.setup() dm.setup()
assert hasattr(dm, "train_dataset") assert hasattr(dm, "train_dataset")
if isinstance(input_, torch.Tensor): assert isinstance(dm.train_dataset, dict)
assert isinstance(dm.train_dataset, PinaTensorDataset) assert all(
else: isinstance(dm.train_dataset[cond], PinaDataset)
assert isinstance(dm.train_dataset, PinaGraphDataset) for cond in dm.train_dataset
# assert len(dm.train_dataset) == int(len(input_) * train_size) )
assert all(
dm.train_dataset[cond].is_graph_dataset == isinstance(input_, list)
for cond in dm.train_dataset
)
assert all(
len(dm.train_dataset[cond]) == int(len(input_) * train_size)
for cond in dm.train_dataset
)
if test_size > 0: if test_size > 0:
assert hasattr(dm, "test_dataset") assert hasattr(dm, "test_dataset")
assert dm.test_dataset is None assert dm.test_dataset is None
else: else:
assert not hasattr(dm, "test_dataset") assert not hasattr(dm, "test_dataset")
assert hasattr(dm, "val_dataset") assert hasattr(dm, "val_dataset")
if isinstance(input_, torch.Tensor):
assert isinstance(dm.val_dataset, PinaTensorDataset) assert isinstance(dm.val_dataset, dict)
else: assert all(
assert isinstance(dm.val_dataset, PinaGraphDataset) isinstance(dm.val_dataset[cond], PinaDataset) for cond in dm.val_dataset
# assert len(dm.val_dataset) == int(len(input_) * val_size) )
assert all(
isinstance(dm.val_dataset[cond], PinaDataset) for cond in dm.val_dataset
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
@@ -87,49 +99,59 @@ def test_setup_test(input_, output_, train_size, val_size, test_size):
assert not hasattr(dm, "val_dataset") assert not hasattr(dm, "val_dataset")
assert hasattr(dm, "test_dataset") assert hasattr(dm, "test_dataset")
if isinstance(input_, torch.Tensor): assert all(
assert isinstance(dm.test_dataset, PinaTensorDataset) isinstance(dm.test_dataset[cond], PinaDataset)
else: for cond in dm.test_dataset
assert isinstance(dm.test_dataset, PinaGraphDataset) )
# assert len(dm.test_dataset) == int(len(input_) * test_size) assert all(
dm.test_dataset[cond].is_graph_dataset == isinstance(input_, list)
for cond in dm.test_dataset
@pytest.mark.parametrize( )
"input_, output_", assert all(
[(input_tensor, output_tensor), (input_graph, output_graph)], len(dm.test_dataset[cond]) == int(len(input_) * test_size)
) for cond in dm.test_dataset
def test_dummy_dataloader(input_, output_):
problem = SupervisedProblem(input_=input_, output_=output_)
solver = SupervisedSolver(problem=problem, model=torch.nn.Linear(10, 10))
trainer = Trainer(
solver, batch_size=None, train_size=0.7, val_size=0.3, test_size=0.0
) )
dm = trainer.data_module
dm.setup()
dm.trainer = trainer
dataloader = dm.train_dataloader()
assert isinstance(dataloader, DummyDataloader)
assert len(dataloader) == 1
data = next(dataloader)
assert isinstance(data, list)
assert isinstance(data[0], tuple)
if isinstance(input_, list):
assert isinstance(data[0][1]["input"], Batch)
else:
assert isinstance(data[0][1]["input"], torch.Tensor)
assert isinstance(data[0][1]["target"], torch.Tensor)
dataloader = dm.val_dataloader()
assert isinstance(dataloader, DummyDataloader) # @pytest.mark.parametrize(
assert len(dataloader) == 1 # "input_, output_",
data = next(dataloader) # [(input_tensor, output_tensor), (input_graph, output_graph)],
assert isinstance(data, list) # )
assert isinstance(data[0], tuple) # def test_dummy_dataloader(input_, output_):
if isinstance(input_, list): # problem = SupervisedProblem(input_=input_, output_=output_)
assert isinstance(data[0][1]["input"], Batch) # solver = SupervisedSolver(problem=problem, model=torch.nn.Linear(10, 10))
else: # trainer = Trainer(
assert isinstance(data[0][1]["input"], torch.Tensor) # solver, batch_size=None, train_size=0.7, val_size=0.3, test_size=0.0
assert isinstance(data[0][1]["target"], torch.Tensor) # )
# dm = trainer.data_module
# dm.setup()
# dm.trainer = trainer
# dataloader = dm.train_dataloader()
# assert isinstance(dataloader, PinaDataLoader)
# print(dataloader.dataloaders)
# assert all([isinstance(ds, DummyDataloader) for ds in dataloader.dataloaders.values()])
# data = next(iter(dataloader))
# assert isinstance(data, list)
# assert isinstance(data[0], tuple)
# if isinstance(input_, list):
# assert isinstance(data[0][1]["input"], Batch)
# else:
# assert isinstance(data[0][1]["input"], torch.Tensor)
# assert isinstance(data[0][1]["target"], torch.Tensor)
# dataloader = dm.val_dataloader()
# assert isinstance(dataloader, DummyDataloader)
# assert len(dataloader) == 1
# data = next(dataloader)
# assert isinstance(data, list)
# assert isinstance(data[0], tuple)
# if isinstance(input_, list):
# assert isinstance(data[0][1]["input"], Batch)
# else:
# assert isinstance(data[0][1]["input"], torch.Tensor)
# assert isinstance(data[0][1]["target"], torch.Tensor)
@pytest.mark.parametrize( @pytest.mark.parametrize(
@@ -137,7 +159,11 @@ def test_dummy_dataloader(input_, output_):
[(input_tensor, output_tensor), (input_graph, output_graph)], [(input_tensor, output_tensor), (input_graph, output_graph)],
) )
@pytest.mark.parametrize("automatic_batching", [True, False]) @pytest.mark.parametrize("automatic_batching", [True, False])
def test_dataloader(input_, output_, automatic_batching): @pytest.mark.parametrize("batch_size", [None, 10])
@pytest.mark.parametrize("batching_mode", ["common_batch_size", "propotional"])
def test_dataloader(
input_, output_, automatic_batching, batch_size, batching_mode
):
problem = SupervisedProblem(input_=input_, output_=output_) problem = SupervisedProblem(input_=input_, output_=output_)
solver = SupervisedSolver(problem=problem, model=torch.nn.Linear(10, 10)) solver = SupervisedSolver(problem=problem, model=torch.nn.Linear(10, 10))
trainer = Trainer( trainer = Trainer(
@@ -147,12 +173,13 @@ def test_dataloader(input_, output_, automatic_batching):
val_size=0.3, val_size=0.3,
test_size=0.0, test_size=0.0,
automatic_batching=automatic_batching, automatic_batching=automatic_batching,
batching_mode=batching_mode,
) )
dm = trainer.data_module dm = trainer.data_module
dm.setup() dm.setup()
dm.trainer = trainer dm.trainer = trainer
dataloader = dm.train_dataloader() dataloader = dm.train_dataloader()
assert isinstance(dataloader, DataLoader) assert isinstance(dataloader, PinaDataLoader)
assert len(dataloader) == 7 assert len(dataloader) == 7
data = next(iter(dataloader)) data = next(iter(dataloader))
assert isinstance(data, dict) assert isinstance(data, dict)
@@ -163,8 +190,8 @@ def test_dataloader(input_, output_, automatic_batching):
assert isinstance(data["data"]["target"], torch.Tensor) assert isinstance(data["data"]["target"], torch.Tensor)
dataloader = dm.val_dataloader() dataloader = dm.val_dataloader()
assert isinstance(dataloader, DataLoader) assert isinstance(dataloader, PinaDataLoader)
assert len(dataloader) == 3 assert len(dataloader) == 3 if batch_size is not None else 1
data = next(iter(dataloader)) data = next(iter(dataloader))
assert isinstance(data, dict) assert isinstance(data, dict)
if isinstance(input_, list): if isinstance(input_, list):
@@ -202,12 +229,13 @@ def test_dataloader_labels(input_, output_, automatic_batching):
val_size=0.3, val_size=0.3,
test_size=0.0, test_size=0.0,
automatic_batching=automatic_batching, automatic_batching=automatic_batching,
# common_batch_size=True,
) )
dm = trainer.data_module dm = trainer.data_module
dm.setup() dm.setup()
dm.trainer = trainer dm.trainer = trainer
dataloader = dm.train_dataloader() dataloader = dm.train_dataloader()
assert isinstance(dataloader, DataLoader) assert isinstance(dataloader, PinaDataLoader)
assert len(dataloader) == 7 assert len(dataloader) == 7
data = next(iter(dataloader)) data = next(iter(dataloader))
assert isinstance(data, dict) assert isinstance(data, dict)
@@ -223,7 +251,7 @@ def test_dataloader_labels(input_, output_, automatic_batching):
assert data["data"]["target"].labels == ["u", "v", "w"] assert data["data"]["target"].labels == ["u", "v", "w"]
dataloader = dm.val_dataloader() dataloader = dm.val_dataloader()
assert isinstance(dataloader, DataLoader) assert isinstance(dataloader, PinaDataLoader)
assert len(dataloader) == 3 assert len(dataloader) == 3
data = next(iter(dataloader)) data = next(iter(dataloader))
assert isinstance(data, dict) assert isinstance(data, dict)
@@ -240,39 +268,6 @@ def test_dataloader_labels(input_, output_, automatic_batching):
assert data["data"]["target"].labels == ["u", "v", "w"] assert data["data"]["target"].labels == ["u", "v", "w"]
def test_get_all_data():
input = torch.stack([torch.zeros((1,)) + i for i in range(1000)])
target = input
problem = SupervisedProblem(input, target)
datamodule = PinaDataModule(
problem,
train_size=0.7,
test_size=0.2,
val_size=0.1,
batch_size=64,
shuffle=False,
repeat=False,
automatic_batching=None,
num_workers=0,
pin_memory=False,
)
datamodule.setup("fit")
datamodule.setup("test")
assert len(datamodule.train_dataset.get_all_data()["data"]["input"]) == 700
assert torch.isclose(
datamodule.train_dataset.get_all_data()["data"]["input"], input[:700]
).all()
assert len(datamodule.val_dataset.get_all_data()["data"]["input"]) == 100
assert torch.isclose(
datamodule.val_dataset.get_all_data()["data"]["input"], input[900:]
).all()
assert len(datamodule.test_dataset.get_all_data()["data"]["input"]) == 200
assert torch.isclose(
datamodule.test_dataset.get_all_data()["data"]["input"], input[700:900]
).all()
def test_input_propery_tensor(): def test_input_propery_tensor():
input = torch.stack([torch.zeros((1,)) + i for i in range(1000)]) input = torch.stack([torch.zeros((1,)) + i for i in range(1000)])
target = input target = input
@@ -285,7 +280,6 @@ def test_input_propery_tensor():
val_size=0.1, val_size=0.1,
batch_size=64, batch_size=64,
shuffle=False, shuffle=False,
repeat=False,
automatic_batching=None, automatic_batching=None,
num_workers=0, num_workers=0,
pin_memory=False, pin_memory=False,
@@ -311,7 +305,6 @@ def test_input_propery_graph():
val_size=0.1, val_size=0.1,
batch_size=64, batch_size=64,
shuffle=False, shuffle=False,
repeat=False,
automatic_batching=None, automatic_batching=None,
num_workers=0, num_workers=0,
pin_memory=False, pin_memory=False,

View File

@@ -1,138 +1,138 @@
import torch # import torch
import pytest # import pytest
from pina.data.dataset import PinaDatasetFactory, PinaGraphDataset # from pina.data.dataset import PinaDatasetFactory, PinaGraphDataset
from pina.graph import KNNGraph # from pina.graph import KNNGraph
from torch_geometric.data import Data # from torch_geometric.data import Data
x = torch.rand((100, 20, 10)) # x = torch.rand((100, 20, 10))
pos = torch.rand((100, 20, 2)) # pos = torch.rand((100, 20, 2))
input_ = [ # input_ = [
KNNGraph(x=x_, pos=pos_, neighbours=3, edge_attr=True) # KNNGraph(x=x_, pos=pos_, neighbours=3, edge_attr=True)
for x_, pos_ in zip(x, pos) # for x_, pos_ in zip(x, pos)
] # ]
output_ = torch.rand((100, 20, 10)) # output_ = torch.rand((100, 20, 10))
x_2 = torch.rand((50, 20, 10)) # x_2 = torch.rand((50, 20, 10))
pos_2 = torch.rand((50, 20, 2)) # pos_2 = torch.rand((50, 20, 2))
input_2_ = [ # input_2_ = [
KNNGraph(x=x_, pos=pos_, neighbours=3, edge_attr=True) # KNNGraph(x=x_, pos=pos_, neighbours=3, edge_attr=True)
for x_, pos_ in zip(x_2, pos_2) # for x_, pos_ in zip(x_2, pos_2)
] # ]
output_2_ = torch.rand((50, 20, 10)) # output_2_ = torch.rand((50, 20, 10))
# Problem with a single condition # # Problem with a single condition
conditions_dict_single = { # conditions_dict_single = {
"data": { # "data": {
"input": input_, # "input": input_,
"target": output_, # "target": output_,
} # }
} # }
max_conditions_lengths_single = {"data": 100} # max_conditions_lengths_single = {"data": 100}
# Problem with multiple conditions # # Problem with multiple conditions
conditions_dict_multi = { # conditions_dict_multi = {
"data_1": { # "data_1": {
"input": input_, # "input": input_,
"target": output_, # "target": output_,
}, # },
"data_2": { # "data_2": {
"input": input_2_, # "input": input_2_,
"target": output_2_, # "target": output_2_,
}, # },
} # }
max_conditions_lengths_multi = {"data_1": 100, "data_2": 50} # max_conditions_lengths_multi = {"data_1": 100, "data_2": 50}
@pytest.mark.parametrize( # @pytest.mark.parametrize(
"conditions_dict, max_conditions_lengths", # "conditions_dict, max_conditions_lengths",
[ # [
(conditions_dict_single, max_conditions_lengths_single), # (conditions_dict_single, max_conditions_lengths_single),
(conditions_dict_multi, max_conditions_lengths_multi), # (conditions_dict_multi, max_conditions_lengths_multi),
], # ],
) # )
def test_constructor(conditions_dict, max_conditions_lengths): # def test_constructor(conditions_dict, max_conditions_lengths):
dataset = PinaDatasetFactory( # dataset = PinaDatasetFactory(
conditions_dict, # conditions_dict,
max_conditions_lengths=max_conditions_lengths, # max_conditions_lengths=max_conditions_lengths,
automatic_batching=True, # automatic_batching=True,
) # )
assert isinstance(dataset, PinaGraphDataset) # assert isinstance(dataset, PinaGraphDataset)
assert len(dataset) == 100 # assert len(dataset) == 100
@pytest.mark.parametrize( # @pytest.mark.parametrize(
"conditions_dict, max_conditions_lengths", # "conditions_dict, max_conditions_lengths",
[ # [
(conditions_dict_single, max_conditions_lengths_single), # (conditions_dict_single, max_conditions_lengths_single),
(conditions_dict_multi, max_conditions_lengths_multi), # (conditions_dict_multi, max_conditions_lengths_multi),
], # ],
) # )
def test_getitem(conditions_dict, max_conditions_lengths): # def test_getitem(conditions_dict, max_conditions_lengths):
dataset = PinaDatasetFactory( # dataset = PinaDatasetFactory(
conditions_dict, # conditions_dict,
max_conditions_lengths=max_conditions_lengths, # max_conditions_lengths=max_conditions_lengths,
automatic_batching=True, # automatic_batching=True,
) # )
data = dataset[50] # data = dataset[50]
assert isinstance(data, dict) # assert isinstance(data, dict)
assert all([isinstance(d["input"], Data) for d in data.values()]) # assert all([isinstance(d["input"], Data) for d in data.values()])
assert all([isinstance(d["target"], torch.Tensor) for d in data.values()]) # assert all([isinstance(d["target"], torch.Tensor) for d in data.values()])
assert all( # assert all(
[d["input"].x.shape == torch.Size((20, 10)) for d in data.values()] # [d["input"].x.shape == torch.Size((20, 10)) for d in data.values()]
) # )
assert all( # assert all(
[d["target"].shape == torch.Size((20, 10)) for d in data.values()] # [d["target"].shape == torch.Size((20, 10)) for d in data.values()]
) # )
assert all( # assert all(
[ # [
d["input"].edge_index.shape == torch.Size((2, 60)) # d["input"].edge_index.shape == torch.Size((2, 60))
for d in data.values() # for d in data.values()
] # ]
) # )
assert all([d["input"].edge_attr.shape[0] == 60 for d in data.values()]) # assert all([d["input"].edge_attr.shape[0] == 60 for d in data.values()])
data = dataset.fetch_from_idx_list([i for i in range(20)]) # data = dataset.fetch_from_idx_list([i for i in range(20)])
assert isinstance(data, dict) # assert isinstance(data, dict)
assert all([isinstance(d["input"], Data) for d in data.values()]) # assert all([isinstance(d["input"], Data) for d in data.values()])
assert all([isinstance(d["target"], torch.Tensor) for d in data.values()]) # assert all([isinstance(d["target"], torch.Tensor) for d in data.values()])
assert all( # assert all(
[d["input"].x.shape == torch.Size((400, 10)) for d in data.values()] # [d["input"].x.shape == torch.Size((400, 10)) for d in data.values()]
) # )
assert all( # assert all(
[d["target"].shape == torch.Size((20, 20, 10)) for d in data.values()] # [d["target"].shape == torch.Size((20, 20, 10)) for d in data.values()]
) # )
assert all( # assert all(
[ # [
d["input"].edge_index.shape == torch.Size((2, 1200)) # d["input"].edge_index.shape == torch.Size((2, 1200))
for d in data.values() # for d in data.values()
] # ]
) # )
assert all([d["input"].edge_attr.shape[0] == 1200 for d in data.values()]) # assert all([d["input"].edge_attr.shape[0] == 1200 for d in data.values()])
def test_input_single_condition(): # def test_input_single_condition():
dataset = PinaDatasetFactory( # dataset = PinaDatasetFactory(
conditions_dict_single, # conditions_dict_single,
max_conditions_lengths=max_conditions_lengths_single, # max_conditions_lengths=max_conditions_lengths_single,
automatic_batching=True, # automatic_batching=True,
) # )
input_ = dataset.input # input_ = dataset.input
assert isinstance(input_, dict) # assert isinstance(input_, dict)
assert isinstance(input_["data"], list) # assert isinstance(input_["data"], list)
assert all([isinstance(d, Data) for d in input_["data"]]) # assert all([isinstance(d, Data) for d in input_["data"]])
def test_input_multi_condition(): # def test_input_multi_condition():
dataset = PinaDatasetFactory( # dataset = PinaDatasetFactory(
conditions_dict_multi, # conditions_dict_multi,
max_conditions_lengths=max_conditions_lengths_multi, # max_conditions_lengths=max_conditions_lengths_multi,
automatic_batching=True, # automatic_batching=True,
) # )
input_ = dataset.input # input_ = dataset.input
assert isinstance(input_, dict) # assert isinstance(input_, dict)
assert isinstance(input_["data_1"], list) # assert isinstance(input_["data_1"], list)
assert all([isinstance(d, Data) for d in input_["data_1"]]) # assert all([isinstance(d, Data) for d in input_["data_1"]])
assert isinstance(input_["data_2"], list) # assert isinstance(input_["data_2"], list)
assert all([isinstance(d, Data) for d in input_["data_2"]]) # assert all([isinstance(d, Data) for d in input_["data_2"]])

View File

@@ -1,86 +1,86 @@
import torch # import torch
import pytest # import pytest
from pina.data.dataset import PinaDatasetFactory, PinaTensorDataset # from pina.data.dataset import PinaDatasetFactory, PinaTensorDataset
input_tensor = torch.rand((100, 10)) # input_tensor = torch.rand((100, 10))
output_tensor = torch.rand((100, 2)) # output_tensor = torch.rand((100, 2))
input_tensor_2 = torch.rand((50, 10)) # input_tensor_2 = torch.rand((50, 10))
output_tensor_2 = torch.rand((50, 2)) # output_tensor_2 = torch.rand((50, 2))
conditions_dict_single = { # conditions_dict_single = {
"data": { # "data": {
"input": input_tensor, # "input": input_tensor,
"target": output_tensor, # "target": output_tensor,
} # }
} # }
conditions_dict_single_multi = { # conditions_dict_single_multi = {
"data_1": { # "data_1": {
"input": input_tensor, # "input": input_tensor,
"target": output_tensor, # "target": output_tensor,
}, # },
"data_2": { # "data_2": {
"input": input_tensor_2, # "input": input_tensor_2,
"target": output_tensor_2, # "target": output_tensor_2,
}, # },
} # }
max_conditions_lengths_single = {"data": 100} # max_conditions_lengths_single = {"data": 100}
max_conditions_lengths_multi = {"data_1": 100, "data_2": 50} # max_conditions_lengths_multi = {"data_1": 100, "data_2": 50}
@pytest.mark.parametrize( # @pytest.mark.parametrize(
"conditions_dict, max_conditions_lengths", # "conditions_dict, max_conditions_lengths",
[ # [
(conditions_dict_single, max_conditions_lengths_single), # (conditions_dict_single, max_conditions_lengths_single),
(conditions_dict_single_multi, max_conditions_lengths_multi), # (conditions_dict_single_multi, max_conditions_lengths_multi),
], # ],
) # )
def test_constructor_tensor(conditions_dict, max_conditions_lengths): # def test_constructor_tensor(conditions_dict, max_conditions_lengths):
dataset = PinaDatasetFactory( # dataset = PinaDatasetFactory(
conditions_dict, # conditions_dict,
max_conditions_lengths=max_conditions_lengths, # max_conditions_lengths=max_conditions_lengths,
automatic_batching=True, # automatic_batching=True,
) # )
assert isinstance(dataset, PinaTensorDataset) # assert isinstance(dataset, PinaTensorDataset)
def test_getitem_single(): # def test_getitem_single():
dataset = PinaDatasetFactory( # dataset = PinaDatasetFactory(
conditions_dict_single, # conditions_dict_single,
max_conditions_lengths=max_conditions_lengths_single, # max_conditions_lengths=max_conditions_lengths_single,
automatic_batching=False, # automatic_batching=False,
) # )
tensors = dataset.fetch_from_idx_list([i for i in range(70)]) # tensors = dataset.fetch_from_idx_list([i for i in range(70)])
assert isinstance(tensors, dict) # assert isinstance(tensors, dict)
assert list(tensors.keys()) == ["data"] # assert list(tensors.keys()) == ["data"]
assert sorted(list(tensors["data"].keys())) == ["input", "target"] # assert sorted(list(tensors["data"].keys())) == ["input", "target"]
assert isinstance(tensors["data"]["input"], torch.Tensor) # assert isinstance(tensors["data"]["input"], torch.Tensor)
assert tensors["data"]["input"].shape == torch.Size((70, 10)) # assert tensors["data"]["input"].shape == torch.Size((70, 10))
assert isinstance(tensors["data"]["target"], torch.Tensor) # assert isinstance(tensors["data"]["target"], torch.Tensor)
assert tensors["data"]["target"].shape == torch.Size((70, 2)) # assert tensors["data"]["target"].shape == torch.Size((70, 2))
def test_getitem_multi(): # def test_getitem_multi():
dataset = PinaDatasetFactory( # dataset = PinaDatasetFactory(
conditions_dict_single_multi, # conditions_dict_single_multi,
max_conditions_lengths=max_conditions_lengths_multi, # max_conditions_lengths=max_conditions_lengths_multi,
automatic_batching=False, # automatic_batching=False,
) # )
tensors = dataset.fetch_from_idx_list([i for i in range(70)]) # tensors = dataset.fetch_from_idx_list([i for i in range(70)])
assert isinstance(tensors, dict) # assert isinstance(tensors, dict)
assert list(tensors.keys()) == ["data_1", "data_2"] # assert list(tensors.keys()) == ["data_1", "data_2"]
assert sorted(list(tensors["data_1"].keys())) == ["input", "target"] # assert sorted(list(tensors["data_1"].keys())) == ["input", "target"]
assert isinstance(tensors["data_1"]["input"], torch.Tensor) # assert isinstance(tensors["data_1"]["input"], torch.Tensor)
assert tensors["data_1"]["input"].shape == torch.Size((70, 10)) # assert tensors["data_1"]["input"].shape == torch.Size((70, 10))
assert isinstance(tensors["data_1"]["target"], torch.Tensor) # assert isinstance(tensors["data_1"]["target"], torch.Tensor)
assert tensors["data_1"]["target"].shape == torch.Size((70, 2)) # assert tensors["data_1"]["target"].shape == torch.Size((70, 2))
assert sorted(list(tensors["data_2"].keys())) == ["input", "target"] # assert sorted(list(tensors["data_2"].keys())) == ["input", "target"]
assert isinstance(tensors["data_2"]["input"], torch.Tensor) # assert isinstance(tensors["data_2"]["input"], torch.Tensor)
assert tensors["data_2"]["input"].shape == torch.Size((50, 10)) # assert tensors["data_2"]["input"].shape == torch.Size((50, 10))
assert isinstance(tensors["data_2"]["target"], torch.Tensor) # assert isinstance(tensors["data_2"]["target"], torch.Tensor)
assert tensors["data_2"]["target"].shape == torch.Size((50, 2)) # assert tensors["data_2"]["target"].shape == torch.Size((50, 2))

View File

@@ -117,6 +117,10 @@ def test_solver_train(use_lt, batch_size, compile):
assert isinstance(solver.model, OptimizedModule) assert isinstance(solver.model, OptimizedModule)
if __name__ == "__main__":
test_solver_train(use_lt=True, batch_size=20, compile=True)
@pytest.mark.parametrize("batch_size", [None, 1, 5, 20]) @pytest.mark.parametrize("batch_size", [None, 1, 5, 20])
@pytest.mark.parametrize("use_lt", [True, False]) @pytest.mark.parametrize("use_lt", [True, False])
def test_solver_train_graph(batch_size, use_lt): def test_solver_train_graph(batch_size, use_lt):