Implement PinaGraphDataset
This commit is contained in:
committed by
Nicola Demo
parent
86fe41261b
commit
4c5e1569ff
@@ -93,8 +93,7 @@ class PinaTensorDataset(PinaDataset):
|
||||
|
||||
|
||||
class PinaGraphDataset(PinaDataset):
|
||||
pass
|
||||
'''
|
||||
|
||||
def __init__(self, conditions_dict, max_conditions_lengths,
|
||||
automatic_batching):
|
||||
super().__init__(conditions_dict, max_conditions_lengths)
|
||||
@@ -113,7 +112,7 @@ class PinaGraphDataset(PinaDataset):
|
||||
to_return_dict[condition] = {k: Batch.from_data_list([v[i]
|
||||
for i in cond_idx])
|
||||
if isinstance(v, list)
|
||||
else v[cond_idx]
|
||||
else v[cond_idx].reshape(-1, *v[cond_idx].shape[2:])
|
||||
for k, v in data.items()
|
||||
}
|
||||
return to_return_dict
|
||||
@@ -132,5 +131,4 @@ class PinaGraphDataset(PinaDataset):
|
||||
return self.fetch_from_idx_list(index)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self._getitem_func(idx)
|
||||
'''
|
||||
return self._getitem_func(idx)
|
||||
Reference in New Issue
Block a user