Fix bugs (#387)
This commit is contained in:
committed by
Nicola Demo
parent
3c95441aac
commit
f748b66194
@@ -51,8 +51,12 @@ class PinaDataset(Dataset):
|
||||
|
||||
class PinaTensorDataset(PinaDataset):
|
||||
def __init__(self, conditions_dict, max_conditions_lengths,
|
||||
):
|
||||
automatic_batching):
|
||||
super().__init__(conditions_dict, max_conditions_lengths)
|
||||
if automatic_batching:
|
||||
self._getitem_func = self._getitem_int
|
||||
else:
|
||||
self._getitem_func = self._getitem_list
|
||||
|
||||
def _getitem_int(self, idx):
|
||||
return {
|
||||
@@ -72,9 +76,7 @@ class PinaTensorDataset(PinaDataset):
|
||||
return to_return_dict
|
||||
|
||||
def __getitem__(self, idx):
|
||||
if isinstance(idx, int):
|
||||
return self._getitem_int(idx)
|
||||
return self._getitem_list(idx)
|
||||
return self._getitem_func(idx)
|
||||
|
||||
class PinaGraphDataset(PinaDataset):
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user