From 6bb44052b07ea2800697891f64499d196355010c Mon Sep 17 00:00:00 2001 From: FilippoOlivo Date: Thu, 13 Nov 2025 10:48:47 +0100 Subject: [PATCH] add update_data and input functions --- pina/data/dataset.py | 67 ++++++++++++++++---------------------------- 1 file changed, 24 insertions(+), 43 deletions(-) diff --git a/pina/data/dataset.py b/pina/data/dataset.py index 88e86fe..674b512 100644 --- a/pina/data/dataset.py +++ b/pina/data/dataset.py @@ -54,6 +54,7 @@ class PinaDataset(Dataset): automatic_batching if automatic_batching is not None else True ) self.stack_fn = {} + self.is_graph_dataset = False # Determine stacking functions for each data type (used in collate_fn) for k, v in data_dict.items(): if isinstance(v, LabelTensor): @@ -64,6 +65,7 @@ class PinaDataset(Dataset): isinstance(item, (Data, Graph)) for item in v ): self.stack_fn[k] = LabelBatch.from_data_list + self.is_graph_dataset = True else: raise ValueError( f"Unsupported data type for stacking: {type(v)}" @@ -104,55 +106,34 @@ class PinaDataset(Dataset): [data[i] for i in idx_list] ) else: + print(data) to_return[field_name] = data[idx_list] return to_return - -class PinaGraphDataset(Dataset): - def __init__(self, data_dict, automatic_batching=None): + def update_data(self, update_dict): """ - Initialize the instance by storing the conditions dictionary. - - :param dict conditions_dict: A dictionary mapping condition names to - their respective data. Each key represents a condition name, and the - corresponding value is a dictionary containing the associated data. + Update the dataset's data in-place. + :param dict update_dict: A dictionary where keys are condition names + and values are dictionaries with updated data for those conditions. """ + for field_name, updates in update_dict.items(): + if field_name not in self.data: + raise KeyError( + f"Condition '{field_name}' not found in dataset." + ) + if not isinstance(updates, (LabelTensor, torch.Tensor)): + raise ValueError( + f"Updates for condition '{field_name}' must be of type " + f"LabelTensor or torch.Tensor." + ) + self.data[field_name] = updates - # Store the conditions dictionary - self.data = data_dict - self.automatic_batching = ( - automatic_batching if automatic_batching is not None else True - ) - - def __len__(self): - return len(next(iter(self.data.values()))) - - def __getitem__(self, idx): + @property + def input(self): """ - Return the data at the given index in the dataset. + Get the input data from the dataset. - :param int idx: Index. - :return: A dictionary containing the data at the given index. - :rtype: dict + :return: The input data. + :rtype: torch.Tensor | LabelTensor | Data | Graph """ - - if self.automatic_batching: - # Return the data at the given index - return { - field_name: data[idx] for field_name, data in self.data.items() - } - return idx - - def _getitem_from_list(self, idx_list): - """ - Return data from the dataset given a list of indices. - - :param list[int] idx_list: List of indices. - :return: A dictionary containing the data at the given indices. - :rtype: dict - """ - - return { - field_name: [data[i] for i in idx_list] - for field_name, data in self.data.items() - } + return self.data["input"]