""" This module provide basic data management functionalities """ import functools import torch from torch.utils.data import Dataset from abc import abstractmethod from torch_geometric.data import Batch, Data from pina import LabelTensor class PinaDatasetFactory: """ Factory class for the PINA dataset. Depending on the type inside the conditions it creates a different dataset object: - PinaTensorDataset for torch.Tensor - PinaGraphDataset for list of torch_geometric.data.Data objects """ def __new__(cls, conditions_dict, **kwargs): if len(conditions_dict) == 0: raise ValueError("No conditions provided") if all( [ isinstance(v["input_points"], torch.Tensor) for v in conditions_dict.values() ] ): return PinaTensorDataset(conditions_dict, **kwargs) elif all( [ isinstance(v["input_points"], list) for v in conditions_dict.values() ] ): return PinaGraphDataset(conditions_dict, **kwargs) raise ValueError( "Conditions must be either torch.Tensor or list of Data " "objects." ) class PinaDataset(Dataset): """ Abstract class for the PINA dataset """ def __init__(self, conditions_dict, max_conditions_lengths): self.conditions_dict = conditions_dict self.max_conditions_lengths = max_conditions_lengths self.conditions_length = { k: len(v["input_points"]) for k, v in self.conditions_dict.items() } self.length = max(self.conditions_length.values()) def _get_max_len(self): max_len = 0 for condition in self.conditions_dict.values(): max_len = max(max_len, len(condition["input_points"])) return max_len def __len__(self): return self.length @abstractmethod def __getitem__(self, item): pass class PinaTensorDataset(PinaDataset): def __init__( self, conditions_dict, max_conditions_lengths, automatic_batching ): super().__init__(conditions_dict, max_conditions_lengths) if automatic_batching: self._getitem_func = self._getitem_int else: self._getitem_func = self._getitem_dummy def _getitem_int(self, idx): return { k: { k_data: v[k_data][idx % len(v["input_points"])] for k_data in v.keys() } for k, v in self.conditions_dict.items() } def fetch_from_idx_list(self, idx): to_return_dict = {} for condition, data in self.conditions_dict.items(): cond_idx = idx[: self.max_conditions_lengths[condition]] condition_len = self.conditions_length[condition] if self.length > condition_len: cond_idx = [idx % condition_len for idx in cond_idx] to_return_dict[condition] = { k: v[cond_idx] for k, v in data.items() } return to_return_dict @staticmethod def _getitem_dummy(idx): return idx def get_all_data(self): index = [i for i in range(len(self))] return self.fetch_from_idx_list(index) def __getitem__(self, idx): return self._getitem_func(idx) @property def input_points(self): """ Method to return input points for training. """ return {k: v["input_points"] for k, v in self.conditions_dict.items()} class PinaBatch(Batch): """ Add extract function to torch_geometric Batch object """ def __init__(self): super().__init__(self) def extract(self, labels): """ Perform extraction of labels on node features (x) :param labels: Labels to extract :type labels: list[str] | tuple[str] | str :return: Batch object with extraction performed on x :rtype: PinaBatch """ self.x = self.x.extract(labels) return self class PinaGraphDataset(PinaDataset): def __init__( self, conditions_dict, max_conditions_lengths, automatic_batching ): super().__init__(conditions_dict, max_conditions_lengths) self.in_labels = {} self.out_labels = None if automatic_batching: self._getitem_func = self._getitem_int else: self._getitem_func = self._getitem_dummy ex_data = conditions_dict[list(conditions_dict.keys())[0]][ "input_points" ][0] for name, attr in ex_data.items(): if isinstance(attr, LabelTensor): self.in_labels[name] = attr.stored_labels ex_data = conditions_dict[list(conditions_dict.keys())[0]][ "output_points" ][0] if isinstance(ex_data, LabelTensor): self.out_labels = ex_data.labels self._create_graph_batch_from_list = ( self._labelise_batch(self._base_create_graph_batch_from_list) if self.in_labels else self._base_create_graph_batch_from_list ) self._create_output_batch = ( self._labelise_tensor(self._base_create_output_batch) if self.out_labels is not None else self._base_create_output_batch ) def fetch_from_idx_list(self, idx): to_return_dict = {} for condition, data in self.conditions_dict.items(): cond_idx = idx[: self.max_conditions_lengths[condition]] condition_len = self.conditions_length[condition] if self.length > condition_len: cond_idx = [idx % condition_len for idx in cond_idx] to_return_dict[condition] = { k: ( self._create_graph_batch_from_list([v[i] for i in idx]) if isinstance(v, list) else self._create_output_batch(v[idx]) ) for k, v in data.items() } return to_return_dict def _base_create_graph_batch_from_list(self, data): batch = PinaBatch.from_data_list(data) return batch def _base_create_output_batch(self, data): out = data.reshape(-1, *data.shape[2:]) return out def _getitem_dummy(self, idx): return idx def _getitem_int(self, idx): return { k: { k_data: v[k_data][idx % len(v["input_points"])] for k_data in v.keys() } for k, v in self.conditions_dict.items() } def get_all_data(self): index = [i for i in range(len(self))] return self.fetch_from_idx_list(index) def __getitem__(self, idx): return self._getitem_func(idx) def _labelise_batch(self, func): @functools.wraps(func) def wrapper(*args, **kwargs): batch = func(*args, **kwargs) for k, v in self.in_labels.items(): tmp = batch[k] tmp.labels = v batch[k] = tmp return batch return wrapper def _labelise_tensor(self, func): @functools.wraps(func) def wrapper(*args, **kwargs): out = func(*args, **kwargs) if isinstance(out, LabelTensor): out.labels = self.out_labels return out return wrapper def create_graph_batch(self, data): """ # TODO """ if isinstance(data[0], Data): return self._create_graph_batch_from_list(data) return self._create_output_batch(data)