diff --git a/pina/data/dataset.py b/pina/data/dataset.py index 8b5f998..4eeb20e 100644 --- a/pina/data/dataset.py +++ b/pina/data/dataset.py @@ -93,8 +93,7 @@ class PinaTensorDataset(PinaDataset): class PinaGraphDataset(PinaDataset): - pass -''' + def __init__(self, conditions_dict, max_conditions_lengths, automatic_batching): super().__init__(conditions_dict, max_conditions_lengths) @@ -113,7 +112,7 @@ class PinaGraphDataset(PinaDataset): to_return_dict[condition] = {k: Batch.from_data_list([v[i] for i in cond_idx]) if isinstance(v, list) - else v[cond_idx] + else v[cond_idx].reshape(-1, *v[cond_idx].shape[2:]) for k, v in data.items() } return to_return_dict @@ -132,5 +131,4 @@ class PinaGraphDataset(PinaDataset): return self.fetch_from_idx_list(index) def __getitem__(self, idx): - return self._getitem_func(idx) -''' \ No newline at end of file + return self._getitem_func(idx) \ No newline at end of file