This commit is contained in:
Filippo Olivo
2024-11-28 11:06:38 +01:00
committed by Nicola Demo
parent 3c95441aac
commit f748b66194
9 changed files with 28 additions and 29 deletions

View File

@@ -51,8 +51,12 @@ class PinaDataset(Dataset):
class PinaTensorDataset(PinaDataset):
def __init__(self, conditions_dict, max_conditions_lengths,
):
automatic_batching):
super().__init__(conditions_dict, max_conditions_lengths)
if automatic_batching:
self._getitem_func = self._getitem_int
else:
self._getitem_func = self._getitem_list
def _getitem_int(self, idx):
return {
@@ -72,9 +76,7 @@ class PinaTensorDataset(PinaDataset):
return to_return_dict
def __getitem__(self, idx):
if isinstance(idx, int):
return self._getitem_int(idx)
return self._getitem_list(idx)
return self._getitem_func(idx)
class PinaGraphDataset(PinaDataset):
pass