Add functionalities in DataModule and data loaders + tests datasets and DataModule (#453)

* Add num_workers and pin_memory arguments to DataLoader and DataModule tests
This commit is contained in:
Filippo Olivo
2025-02-18 09:10:23 +01:00
committed by Nicola Demo
parent 9cae9a438f
commit 571ef7f9e2
5 changed files with 455 additions and 29 deletions

View File

@@ -1,4 +1,5 @@
import logging import logging
import warnings
from lightning.pytorch import LightningDataModule from lightning.pytorch import LightningDataModule
import torch import torch
from ..label_tensor import LabelTensor from ..label_tensor import LabelTensor
@@ -8,6 +9,7 @@ from torch.utils.data.distributed import DistributedSampler
from .dataset import PinaDatasetFactory from .dataset import PinaDatasetFactory
from ..collector import Collector from ..collector import Collector
class DummyDataloader: class DummyDataloader:
"""" """"
Dummy dataloader used when batch size is None. It callects all the data Dummy dataloader used when batch size is None. It callects all the data
@@ -57,7 +59,7 @@ class Collator:
self.max_conditions_lengths = max_conditions_lengths self.max_conditions_lengths = max_conditions_lengths
self.callable_function = self._collate_custom_dataloader if \ self.callable_function = self._collate_custom_dataloader if \
max_conditions_lengths is None else ( max_conditions_lengths is None else (
self._collate_standard_dataloader) self._collate_standard_dataloader)
self.dataset = dataset self.dataset = dataset
def _collate_custom_dataloader(self, batch): def _collate_custom_dataloader(self, batch):
@@ -95,7 +97,7 @@ class Collator:
class PinaSampler: class PinaSampler:
def __new__(self, dataset, batch_size, shuffle, automatic_batching): def __new__(cls, dataset, shuffle):
if (torch.distributed.is_available() and if (torch.distributed.is_available() and
torch.distributed.is_initialized()): torch.distributed.is_initialized()):
@@ -123,15 +125,35 @@ class PinaDataModule(LightningDataModule):
batch_size=None, batch_size=None,
shuffle=True, shuffle=True,
repeat=False, repeat=False,
automatic_batching=False automatic_batching=False,
num_workers=0,
pin_memory=False,
): ):
""" """
Initialize the object, creating dataset based on input problem Initialize the object, creating datasets based on the input problem.
:param problem: Problem where data are defined
:param train_size: number/percentage of elements in train split :param problem: The problem defining the dataset.
:param test_size: number/percentage of elements in test split :type problem: AbstractProblem
:param val_size: number/percentage of elements in evaluation split :param train_size: Fraction or number of elements in the training split.
:param batch_size: batch size used for training :type train_size: float
:param test_size: Fraction or number of elements in the test split.
:type test_size: float
:param val_size: Fraction or number of elements in the validation split.
:type val_size: float
:param predict_size: Fraction or number of elements in the prediction split.
:type predict_size: float
:param batch_size: Batch size used for training. If None, the entire dataset is used per batch.
:type batch_size: int or None
:param shuffle: Whether to shuffle the dataset before splitting.
:type shuffle: bool
:param repeat: Whether to repeat the dataset indefinitely.
:type repeat: bool
:param automatic_batching: Whether to enable automatic batching.
:type automatic_batching: bool
:param num_workers: Number of worker threads for data loading. Default 0 (serial loading)
:type num_workers: int
:param pin_memory: Whether to use pinned memory for faster data transfer to GPU. (Default False)
:type pin_memory: bool
""" """
logging.debug('Start initialization of Pina DataModule') logging.debug('Start initialization of Pina DataModule')
logging.info('Start initialization of Pina DataModule') logging.info('Start initialization of Pina DataModule')
@@ -170,6 +192,15 @@ class PinaDataModule(LightningDataModule):
collector = Collector(problem) collector = Collector(problem)
collector.store_fixed_data() collector.store_fixed_data()
collector.store_sample_domains() collector.store_sample_domains()
if batch_size is None and num_workers != 0:
warnings.warn(
"Setting num_workers when batch_size is None has no effect on "
"the DataLoading process.")
if batch_size is None and pin_memory:
warnings.warn("Setting pin_memory to True has no effect when "
"batch_size is None.")
self.num_workers = num_workers
self.pin_memory = pin_memory
self.collector_splits = self._create_splits(collector, splits_dict) self.collector_splits = self._create_splits(collector, splits_dict)
self.transfer_batch_to_device = self._transfer_batch_to_device self.transfer_batch_to_device = self._transfer_batch_to_device
@@ -271,20 +302,27 @@ class PinaDataModule(LightningDataModule):
dataset_dict[key].update({condition_name: data}) dataset_dict[key].update({condition_name: data})
return dataset_dict return dataset_dict
def _create_dataloader(self, split, dataset): def _create_dataloader(self, split, dataset):
shuffle = self.shuffle if split == 'train' else False shuffle = self.shuffle if split == 'train' else False
# Suppress the warning about num_workers.
# In many cases, especially for PINNs, serial data loading can outperform parallel data loading.
warnings.filterwarnings(
"ignore",
message=(
r"The '(train|val|test)_dataloader' does not have many workers which may be a bottleneck."),
module="lightning.pytorch.trainer.connectors.data_connector"
)
# Use custom batching (good if batch size is large) # Use custom batching (good if batch size is large)
if self.batch_size is not None: if self.batch_size is not None:
sampler = PinaSampler(dataset, self.batch_size, sampler = PinaSampler(dataset, shuffle)
shuffle, self.automatic_batching)
if self.automatic_batching: if self.automatic_batching:
collate = Collator(self.find_max_conditions_lengths(split)) collate = Collator(self.find_max_conditions_lengths(split))
else: else:
collate = Collator(None, dataset) collate = Collator(None, dataset)
return DataLoader(dataset, self.batch_size, return DataLoader(dataset, self.batch_size,
collate_fn=collate, sampler=sampler) collate_fn=collate, sampler=sampler,
num_workers=self.num_workers)
dataloader = DummyDataloader(dataset) dataloader = DummyDataloader(dataset)
dataloader.dataset = self._transfer_batch_to_device( dataloader.dataset = self._transfer_batch_to_device(
dataloader.dataset, self.trainer.strategy.root_device, 0) dataloader.dataset, self.trainer.strategy.root_device, 0)

View File

@@ -18,6 +18,8 @@ class Trainer(lightning.pytorch.Trainer):
predict_size=0., predict_size=0.,
compile=None, compile=None,
automatic_batching=None, automatic_batching=None,
num_workers=None,
pin_memory=None,
**kwargs): **kwargs):
""" """
PINA Trainer class for costumizing every aspect of training via flags. PINA Trainer class for costumizing every aspect of training via flags.
@@ -44,6 +46,10 @@ class Trainer(lightning.pytorch.Trainer):
performed. Please avoid using automatic batching when batch_size is performed. Please avoid using automatic batching when batch_size is
large, default False. large, default False.
:type automatic_batching: bool :type automatic_batching: bool
:param num_workers: Number of worker threads for data loading. Default 0 (serial loading)
:type num_workers: int
:param pin_memory: Whether to use pinned memory for faster data transfer to GPU. (Default False)
:type pin_memory: bool
:Keyword Arguments: :Keyword Arguments:
The additional keyword arguments specify the training setup The additional keyword arguments specify the training setup
@@ -60,6 +66,14 @@ class Trainer(lightning.pytorch.Trainer):
check_consistency(automatic_batching, bool) check_consistency(automatic_batching, bool)
if compile is not None: if compile is not None:
check_consistency(compile, bool) check_consistency(compile, bool)
if pin_memory is not None:
check_consistency(pin_memory, bool)
else:
pin_memory = False
if num_workers is not None:
check_consistency(pin_memory, int)
else:
num_workers = 0
if train_size + test_size + val_size + predict_size > 1: if train_size + test_size + val_size + predict_size > 1:
raise ValueError('train_size, test_size, val_size and predict_size ' raise ValueError('train_size, test_size, val_size and predict_size '
'must sum up to 1.') 'must sum up to 1.')
@@ -96,16 +110,13 @@ class Trainer(lightning.pytorch.Trainer):
# set attributes # set attributes
self.compile = compile self.compile = compile
self.automatic_batching = automatic_batching
self.train_size = train_size
self.test_size = test_size
self.val_size = val_size
self.predict_size = predict_size
self.solver = solver self.solver = solver
self.batch_size = batch_size self.batch_size = batch_size
self._move_to_device() self._move_to_device()
self.data_module = None self.data_module = None
self._create_loader() self._create_datamodule(train_size, test_size, val_size, predict_size,
batch_size, automatic_batching, pin_memory,
num_workers)
# logging # logging
self.logging_kwargs = { self.logging_kwargs = {
@@ -127,7 +138,15 @@ class Trainer(lightning.pytorch.Trainer):
pb.unknown_parameters[key] = torch.nn.Parameter( pb.unknown_parameters[key] = torch.nn.Parameter(
pb.unknown_parameters[key].data.to(device)) pb.unknown_parameters[key].data.to(device))
def _create_loader(self): def _create_datamodule(self,
train_size,
test_size,
val_size,
predict_size,
batch_size,
automatic_batching,
pin_memory,
num_workers):
""" """
This method is used here because is resampling is needed This method is used here because is resampling is needed
during training, there is no need to define to touch the during training, there is no need to define to touch the
@@ -136,8 +155,8 @@ class Trainer(lightning.pytorch.Trainer):
if not self.solver.problem.are_all_domains_discretised: if not self.solver.problem.are_all_domains_discretised:
error_message = '\n'.join([ error_message = '\n'.join([
f"""{" " * 13} ---> Domain {key} { f"""{" " * 13} ---> Domain {key} {
"sampled" if key in self.solver.problem.discretised_domains else "sampled" if key in self.solver.problem.discretised_domains else
"not sampled"}""" for key in "not sampled"}""" for key in
self.solver.problem.domains.keys() self.solver.problem.domains.keys()
]) ])
raise RuntimeError('Cannot create Trainer if not all conditions ' raise RuntimeError('Cannot create Trainer if not all conditions '
@@ -145,12 +164,14 @@ class Trainer(lightning.pytorch.Trainer):
f'{error_message}') f'{error_message}')
self.data_module = PinaDataModule( self.data_module = PinaDataModule(
self.solver.problem, self.solver.problem,
train_size=self.train_size, train_size=train_size,
test_size=self.test_size, test_size=test_size,
val_size=self.val_size, val_size=val_size,
predict_size=self.predict_size, predict_size=predict_size,
batch_size=self.batch_size, batch_size=batch_size,
automatic_batching=self.automatic_batching) automatic_batching=automatic_batching,
num_workers=num_workers,
pin_memory=pin_memory)
def train(self, **kwargs): def train(self, **kwargs):
""" """

View File

@@ -0,0 +1,178 @@
import torch
import pytest
from pina.data import PinaDataModule
from pina.data.dataset import PinaTensorDataset, PinaGraphDataset
from pina.problem.zoo import SupervisedProblem
from pina.graph import RadiusGraph
from pina.data.data_module import DummyDataloader
from pina import Trainer
from pina.solvers import SupervisedSolver
from torch_geometric.data import Batch
from torch.utils.data import DataLoader
input_tensor = torch.rand((100, 10))
output_tensor = torch.rand((100, 2))
x = torch.rand((100, 50 , 10))
pos = torch.rand((100, 50 , 2))
input_graph = RadiusGraph(x, pos, r=.1, build_edge_attr=True)
output_graph = torch.rand((100, 50 , 10))
@pytest.mark.parametrize(
"input_, output_",
[
(input_tensor, output_tensor),
(input_graph, output_graph)
]
)
def test_constructor(input_, output_):
problem = SupervisedProblem(input_=input_, output_=output_)
PinaDataModule(problem)
@pytest.mark.parametrize(
"input_, output_",
[
(input_tensor, output_tensor),
(input_graph, output_graph)
]
)
@pytest.mark.parametrize(
"train_size, val_size, test_size",
[
(.7, .2, .1),
(.7, .3, 0)
]
)
def test_setup_train(input_, output_, train_size, val_size, test_size):
problem = SupervisedProblem(input_=input_, output_=output_)
dm = PinaDataModule(problem, train_size=train_size, val_size=val_size, test_size=test_size)
dm.setup()
assert hasattr(dm, "train_dataset")
if isinstance(input_, torch.Tensor):
assert isinstance(dm.train_dataset, PinaTensorDataset)
else:
assert isinstance(dm.train_dataset, PinaGraphDataset)
#assert len(dm.train_dataset) == int(len(input_) * train_size)
if test_size > 0:
assert hasattr(dm, "test_dataset")
assert dm.test_dataset is None
else:
assert not hasattr(dm, "test_dataset")
assert hasattr(dm, "val_dataset")
if isinstance(input_, torch.Tensor):
assert isinstance(dm.val_dataset, PinaTensorDataset)
else:
assert isinstance(dm.val_dataset, PinaGraphDataset)
#assert len(dm.val_dataset) == int(len(input_) * val_size)
@pytest.mark.parametrize(
"input_, output_",
[
(input_tensor, output_tensor),
(input_graph, output_graph)
]
)
@pytest.mark.parametrize(
"train_size, val_size, test_size",
[
(.7, .2, .1),
(0., 0., 1.)
]
)
def test_setup_test(input_, output_, train_size, val_size, test_size):
problem = SupervisedProblem(input_=input_, output_=output_)
dm = PinaDataModule(problem, train_size=train_size, val_size=val_size, test_size=test_size)
dm.setup(stage='test')
if train_size > 0:
assert hasattr(dm, "train_dataset")
assert dm.train_dataset is None
else:
assert not hasattr(dm, "train_dataset")
if val_size > 0:
assert hasattr(dm, "val_dataset")
assert dm.val_dataset is None
else:
assert not hasattr(dm, "val_dataset")
assert hasattr(dm, "test_dataset")
if isinstance(input_, torch.Tensor):
assert isinstance(dm.test_dataset, PinaTensorDataset)
else:
assert isinstance(dm.test_dataset, PinaGraphDataset)
#assert len(dm.test_dataset) == int(len(input_) * test_size)
@pytest.mark.parametrize(
"input_, output_",
[
(input_tensor, output_tensor),
(input_graph, output_graph)
]
)
def test_dummy_dataloader(input_, output_):
problem = SupervisedProblem(input_=input_, output_=output_)
solver = SupervisedSolver(problem=problem, model=torch.nn.Linear(10, 10))
trainer = Trainer(solver, batch_size=None, train_size=.7, val_size=.3, test_size=0.)
dm = trainer.data_module
dm.setup()
dm.trainer = trainer
dataloader = dm.train_dataloader()
assert isinstance(dataloader, DummyDataloader)
assert len(dataloader) == 1
data = next(dataloader)
assert isinstance(data, list)
assert isinstance(data[0], tuple)
if isinstance(input_, RadiusGraph):
assert isinstance(data[0][1]['input_points'], Batch)
else:
assert isinstance(data[0][1]['input_points'], torch.Tensor)
assert isinstance(data[0][1]['output_points'], torch.Tensor)
dataloader = dm.val_dataloader()
assert isinstance(dataloader, DummyDataloader)
assert len(dataloader) == 1
data = next(dataloader)
assert isinstance(data, list)
assert isinstance(data[0], tuple)
if isinstance(input_, RadiusGraph):
assert isinstance(data[0][1]['input_points'], Batch)
else:
assert isinstance(data[0][1]['input_points'], torch.Tensor)
assert isinstance(data[0][1]['output_points'], torch.Tensor)
@pytest.mark.parametrize(
"input_, output_",
[
(input_tensor, output_tensor),
(input_graph, output_graph)
]
)
def test_dataloader(input_, output_):
problem = SupervisedProblem(input_=input_, output_=output_)
solver = SupervisedSolver(problem=problem, model=torch.nn.Linear(10, 10))
trainer = Trainer(solver, batch_size=10, train_size=.7, val_size=.3, test_size=0.)
dm = trainer.data_module
dm.setup()
dm.trainer = trainer
dataloader = dm.train_dataloader()
assert isinstance(dataloader, DataLoader)
assert len(dataloader) == 7
data = next(iter(dataloader))
assert isinstance(data, dict)
if isinstance(input_, RadiusGraph):
assert isinstance(data['data']['input_points'], Batch)
else:
assert isinstance(data['data']['input_points'], torch.Tensor)
assert isinstance(data['data']['output_points'], torch.Tensor)
dataloader = dm.val_dataloader()
assert isinstance(dataloader, DataLoader)
assert len(dataloader) == 3
data = next(iter(dataloader))
assert isinstance(data, dict)
if isinstance(input_, RadiusGraph):
assert isinstance(data['data']['input_points'], Batch)
else:
assert isinstance(data['data']['input_points'], torch.Tensor)
assert isinstance(data['data']['output_points'], torch.Tensor)

View File

@@ -0,0 +1,101 @@
import torch
import pytest
from pina.data.dataset import PinaDatasetFactory, PinaGraphDataset
from pina.graph import KNNGraph
from torch_geometric.data import Data
x = torch.rand((100, 20, 10))
pos = torch.rand((100, 20, 2))
input_ = KNNGraph(x=x, pos=pos, k=3, build_edge_attr=True)
output_ = torch.rand((100, 20, 10))
x_2 = torch.rand((50, 20, 10))
pos_2 = torch.rand((50, 20, 2))
input_2_ = KNNGraph(x=x_2, pos=pos_2, k=3, build_edge_attr=True)
output_2_ = torch.rand((50, 20, 10))
# Problem with a single condition
conditions_dict_single = {
'data': {
'input_points': input_.data,
'output_points': output_,
}
}
max_conditions_lengths_single = {
'data': 100
}
# Problem with multiple conditions
conditions_dict_single_multi = {
'data_1': {
'input_points': input_.data,
'output_points': output_,
},
'data_2': {
'input_points': input_2_.data,
'output_points': output_2_,
}
}
max_conditions_lengths_multi = {
'data_1': 100,
'data_2': 50
}
@pytest.mark.parametrize(
"conditions_dict, max_conditions_lengths",
[
(conditions_dict_single, max_conditions_lengths_single),
(conditions_dict_single_multi, max_conditions_lengths_multi)
]
)
def test_constructor(conditions_dict, max_conditions_lengths):
dataset = PinaDatasetFactory(conditions_dict,
max_conditions_lengths=max_conditions_lengths,
automatic_batching=True)
assert isinstance(dataset, PinaGraphDataset)
assert len(dataset) == 100
@pytest.mark.parametrize(
"conditions_dict, max_conditions_lengths",
[
(conditions_dict_single, max_conditions_lengths_single),
(conditions_dict_single_multi, max_conditions_lengths_multi)
]
)
def test_getitem(conditions_dict, max_conditions_lengths):
dataset = PinaDatasetFactory(conditions_dict,
max_conditions_lengths=max_conditions_lengths,
automatic_batching=True)
data = dataset[50]
assert isinstance(data, dict)
assert all([isinstance(d['input_points'], Data)
for d in data.values()])
assert all([isinstance(d['output_points'], torch.Tensor)
for d in data.values()])
assert all([d['input_points'].x.shape == torch.Size((20, 10))
for d in data.values()])
assert all([d['output_points'].shape == torch.Size((20, 10))
for d in data.values()])
assert all([d['input_points'].edge_index.shape ==
torch.Size((2, 60)) for d in data.values()])
assert all([d['input_points'].edge_attr.shape[0]
== 60 for d in data.values()])
data = dataset.fetch_from_idx_list([i for i in range(20)])
assert isinstance(data, dict)
assert all([isinstance(d['input_points'], Data)
for d in data.values()])
assert all([isinstance(d['output_points'], torch.Tensor)
for d in data.values()])
assert all([d['input_points'].x.shape == torch.Size((400, 10))
for d in data.values()])
assert all([d['output_points'].shape == torch.Size((400, 10))
for d in data.values()])
assert all([d['input_points'].edge_index.shape ==
torch.Size((2, 1200)) for d in data.values()])
assert all([d['input_points'].edge_attr.shape[0]
== 1200 for d in data.values()])

View File

@@ -0,0 +1,88 @@
import torch
import pytest
from pina.data.dataset import PinaDatasetFactory, PinaTensorDataset
input_tensor = torch.rand((100, 10))
output_tensor = torch.rand((100, 2))
input_tensor_2 = torch.rand((50, 10))
output_tensor_2 = torch.rand((50, 2))
conditions_dict_single = {
'data': {
'input_points': input_tensor,
'output_points': output_tensor,
}
}
conditions_dict_single_multi = {
'data_1': {
'input_points': input_tensor,
'output_points': output_tensor,
},
'data_2': {
'input_points': input_tensor_2,
'output_points': output_tensor_2,
}
}
max_conditions_lengths_single = {
'data': 100
}
max_conditions_lengths_multi = {
'data_1': 100,
'data_2': 50
}
@pytest.mark.parametrize(
"conditions_dict, max_conditions_lengths",
[
(conditions_dict_single, max_conditions_lengths_single),
(conditions_dict_single_multi, max_conditions_lengths_multi)
]
)
def test_constructor_tensor(conditions_dict, max_conditions_lengths):
dataset = PinaDatasetFactory(conditions_dict,
max_conditions_lengths=max_conditions_lengths,
automatic_batching=True)
assert isinstance(dataset, PinaTensorDataset)
def test_getitem_single():
dataset = PinaDatasetFactory(conditions_dict_single,
max_conditions_lengths=max_conditions_lengths_single,
automatic_batching=False)
tensors = dataset.fetch_from_idx_list([i for i in range(70)])
assert isinstance(tensors, dict)
assert list(tensors.keys()) == ['data']
assert sorted(list(tensors['data'].keys())) == [
'input_points', 'output_points']
assert isinstance(tensors['data']['input_points'], torch.Tensor)
assert tensors['data']['input_points'].shape == torch.Size((70, 10))
assert isinstance(tensors['data']['output_points'], torch.Tensor)
assert tensors['data']['output_points'].shape == torch.Size((70, 2))
def test_getitem_multi():
dataset = PinaDatasetFactory(conditions_dict_single_multi,
max_conditions_lengths=max_conditions_lengths_multi,
automatic_batching=False)
tensors = dataset.fetch_from_idx_list([i for i in range(70)])
assert isinstance(tensors, dict)
assert list(tensors.keys()) == ['data_1', 'data_2']
assert sorted(list(tensors['data_1'].keys())) == [
'input_points', 'output_points']
assert isinstance(tensors['data_1']['input_points'], torch.Tensor)
assert tensors['data_1']['input_points'].shape == torch.Size((70, 10))
assert isinstance(tensors['data_1']['output_points'], torch.Tensor)
assert tensors['data_1']['output_points'].shape == torch.Size((70, 2))
assert sorted(list(tensors['data_2'].keys())) == [
'input_points', 'output_points']
assert isinstance(tensors['data_2']['input_points'], torch.Tensor)
assert tensors['data_2']['input_points'].shape == torch.Size((50, 10))
assert isinstance(tensors['data_2']['output_points'], torch.Tensor)
assert tensors['data_2']['output_points'].shape == torch.Size((50, 2))