Implement PinaGraphDataset

This commit is contained in:
FilippoOlivo
2025-02-04 19:47:10 +01:00
committed by Nicola Demo
parent 86fe41261b
commit 4c5e1569ff

View File

@@ -93,8 +93,7 @@ class PinaTensorDataset(PinaDataset):
class PinaGraphDataset(PinaDataset):
pass
'''
def __init__(self, conditions_dict, max_conditions_lengths,
automatic_batching):
super().__init__(conditions_dict, max_conditions_lengths)
@@ -113,7 +112,7 @@ class PinaGraphDataset(PinaDataset):
to_return_dict[condition] = {k: Batch.from_data_list([v[i]
for i in cond_idx])
if isinstance(v, list)
else v[cond_idx]
else v[cond_idx].reshape(-1, *v[cond_idx].shape[2:])
for k, v in data.items()
}
return to_return_dict
@@ -132,5 +131,4 @@ class PinaGraphDataset(PinaDataset):
return self.fetch_from_idx_list(index)
def __getitem__(self, idx):
return self._getitem_func(idx)
'''
return self._getitem_func(idx)