From 4c5e1569ffbd39aca85fa7cc577e065e20679c35 Mon Sep 17 00:00:00 2001 From: FilippoOlivo Date: Tue, 4 Feb 2025 19:47:10 +0100 Subject: [PATCH] Implement PinaGraphDataset --- pina/data/dataset.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/pina/data/dataset.py b/pina/data/dataset.py index 8b5f998..4eeb20e 100644 --- a/pina/data/dataset.py +++ b/pina/data/dataset.py @@ -93,8 +93,7 @@ class PinaTensorDataset(PinaDataset): class PinaGraphDataset(PinaDataset): - pass -''' + def __init__(self, conditions_dict, max_conditions_lengths, automatic_batching): super().__init__(conditions_dict, max_conditions_lengths) @@ -113,7 +112,7 @@ class PinaGraphDataset(PinaDataset): to_return_dict[condition] = {k: Batch.from_data_list([v[i] for i in cond_idx]) if isinstance(v, list) - else v[cond_idx] + else v[cond_idx].reshape(-1, *v[cond_idx].shape[2:]) for k, v in data.items() } return to_return_dict @@ -132,5 +131,4 @@ class PinaGraphDataset(PinaDataset): return self.fetch_from_idx_list(index) def __getitem__(self, idx): - return self._getitem_func(idx) -''' \ No newline at end of file + return self._getitem_func(idx) \ No newline at end of file