This commit is contained in:
Filippo Olivo
2024-11-28 11:06:38 +01:00
committed by Nicola Demo
parent 3c95441aac
commit f748b66194
9 changed files with 28 additions and 29 deletions

View File

@@ -95,7 +95,7 @@ class PinaDataModule(LightningDataModule):
logging.debug('Start initialization of Pina DataModule')
logging.info('Start initialization of Pina DataModule')
super().__init__()
self.default_batching = automatic_batching
self.automatic_batching = automatic_batching
self.batch_size = batch_size
self.shuffle = shuffle
self.repeat = repeat
@@ -133,24 +133,24 @@ class PinaDataModule(LightningDataModule):
self.train_dataset = PinaDatasetFactory(
self.collector_splits['train'],
max_conditions_lengths=self.find_max_conditions_lengths(
'train'))
'train'), automatic_batching=self.automatic_batching)
if 'val' in self.collector_splits.keys():
self.val_dataset = PinaDatasetFactory(
self.collector_splits['val'],
max_conditions_lengths=self.find_max_conditions_lengths(
'val')
'val'), automatic_batching=self.automatic_batching
)
elif stage == 'test':
self.test_dataset = PinaDatasetFactory(
self.collector_splits['test'],
max_conditions_lengths=self.find_max_conditions_lengths(
'test')
'test'), automatic_batching=self.automatic_batching
)
elif stage == 'predict':
self.predict_dataset = PinaDatasetFactory(
self.collector_splits['predict'],
max_conditions_lengths=self.find_max_conditions_lengths(
'predict')
'predict'), automatic_batching=self.automatic_batching
)
else:
raise ValueError(
@@ -237,9 +237,9 @@ class PinaDataModule(LightningDataModule):
self.val_dataset)
# Use default batching in torch DataLoader (good is batch size is small)
if self.default_batching:
if self.automatic_batching:
collate = Collator(self.find_max_conditions_lengths('val'))
return DataLoader(self.val_dataset, self.batch_size,
return DataLoader(self.val_dataset, batch_size,
collate_fn=collate)
collate = Collator(None)
# Use custom batching (good if batch size is large)
@@ -252,14 +252,16 @@ class PinaDataModule(LightningDataModule):
Create the training dataloader
"""
# Use default batching in torch DataLoader (good is batch size is small)
if self.default_batching:
batch_size = self.batch_size if self.batch_size is not None else len(
self.train_dataset)
if self.automatic_batching:
collate = Collator(self.find_max_conditions_lengths('train'))
return DataLoader(self.train_dataset, self.batch_size,
return DataLoader(self.train_dataset, batch_size,
collate_fn=collate)
collate = Collator(None)
# Use custom batching (good if batch size is large)
batch_size = self.batch_size if self.batch_size is not None else len(
self.train_dataset)
sampler = PinaBatchSampler(self.train_dataset, batch_size,
shuffle=False)
return DataLoader(self.train_dataset, sampler=sampler,

View File

@@ -51,8 +51,12 @@ class PinaDataset(Dataset):
class PinaTensorDataset(PinaDataset):
def __init__(self, conditions_dict, max_conditions_lengths,
):
automatic_batching):
super().__init__(conditions_dict, max_conditions_lengths)
if automatic_batching:
self._getitem_func = self._getitem_int
else:
self._getitem_func = self._getitem_list
def _getitem_int(self, idx):
return {
@@ -72,9 +76,7 @@ class PinaTensorDataset(PinaDataset):
return to_return_dict
def __getitem__(self, idx):
if isinstance(idx, int):
return self._getitem_int(idx)
return self._getitem_list(idx)
return self._getitem_func(idx)
class PinaGraphDataset(PinaDataset):
pass