diff --git a/pina/__init__.py b/pina/__init__.py index 30f35a6..c02e6de 100644 --- a/pina/__init__.py +++ b/pina/__init__.py @@ -1,6 +1,7 @@ __all__ = [ - "PINN", "Trainer", "LabelTensor", "Plotter", "Condition", - "SamplePointDataset", "PinaDataModule", "PinaDataLoader" + "Trainer", "LabelTensor", "Plotter", "Condition", + "SamplePointDataset", "PinaDataModule", "PinaDataLoader", + 'TorchOptimizer', 'Graph' ] from .meta import * @@ -12,3 +13,6 @@ from .condition.condition import Condition from .data import SamplePointDataset from .data import PinaDataModule from .data import PinaDataLoader +from .optim import TorchOptimizer +from .optim import TorchScheduler +from .graph import Graph \ No newline at end of file diff --git a/pina/collector.py b/pina/collector.py index c48c674..e75b49c 100644 --- a/pina/collector.py +++ b/pina/collector.py @@ -1,6 +1,3 @@ -from sympy.strategies.branch import condition - -from . import LabelTensor from .utils import check_consistency, merge_tensors @@ -16,6 +13,8 @@ class Collector: # } # those variables are used for the dataloading self._data_collections = {name: {} for name in self.problem.conditions} + self.conditions_name = {i: name for i, name in + enumerate(self.problem.conditions)} # variables used to check that all conditions are sampled self._is_conditions_ready = { @@ -101,7 +100,8 @@ class Collector: """ Add input points to a sampled condition - :param new_points_dict: Dictonary of input points (condition_name: LabelTensor) + :param new_points_dict: Dictonary of input points (condition_name: + LabelTensor) :raises RuntimeError: if at least one condition is not already sampled """ for k, v in new_points_dict.items(): diff --git a/pina/data/base_dataset.py b/pina/data/base_dataset.py index d859aac..5e27fc9 100644 --- a/pina/data/base_dataset.py +++ b/pina/data/base_dataset.py @@ -1,10 +1,12 @@ """ Basic data module implementation """ -from torch.utils.data import Dataset import torch +import logging + +from torch.utils.data import Dataset + from ..label_tensor import LabelTensor -from ..graph import Graph class BaseDataset(Dataset): @@ -12,10 +14,9 @@ class BaseDataset(Dataset): BaseDataset class, which handle initialization and data retrieval :var condition_indices: List of indices :var device: torch.device - :var condition_names: dict of condition index and corresponding name """ - def __new__(cls, problem, device): + def __new__(cls, problem=None, device=torch.device('cpu')): """ Ensure correct definition of __slots__ before initialization :param AbstractProblem problem: The formulation of the problem. @@ -30,7 +31,7 @@ class BaseDataset(Dataset): 'Something is wrong, __slots__ must be defined in subclasses.') return object.__new__(cls) - def __init__(self, problem, device): + def __init__(self, problem=None, device=torch.device('cpu')): """" Initialize the object based on __slots__ :param AbstractProblem problem: The formulation of the problem. @@ -38,79 +39,118 @@ class BaseDataset(Dataset): dataset will be loaded. """ super().__init__() - - self.condition_names = {} - collector = problem.collector + self.empty = True + self.problem = problem + self.device = device + self.condition_indices = None for slot in self.__slots__: setattr(self, slot, []) - num_el_per_condition = [] - idx = 0 - for name, data in collector.data_collections.items(): + self.num_el_per_condition = [] + self.conditions_idx = [] + if self.problem is not None: + self._init_from_problem(self.problem.collector.data_collections) + self.initialized = False + + def _init_from_problem(self, collector_dict): + """ + TODO + """ + for name, data in collector_dict.items(): keys = list(data.keys()) - current_cond_num_el = None - if sorted(self.__slots__) == sorted(keys): - for slot in self.__slots__: - slot_data = data[slot] - if isinstance(slot_data, (LabelTensor, torch.Tensor, - Graph)): - if current_cond_num_el is None: - current_cond_num_el = len(slot_data) - elif current_cond_num_el != len(slot_data): - raise ValueError('Different number of conditions') - current_list = getattr(self, slot) - current_list += [data[slot]] if not ( - isinstance(data[slot], list)) else data[slot] - num_el_per_condition.append(current_cond_num_el) - self.condition_names[idx] = name - idx += 1 - if num_el_per_condition: + if set(self.__slots__) == set(keys): + self._populate_init_list(data) + idx = [key for key, val in + self.problem.collector.conditions_name.items() if + val == name] + self.conditions_idx.append(idx) + self.initialize() + + def add_points(self, data_dict, condition_idx, batching_dim=0): + """ + This method filled internal lists of data points + :param data_dict: dictionary containing data points + :param condition_idx: index of the condition to which the data points + belong to + :param batching_dim: dimension of the batching + :raises: ValueError if the dataset has already been initialized + """ + if not self.initialized: + self._populate_init_list(data_dict, batching_dim) + self.conditions_idx.append(condition_idx) + self.empty = False + else: + raise ValueError('Dataset already initialized') + + def _populate_init_list(self, data_dict, batching_dim=0): + current_cond_num_el = None + for slot in data_dict.keys(): + slot_data = data_dict[slot] + if batching_dim != 0: + if isinstance(slot_data, (LabelTensor, torch.Tensor)): + dims = len(slot_data.size()) + slot_data = slot_data.permute( + [batching_dim] + [dim for dim in range(dims) if + dim != batching_dim]) + if current_cond_num_el is None: + current_cond_num_el = len(slot_data) + elif current_cond_num_el != len(slot_data): + raise ValueError('Different dimension in same condition') + current_list = getattr(self, slot) + current_list += [slot_data] if not ( + isinstance(slot_data, list)) else slot_data + self.num_el_per_condition.append(current_cond_num_el) + + def initialize(self): + """ + Initialize the datasets tensors/LabelTensors/lists given the lists + already filled + """ + logging.debug(f'Initialize dataset {self.__class__.__name__}') + + if self.num_el_per_condition: self.condition_indices = torch.cat( [ - torch.tensor([i] * num_el_per_condition[i], + torch.tensor([i] * self.num_el_per_condition[i], dtype=torch.uint8) - for i in range(len(num_el_per_condition)) + for i in range(len(self.num_el_per_condition)) ], - dim=0, + dim=0 ) for slot in self.__slots__: current_attribute = getattr(self, slot) if all(isinstance(a, LabelTensor) for a in current_attribute): setattr(self, slot, LabelTensor.vstack(current_attribute)) - else: - self.condition_indices = torch.tensor([], dtype=torch.uint8) - for slot in self.__slots__: - setattr(self, slot, torch.tensor([])) - self.device = device + self.initialized = True def __len__(self): + """ + :return: Number of elements in the dataset + """ return len(getattr(self, self.__slots__[0])) - def __getattribute__(self, item): - attribute = super().__getattribute__(item) - if isinstance(attribute, - LabelTensor) and attribute.dtype == torch.float32: - attribute = attribute.to(device=self.device).requires_grad_() - return attribute - def __getitem__(self, idx): - if isinstance(idx, str): - return getattr(self, idx).to(self.device) - if isinstance(idx, slice): - to_return_list = [] - for i in self.__slots__: - to_return_list.append(getattr(self, i)[idx].to(self.device)) - return to_return_list + """ + :param idx: + :return: + """ + if not isinstance(idx, (tuple, list, slice, int)): + raise IndexError("Invalid index") + tensors = [] + for attribute in self.__slots__: + tensor = getattr(self, attribute) + if isinstance(attribute, (LabelTensor, torch.Tensor)): + tensors.append(tensor.__getitem__(idx)) + elif isinstance(attribute, list): + if isinstance(idx, (list, tuple)): + tensor = [tensor[i] for i in idx] + tensors.append(tensor) + return tensors - if isinstance(idx, (tuple, list)): - if (len(idx) == 2 and isinstance(idx[0], str) - and isinstance(idx[1], (list, slice))): - tensor = getattr(self, idx[0]) - return tensor[[idx[1]]].to(self.device) - if all(isinstance(x, int) for x in idx): - to_return_list = [] - for i in self.__slots__: - to_return_list.append( - getattr(self, i)[[idx]].to(self.device)) - return to_return_list - - raise ValueError(f'Invalid index {idx}') + def apply_shuffle(self, indices): + for slot in self.__slots__: + if slot != 'equation': + attribute = getattr(self, slot) + if isinstance(attribute, (LabelTensor, torch.Tensor)): + setattr(self, 'slot', attribute[[indices]]) + if isinstance(attribute, list): + setattr(self, 'slot', [attribute[i] for i in indices]) diff --git a/pina/data/data_module.py b/pina/data/data_module.py index 25c7e54..98460ae 100644 --- a/pina/data/data_module.py +++ b/pina/data/data_module.py @@ -4,7 +4,8 @@ This module provide basic data management functionalities import math import torch -from lightning import LightningDataModule +import logging +from pytorch_lightning import LightningDataModule from .sample_dataset import SamplePointDataset from .supervised_dataset import SupervisedDataset from .unsupervised_dataset import UnsupervisedDataset @@ -22,8 +23,9 @@ class PinaDataModule(LightningDataModule): problem, device, train_size=.7, - test_size=.2, - eval_size=.1, + test_size=.1, + val_size=.2, + predict_size=0., batch_size=None, shuffle=True, datasets=None): @@ -37,37 +39,64 @@ class PinaDataModule(LightningDataModule): :param batch_size: batch size used for training :param datasets: list of datasets objects """ + logging.debug('Start initialization of Pina DataModule') + logging.info('Start initialization of Pina DataModule') super().__init__() - dataset_classes = [SupervisedDataset, UnsupervisedDataset, - SamplePointDataset] + self.problem = problem + self.device = device + self.dataset_classes = [SupervisedDataset, UnsupervisedDataset, + SamplePointDataset] if datasets is None: - self.datasets = [DatasetClass(problem, device) for DatasetClass in - dataset_classes] + self.datasets = None else: self.datasets = datasets self.split_length = [] self.split_names = [] + self.loader_functions = {} + self.batch_size = batch_size + self.condition_names = problem.collector.conditions_name + if train_size > 0: self.split_names.append('train') self.split_length.append(train_size) + self.loader_functions['train_dataloader'] = lambda: PinaDataLoader( + self.splits['train'], self.batch_size, self.condition_names) if test_size > 0: self.split_length.append(test_size) self.split_names.append('test') - if eval_size > 0: - self.split_length.append(eval_size) - self.split_names.append('eval') - - self.batch_size = batch_size - self.condition_names = None + self.loader_functions['test_dataloader'] = lambda: PinaDataLoader( + self.splits['test'], self.batch_size, self.condition_names) + if val_size > 0: + self.split_length.append(val_size) + self.split_names.append('val') + self.loader_functions['val_dataloader'] = lambda: PinaDataLoader( + self.splits['val'], self.batch_size, + self.condition_names) + if predict_size > 0: + self.split_length.append(predict_size) + self.split_names.append('predict') + self.loader_functions[ + 'predict_dataloader'] = lambda: PinaDataLoader( + self.splits['predict'], self.batch_size, + self.condition_names) self.splits = {k: {} for k in self.split_names} self.shuffle = shuffle + for k, v in self.loader_functions.items(): + setattr(self, k, v) + + def prepare_data(self): + if self.datasets is None: + self._create_datasets() + def setup(self, stage=None): """ Perform the splitting of the dataset """ - self.extract_conditions() + logging.debug('Start setup of Pina DataModule obj') + if self.datasets is None: + self._create_datasets() if stage == 'fit' or stage is None: for dataset in self.datasets: if len(dataset) > 0: @@ -82,53 +111,6 @@ class PinaDataModule(LightningDataModule): else: raise ValueError("stage must be either 'fit' or 'test'") - def extract_conditions(self): - """ - Extract conditions from dataset and update condition indices - """ - # Extract number of conditions - n_conditions = 0 - for dataset in self.datasets: - if n_conditions != 0: - dataset.condition_names = { - key + n_conditions: value - for key, value in dataset.condition_names.items() - } - n_conditions += len(dataset.condition_names) - - self.condition_names = { - key: value - for dataset in self.datasets - for key, value in dataset.condition_names.items() - } - - def train_dataloader(self): - """ - Return the training dataloader for the dataset - :return: data loader - :rtype: PinaDataLoader - """ - return PinaDataLoader(self.splits['train'], self.batch_size, - self.condition_names) - - def test_dataloader(self): - """ - Return the testing dataloader for the dataset - :return: data loader - :rtype: PinaDataLoader - """ - return PinaDataLoader(self.splits['test'], self.batch_size, - self.condition_names) - - def eval_dataloader(self): - """ - Return the evaluation dataloader for the dataset - :return: data loader - :rtype: PinaDataLoader - """ - return PinaDataLoader(self.splits['eval'], self.batch_size, - self.condition_names) - @staticmethod def dataset_split(dataset, lengths, seed=None, shuffle=True): """ @@ -141,30 +123,28 @@ class PinaDataModule(LightningDataModule): :rtype: PinaSubset """ if sum(lengths) - 1 < 1e-3: + len_dataset = len(dataset) lengths = [ - int(math.floor(len(dataset) * length)) for length in lengths + int(math.floor(len_dataset * length)) for length in lengths ] - remainder = len(dataset) - sum(lengths) for i in range(remainder): lengths[i % len(lengths)] += 1 elif sum(lengths) - 1 >= 1e-3: raise ValueError(f"Sum of lengths is {sum(lengths)} less than 1") - if sum(lengths) != len(dataset): - raise ValueError("Sum of lengths is not equal to dataset length") - if shuffle: if seed is not None: generator = torch.Generator() generator.manual_seed(seed) indices = torch.randperm(sum(lengths), - generator=generator).tolist() + generator=generator) else: - indices = torch.arange(sum(lengths)).tolist() - else: - indices = torch.arange(0, sum(lengths), 1, - dtype=torch.uint8).tolist() + indices = torch.randperm(sum(lengths)) + dataset.apply_shuffle(indices) + + indices = torch.arange(0, sum(lengths), 1, + dtype=torch.uint8).tolist() offsets = [ sum(lengths[:i]) if i > 0 else 0 for i in range(len(lengths)) ] @@ -172,3 +152,29 @@ class PinaDataModule(LightningDataModule): PinaSubset(dataset, indices[offset:offset + length]) for offset, length in zip(offsets, lengths) ] + + def _create_datasets(self): + """ + Create the dataset objects putting data + """ + logging.debug('Dataset creation in PinaDataModule obj') + collector = self.problem.collector + batching_dim = self.problem.batching_dimension + datasets_slots = [i.__slots__ for i in self.dataset_classes] + self.datasets = [dataset(device=self.device) for dataset in + self.dataset_classes] + logging.debug('Filling datasets in PinaDataModule obj') + for name, data in collector.data_collections.items(): + keys = list(data.keys()) + idx = [key for key, val in collector.conditions_name.items() if + val == name] + for i, slot in enumerate(datasets_slots): + if slot == keys: + self.datasets[i].add_points(data, idx[0], batching_dim) + continue + datasets = [] + for dataset in self.datasets: + if not dataset.empty: + dataset.initialize() + datasets.append(dataset) + self.datasets = datasets diff --git a/pina/data/pina_batch.py b/pina/data/pina_batch.py index 65b5ac5..6fb74f9 100644 --- a/pina/data/pina_batch.py +++ b/pina/data/pina_batch.py @@ -10,13 +10,15 @@ class Batch: optimization. """ - def __init__(self, dataset_dict, idx_dict): - + def __init__(self, dataset_dict, idx_dict, require_grad=True): + self.attributes = [] for k, v in dataset_dict.items(): setattr(self, k, v) + self.attributes.append(k) for k, v in idx_dict.items(): setattr(self, k + '_idx', v) + self.require_grad = require_grad def __len__(self): """ @@ -31,9 +33,18 @@ class Batch: length += len(getattr(self, dataset)) return length + def __getattribute__(self, item): + if item in super().__getattribute__('attributes'): + dataset = super().__getattribute__(item) + index = super().__getattribute__(item + '_idx') + return PinaSubset( + dataset.dataset, + dataset.indices[index]) + else: + return super().__getattribute__(item) + def __getattr__(self, item): - if not item in dir(self): - raise AttributeError(f'Batch instance has no attribute {item}') - return PinaSubset( - getattr(self, item).dataset, - getattr(self, item).indices[self.coordinates_dict[item]]) + if item == 'data' and len(self.attributes) == 1: + item = self.attributes[0] + return super().__getattribute__(item) + raise AttributeError(f"'Batch' object has no attribute '{item}'") \ No newline at end of file diff --git a/pina/data/pina_subset.py b/pina/data/pina_subset.py index f1347b6..275541e 100644 --- a/pina/data/pina_subset.py +++ b/pina/data/pina_subset.py @@ -2,21 +2,22 @@ Module for PinaSubset class """ from pina import LabelTensor -from torch import Tensor +from torch import Tensor, float32 class PinaSubset: """ TODO """ - __slots__ = ['dataset', 'indices'] + __slots__ = ['dataset', 'indices', 'require_grad'] - def __init__(self, dataset, indices): + def __init__(self, dataset, indices, require_grad=True): """ TODO """ self.dataset = dataset self.indices = indices + self.require_grad = require_grad def __len__(self): """ @@ -27,7 +28,9 @@ class PinaSubset: def __getattr__(self, name): tensor = self.dataset.__getattribute__(name) if isinstance(tensor, (LabelTensor, Tensor)): - return tensor[self.indices] + tensor = tensor[[self.indices]].to(self.dataset.device) + return tensor.requires_grad_( + self.require_grad) if tensor.dtype == float32 else tensor if isinstance(tensor, list): return [tensor[i] for i in self.indices] - raise AttributeError("No attribute named {}".format(name)) + raise AttributeError(f"No attribute named {name}") diff --git a/pina/data/sample_dataset.py b/pina/data/sample_dataset.py index 99811ca..5c47a14 100644 --- a/pina/data/sample_dataset.py +++ b/pina/data/sample_dataset.py @@ -1,8 +1,9 @@ """ Sample dataset module """ +from copy import deepcopy from .base_dataset import BaseDataset -from ..condition.input_equation_condition import InputPointsEquationCondition +from ..condition import InputPointsEquationCondition class SamplePointDataset(BaseDataset): @@ -12,3 +13,21 @@ class SamplePointDataset(BaseDataset): """ data_type = 'physics' __slots__ = InputPointsEquationCondition.__slots__ + + def add_points(self, data_dict, condition_idx, batching_dim=0): + data_dict = deepcopy(data_dict) + data_dict.pop('equation') + super().add_points(data_dict, condition_idx) + + def _init_from_problem(self, collector_dict, batching_dim=0): + for name, data in collector_dict.items(): + keys = list(data.keys()) + if set(self.__slots__) == set(keys): + data = deepcopy(data) + data.pop('equation') + self._populate_init_list(data) + idx = [key for key, val in + self.problem.collector.conditions_name.items() if + val == name] + self.conditions_idx.append(idx) + self.initialize() diff --git a/pina/label_tensor.py b/pina/label_tensor.py index 87def2f..a28a3ea 100644 --- a/pina/label_tensor.py +++ b/pina/label_tensor.py @@ -1,5 +1,5 @@ """ Module for LabelTensor """ -from copy import deepcopy, copy +from copy import copy import torch from torch import Tensor @@ -8,21 +8,29 @@ def issubset(a, b): """ Check if a is a subset of b. """ - return set(a).issubset(set(b)) + if isinstance(a, list) and isinstance(b, list): + return set(a).issubset(set(b)) + elif isinstance(a, range) and isinstance(b, range): + return a.start <= b.start and a.stop >= b.stop + else: + return False class LabelTensor(torch.Tensor): """Torch tensor with a label for any column.""" @staticmethod - def __new__(cls, x, labels, *args, **kwargs): - return super().__new__(cls, x, *args, **kwargs) + def __new__(cls, x, labels, full=True, *args, **kwargs): + if isinstance(x, LabelTensor): + return x + else: + return super().__new__(cls, x, *args, **kwargs) @property def tensor(self): return self.as_subclass(Tensor) - def __init__(self, x, labels): + def __init__(self, x, labels, full=False): """ Construct a `LabelTensor` by passing a dict of the labels @@ -34,8 +42,17 @@ class LabelTensor(torch.Tensor): """ self.dim_names = None + self.full = full self.labels = labels + @classmethod + def __internal_init__(cls, x, labels, dim_names ,full=False, *args, **kwargs): + lt = cls.__new__(cls, x, labels, full, *args, **kwargs) + lt._labels = labels + lt.full = full + lt.dim_names = dim_names + return lt + @property def labels(self): """Property decorator for labels @@ -43,12 +60,29 @@ class LabelTensor(torch.Tensor): :return: labels of self :rtype: list """ - return self._labels[self.tensor.ndim - 1]['dof'] + if self.ndim - 1 in self._labels.keys(): + return self._labels[self.ndim - 1]['dof'] @property def full_labels(self): """Property decorator for labels + :return: labels of self + :rtype: list + """ + to_return_dict = {} + shape_tensor = self.shape + for i in range(len(shape_tensor)): + if i in self._labels.keys(): + to_return_dict[i] = self._labels[i] + else: + to_return_dict[i] = {'dof': range(shape_tensor[i]), 'name': i} + return to_return_dict + + @property + def stored_labels(self): + """Property decorator for labels + :return: labels of self :rtype: list """ @@ -62,26 +96,77 @@ class LabelTensor(torch.Tensor): :param labels: Labels to assign to the class variable _labels. :type: labels: str | list(str) | dict """ - if hasattr(self, 'labels') is False: - self.init_labels() + if not hasattr(self, '_labels'): + self._labels = {} if isinstance(labels, dict): - self.update_labels_from_dict(labels) + self._init_labels_from_dict(labels) elif isinstance(labels, list): - self.update_labels_from_list(labels) + self._init_labels_from_list(labels) elif isinstance(labels, str): labels = [labels] - self.update_labels_from_list(labels) + self._init_labels_from_list(labels) else: raise ValueError("labels must be list, dict or string.") self.set_names() + def _init_labels_from_dict(self, labels): + """ + Update the internal label representation according to the values + passed as input. + + :param labels: The label(s) to update. + :type labels: dict + :raises ValueError: dof list contain duplicates or number of dof + does not match with tensor shape + """ + tensor_shape = self.shape + + if hasattr(self, 'full') and self.full: + labels = {i: labels[i] if i in labels else {'name': i} for i in + labels.keys()} + for k, v in labels.items(): + # Init labels from str + if isinstance(v, str): + v = {'name': v, 'dof': range(tensor_shape[k])} + # Init labels from dict + elif isinstance(v, dict) and list(v.keys()) == ['name']: + # Init from dict with only name key + v['dof'] = range(tensor_shape[k]) + # Init from dict with both name and dof keys + elif isinstance(v, dict) and sorted(list(v.keys())) == ['dof', + 'name']: + dof_list = v['dof'] + dof_len = len(dof_list) + if dof_len != len(set(dof_list)): + raise ValueError("dof must be unique") + if dof_len != tensor_shape[k]: + raise ValueError( + 'Number of dof does not match tensor shape') + else: + ValueError('Illegal labels initialization') + # Perform update + self._labels[k] = v + + def _init_labels_from_list(self, labels): + """ + Given a list of dof, this method update the internal label + representation + + :param labels: The label(s) to update. + :type labels: list + """ + # Create a dict with labels + last_dim_labels = { + self.ndim - 1: {'dof': labels, 'name': self.ndim - 1}} + self._init_labels_from_dict(last_dim_labels) + def set_names(self): - labels = self.full_labels + labels = self.stored_labels self.dim_names = {} - for dim in range(self.tensor.ndim): + for dim in labels.keys(): self.dim_names[labels[dim]['name']] = dim - def extract(self, label_to_extract): + def extract(self, labels_to_extract): """ Extract the subset of the original tensor by returning all the columns corresponding to the passed ``label_to_extract``. @@ -91,78 +176,68 @@ class LabelTensor(torch.Tensor): :raises TypeError: Labels are not ``str``. :raises ValueError: Label to extract is not in the labels ``list``. """ - if isinstance(label_to_extract, (str, int)): - label_to_extract = [label_to_extract] - if isinstance(label_to_extract, (tuple, list)): - return self._extract_from_list(label_to_extract) - if isinstance(label_to_extract, dict): - return self._extract_from_dict(label_to_extract) - raise ValueError('labels_to_extract must be str or list or dict') + # Convert str/int to string + if isinstance(labels_to_extract, (str, int)): + labels_to_extract = [labels_to_extract] - def _extract_from_list(self, labels_to_extract): - # Store locally all necessary obj/variables - ndim = self.tensor.ndim - labels = self.full_labels - tensor = self.tensor - last_dim_label = self.labels + # Store useful variables + labels = self.stored_labels + stored_keys = labels.keys() + dim_names = self.dim_names + ndim = len(super().shape) - # Verify if all the labels in labels_to_extract are in last dimension - if set(labels_to_extract).issubset(last_dim_label) is False: - raise ValueError( - 'Cannot extract a dof which is not in the original LabelTensor') - - # Extract index to extract - idx_to_extract = [last_dim_label.index(i) for i in labels_to_extract] - - # Perform extraction - new_tensor = tensor[..., idx_to_extract] - - # Manage labels - new_labels = copy(labels) - - last_dim_new_label = {ndim - 1: { - 'dof': list(labels_to_extract), - 'name': labels[ndim - 1]['name'] - }} - new_labels.update(last_dim_new_label) - return LabelTensor(new_tensor, new_labels) - - def _extract_from_dict(self, labels_to_extract): - labels = self.full_labels - tensor = self.tensor - ndim = tensor.ndim - new_labels = deepcopy(labels) - new_tensor = tensor - for k, _ in labels_to_extract.items(): - idx_dim = self.dim_names[k] - dim_labels = labels[idx_dim]['dof'] - if isinstance(labels_to_extract[k], (int, str)): - labels_to_extract[k] = [labels_to_extract[k]] - if set(labels_to_extract[k]).issubset(dim_labels) is False: + # Convert tuple/list to dict + if isinstance(labels_to_extract, (tuple, list)): + if not ndim - 1 in stored_keys: raise ValueError( - 'Cannot extract a dof which is not in the original ' - 'LabelTensor') - idx_to_extract = [dim_labels.index(i) for i in labels_to_extract[k]] - indexer = [slice(None)] * idx_dim + [idx_to_extract] + [ - slice(None)] * (ndim - idx_dim - 1) - new_tensor = new_tensor[indexer] - dim_new_label = {idx_dim: { - 'dof': labels_to_extract[k], - 'name': labels[idx_dim]['name'] - }} - new_labels.update(dim_new_label) - return LabelTensor(new_tensor, new_labels) + "LabelTensor does not have labels in last dimension") + name = labels[max(stored_keys)]['name'] + labels_to_extract = {name: list(labels_to_extract)} + + # If labels_to_extract is not dict then rise error + if not isinstance(labels_to_extract, dict): + raise ValueError('labels_to_extract must be str or list or dict') + + # Make copy of labels (avoid issue in consistency) + updated_labels = {k: copy(v) for k, v in labels.items()} + + # Initialize list used to perform extraction + extractor = [slice(None) for _ in range(ndim)] + + # Loop over labels_to_extract dict + for k, v in labels_to_extract.items(): + + # If label is not find raise value error + idx_dim = dim_names.get(k) + if idx_dim is None: + raise ValueError( + 'Cannot extract label with is not in original labels') + + dim_labels = labels[idx_dim]['dof'] + v = [v] if isinstance(v, (int, str)) else v + + if not isinstance(v, range): + extractor[idx_dim] = [dim_labels.index(i) for i in v] if len( + v) > 1 else slice(dim_labels.index(v[0]), + dim_labels.index(v[0]) + 1) + else: + extractor[idx_dim] = slice(v.start, v.stop) + + updated_labels.update({idx_dim: {'dof': v, 'name': k}}) + + tensor = self.tensor + tensor = tensor[extractor] + return LabelTensor.__internal_init__(tensor, updated_labels, dim_names) def __str__(self): """ returns a string with the representation of the class """ - s = '' for key, value in self._labels.items(): s += f"{key}: {value}\n" s += '\n' - s += super().__str__() + s += self.tensor.__str__() return s @staticmethod @@ -174,55 +249,44 @@ class LabelTensor(torch.Tensor): :param tensors: tensors to concatenate :type tensors: list(LabelTensor) - :param dim: dimensions on which you want to perform the operation (default 0) + :param dim: dimensions on which you want to perform the operation + (default 0) :type dim: int :rtype: LabelTensor :raises ValueError: either number dof or dimensions names differ """ if len(tensors) == 0: return [] - if len(tensors) == 1: + if len(tensors) == 1 or isinstance(tensors, LabelTensor): return tensors[0] - new_labels_cat_dim = LabelTensor._check_validity_before_cat(tensors, - dim) - # Perform cat on tensors new_tensor = torch.cat(tensors, dim=dim) # Update labels - labels = tensors[0].full_labels - labels.pop(dim) - new_labels_cat_dim = new_labels_cat_dim if len( - set(new_labels_cat_dim)) == len(new_labels_cat_dim) \ - else range(new_tensor.shape[dim]) - labels[dim] = {'dof': new_labels_cat_dim, - 'name': tensors[1].full_labels[dim]['name']} - return LabelTensor(new_tensor, labels) + labels = LabelTensor.__create_labels_cat(tensors, + dim) + + return LabelTensor.__internal_init__(new_tensor, labels, tensors[0].dim_names) @staticmethod - def _check_validity_before_cat(tensors, dim): - n_dims = tensors[0].ndim - new_labels_cat_dim = [] + def __create_labels_cat(tensors, dim): # Check if names and dof of the labels are the same in all dimensions # except in dim - for i in range(n_dims): - name = tensors[0].full_labels[i]['name'] - if i != dim: - dof = tensors[0].full_labels[i]['dof'] - for tensor in tensors: - dof_to_check = tensor.full_labels[i]['dof'] - name_to_check = tensor.full_labels[i]['name'] - if dof != dof_to_check or name != name_to_check: - raise ValueError( - 'dimensions must have the same dof and name') - else: - for tensor in tensors: - new_labels_cat_dim += tensor.full_labels[i]['dof'] - name_to_check = tensor.full_labels[i]['name'] - if name != name_to_check: - raise ValueError( - 'Dimensions to concatenate must have the same name') - return new_labels_cat_dim + stored_labels = [tensor.stored_labels for tensor in tensors] + + # check if: + # - labels dict have same keys + # - all labels are the same expect for dimension dim + if not all(all(stored_labels[i][k] == stored_labels[0][k] + for i in range(len(stored_labels))) + for k in stored_labels[0].keys() if k != dim): + raise RuntimeError('tensors must have the same shape and dof') + + labels = {k: copy(v) for k, v in tensors[0].stored_labels.items()} + if dim in labels.keys(): + last_dim_dof = [i for j in stored_labels for i in j[dim]['dof']] + labels[dim]['dof'] = last_dim_dof + return labels def requires_grad_(self, mode=True): lt = super().requires_grad_(mode) @@ -251,52 +315,10 @@ class LabelTensor(torch.Tensor): :return: A copy of the tensor. :rtype: LabelTensor """ - - out = LabelTensor(super().clone(*args, **kwargs), self._labels) + labels = {k: copy(v) for k, v in self._labels.items()} + out = LabelTensor(super().clone(*args, **kwargs), labels) return out - def init_labels(self): - self._labels = { - idx_: { - 'dof': range(self.tensor.shape[idx_]), - 'name': idx_ - } for idx_ in range(self.tensor.ndim) - } - - def update_labels_from_dict(self, labels): - """ - Update the internal label representation according to the values passed - as input. - - :param labels: The label(s) to update. - :type labels: dict - :raises ValueError: dof list contain duplicates or number of dof does - not match with tensor shape - """ - tensor_shape = self.tensor.shape - # Check dimensionality - for k, v in labels.items(): - if len(v['dof']) != len(set(v['dof'])): - raise ValueError("dof must be unique") - if len(v['dof']) != tensor_shape[k]: - raise ValueError( - 'Number of dof does not match with tensor dimension') - # Perform update - self._labels.update(labels) - - def update_labels_from_list(self, labels): - """ - Given a list of dof, this method update the internal label - representation - - :param labels: The label(s) to update. - :type labels: list - """ - # Create a dict with labels - last_dim_labels = { - self.tensor.ndim - 1: {'dof': labels, 'name': self.tensor.ndim - 1}} - self.update_labels_from_dict(last_dim_labels) - @staticmethod def summation(tensors): if len(tensors) == 0: @@ -304,25 +326,30 @@ class LabelTensor(torch.Tensor): if len(tensors) == 1: return tensors[0] # Collect all labels - labels = tensors[0].full_labels + # Check labels of all the tensors in each dimension - for j in range(tensors[0].ndim): - for i in range(1, len(tensors)): - if labels[j] != tensors[i].full_labels[j]: - labels.pop(j) - break - # Sum tensors + if not all(tensor.shape == tensors[0].shape for tensor in tensors) or \ + not all(tensor.full_labels[i] == tensors[0].full_labels[i] for + tensor in tensors for i in range(tensors[0].ndim - 1)): + raise RuntimeError('Tensors must have the same shape and labels') + + last_dim_labels = [] data = torch.zeros(tensors[0].tensor.shape) for tensor in tensors: data += tensor.tensor - new_tensor = LabelTensor(data, labels) - return new_tensor + last_dim_labels.append(tensor.labels) + + last_dim_labels = ['+'.join(items) for items in zip(*last_dim_labels)] + labels = {k: copy(v) for k, v in tensors[0].stored_labels.items()} + labels.update({tensors[0].ndim - 1: {'dof': last_dim_labels, + 'name': tensors[0].name}}) + return LabelTensor(data, labels) def append(self, tensor, mode='std'): if mode == 'std': # Call cat on last dimension new_label_tensor = LabelTensor.cat([self, tensor], - dim=self.tensor.ndim - 1) + dim=self.ndim - 1) elif mode == 'cross': # Crete tensor and call cat on last dimension tensor1 = self @@ -333,7 +360,7 @@ class LabelTensor(torch.Tensor): tensor2 = LabelTensor(tensor2.repeat_interleave(n1, dim=0), labels=tensor2.labels) new_label_tensor = LabelTensor.cat([tensor1, tensor2], - dim=self.tensor.ndim - 1) + dim=self.ndim - 1) else: raise ValueError('mode must be either "std" or "cross"') return new_label_tensor @@ -357,97 +384,76 @@ class LabelTensor(torch.Tensor): :param index: :return: """ - if isinstance(index, str) or (isinstance(index, (tuple, list)) and all( isinstance(a, str) for a in index)): return self.extract(index) + selected_lt = super().__getitem__(index) if isinstance(index, (int, slice)): - return self._getitem_int_slice(index, selected_lt) + index = [index] - if len(index) == self.tensor.ndim: - return self._getitem_full_dim_indexing(index, selected_lt) + if index[0] == Ellipsis: + index = [slice(None)] * (self.ndim - 1) + [index[1]] - if isinstance(index, torch.Tensor) or ( - isinstance(index, (tuple, list)) and all( - isinstance(x, int) for x in index)): - return self._getitem_permutation(index, selected_lt) - raise ValueError('Not recognized index type') - - def _getitem_int_slice(self, index, selected_lt): - """ - :param index: - :param selected_lt: - :return: - """ - if selected_lt.ndim == 1: - selected_lt = selected_lt.reshape(1, -1) if hasattr(self, "labels"): - new_labels = deepcopy(self.full_labels) - to_update_dof = new_labels[0]['dof'][index] - to_update_dof = to_update_dof if isinstance(to_update_dof, ( - tuple, list, range)) else [to_update_dof] - new_labels.update( - {0: {'dof': to_update_dof, 'name': new_labels[0]['name']}} - ) - selected_lt.labels = new_labels - return selected_lt - - def _getitem_full_dim_indexing(self, index, selected_lt): - new_labels = {} - old_labels = self.full_labels - if selected_lt.ndim == 1: - selected_lt = selected_lt.reshape(-1, 1) - new_labels = deepcopy(old_labels) - new_labels[1].update({'dof': old_labels[1]['dof'][index[1]], - 'name': old_labels[1]['name']}) - idx = 0 - for j in range(selected_lt.ndim): - if not isinstance(index[j], int): - if hasattr(self, "labels"): - new_labels.update( - self._update_label_for_dim(old_labels, index[j], idx)) - idx += 1 - selected_lt.labels = new_labels - return selected_lt - - def _getitem_permutation(self, index, selected_lt): - new_labels = deepcopy(self.full_labels) - new_labels.update(self._update_label_for_dim(self.full_labels, index, - 0)) - selected_lt.labels = self.labels + labels = {k: copy(v) for k, v in self.stored_labels.items()} + for j, idx in enumerate(index): + if isinstance(idx, int): + selected_lt = selected_lt.unsqueeze(j) + if j in labels.keys() and idx != slice(None): + self._update_single_label(labels, labels, idx, j) + selected_lt = LabelTensor.__internal_init__(selected_lt, labels, + self.dim_names) return selected_lt @staticmethod - def _update_label_for_dim(old_labels, index, dim): + def _update_single_label(old_labels, to_update_labels, index, dim): """ TODO - :param old_labels: - :param index: - :param dim: + :param old_labels: labels from which retrieve data + :param to_update_labels: labels to update + :param index: index of dof to retain + :param dim: label index :return: """ + old_dof = old_labels[dim]['dof'] + if not isinstance(index, (int, slice)) and len(index) == len( + old_dof) and isinstance(old_dof, range): + return if isinstance(index, torch.Tensor): - index = index.nonzero() + index = index.nonzero(as_tuple=True)[ + 0] if index.dtype == torch.bool else index.tolist() if isinstance(index, list): - return {dim: {'dof': [old_labels[dim]['dof'][i] for i in index], - 'name': old_labels[dim]['name']}} + to_update_labels.update({dim: { + 'dof': [old_dof[i] for i in index], + 'name': old_labels[dim]['name']}}) else: - return {dim: {'dof': old_labels[dim]['dof'][index], - 'name': old_labels[dim]['name']}} + to_update_labels.update({dim: {'dof': old_dof[index], + 'name': old_labels[dim]['name']}}) def sort_labels(self, dim=None): - def argsort(lst): + def arg_sort(lst): return sorted(range(len(lst)), key=lambda x: lst[x]) if dim is None: - dim = self.tensor.ndim - 1 - labels = self.full_labels[dim]['dof'] - sorted_index = argsort(labels) - indexer = [slice(None)] * self.tensor.ndim + dim = self.ndim - 1 + labels = self.stored_labels[dim]['dof'] + sorted_index = arg_sort(labels) + indexer = [slice(None)] * self.ndim indexer[dim] = sorted_index - new_labels = deepcopy(self.full_labels) - new_labels[dim] = {'dof': sorted(labels), - 'name': new_labels[dim]['name']} - return LabelTensor(self.tensor[indexer], new_labels) + return self.__getitem__(indexer) + + def __deepcopy__(self, memo): + from copy import deepcopy + cls = self.__class__ + result = cls(deepcopy(self.tensor), deepcopy(self.stored_labels)) + return result + + def permute(self, *dims): + tensor = super().permute(*dims) + stored_labels = self.stored_labels + keys_list = list(*dims) + labels = {keys_list.index(k): copy(stored_labels[k]) for k in + stored_labels.keys()} + return LabelTensor.__internal_init__(tensor, labels, self.dim_names) diff --git a/pina/operators.py b/pina/operators.py index 083837c..48af1da 100644 --- a/pina/operators.py +++ b/pina/operators.py @@ -85,7 +85,8 @@ def grad(output_, input_, components=None, d=None): raise RuntimeError gradients = grad_scalar_output(output_, input_, d) - elif output_.shape[output_.ndim - 1] >= 2: # vector output ############################## + elif output_.shape[ + output_.ndim - 1] >= 2: # vector output ############################## tensor_to_cat = [] for i, c in enumerate(components): c_output = output_.extract([c]) @@ -143,7 +144,6 @@ def div(output_, input_, components=None, d=None): tensors_to_sum.append(grad_output.extract(c_fields)) labels[i] = c_fields div_result = LabelTensor.summation(tensors_to_sum) - div_result.labels = ["+".join(labels)] return div_result @@ -249,7 +249,8 @@ def laplacian(output_, input_, components=None, d=None, method="std"): result[:, idx] = grad(grad_output, input_, d=di).flatten() to_append_tensors[idx] = grad(grad_output, input_, d=di) labels[idx] = f"dd{ci[0]}dd{di[0]}" - result = LabelTensor.cat(tensors=to_append_tensors, dim=output_.tensor.ndim - 1) + result = LabelTensor.cat(tensors=to_append_tensors, + dim=output_.tensor.ndim - 1) result.labels = labels return result diff --git a/pina/problem/abstract_problem.py b/pina/problem/abstract_problem.py index 600a688..f37557a 100644 --- a/pina/problem/abstract_problem.py +++ b/pina/problem/abstract_problem.py @@ -32,11 +32,20 @@ class AbstractProblem(metaclass=ABCMeta): # training all type self.collector.full, which returns true if all # points are ready. self.collector.store_fixed_data() + self._batching_dimension = 0 @property def collector(self): return self._collector + @property + def batching_dimension(self): + return self._batching_dimension + + @batching_dimension.setter + def batching_dimension(self, value): + self._batching_dimension = value + # TODO this should be erase when dataloading will interface collector, # kept only for back compatibility @property diff --git a/pina/solvers/solver.py b/pina/solvers/solver.py index 6f55ded..2d9b4a5 100644 --- a/pina/solvers/solver.py +++ b/pina/solvers/solver.py @@ -94,7 +94,7 @@ class SolverInterface(pytorch_lightning.LightningModule, metaclass=ABCMeta): pass @abstractmethod - def training_step(self): + def training_step(self, batch, batch_idx): pass @abstractmethod diff --git a/pina/trainer.py b/pina/trainer.py index 3de0d7e..1601d77 100644 --- a/pina/trainer.py +++ b/pina/trainer.py @@ -79,7 +79,7 @@ class Trainer(pytorch_lightning.Trainer): data_module = PinaDataModule(problem=self.solver.problem, device=device, train_size=self.train_size, test_size=self.test_size, - eval_size=self.eval_size) + val_size=self.eval_size) data_module.setup() self._loader = data_module.train_dataloader() diff --git a/tests/test_label_tensor/test_label_tensor.py b/tests/test_label_tensor/test_label_tensor.py index 8469767..61e4799 100644 --- a/tests/test_label_tensor/test_label_tensor.py +++ b/tests/test_label_tensor/test_label_tensor.py @@ -131,17 +131,17 @@ def test_concatenation_3D(): data_2 = torch.rand(20, 3, 4) labels_2 = ['x', 'y', 'z', 'w'] lt2 = LabelTensor(data_2, labels_2) - with pytest.raises(ValueError): + with pytest.raises(RuntimeError): LabelTensor.cat([lt1, lt2], dim=2) data_1 = torch.rand(20, 3, 2) labels_1 = ['x', 'y'] lt1 = LabelTensor(data_1, labels_1) data_2 = torch.rand(20, 3, 3) - labels_2 = ['x', 'w', 'a'] + labels_2 = ['z', 'w', 'a'] lt2 = LabelTensor(data_2, labels_2) lt_cat = LabelTensor.cat([lt1, lt2], dim=2) assert lt_cat.shape == (20, 3, 5) - assert lt_cat.full_labels[2]['dof'] == range(5) + assert lt_cat.full_labels[2]['dof'] == ['x', 'y', 'z', 'w', 'a'] assert lt_cat.full_labels[0]['dof'] == range(20) assert lt_cat.full_labels[1]['dof'] == range(3) @@ -157,7 +157,8 @@ def test_summation(): assert lt_sum.ndim == lt_sum.ndim assert lt_sum.shape[0] == 20 assert lt_sum.shape[1] == 3 - assert lt_sum.full_labels == labels_all + assert lt_sum.full_labels[0] == labels_all[0] + assert lt_sum.labels == ['x+x', 'y+y', 'z+z'] assert torch.eq(lt_sum.tensor, torch.ones(20, 3) * 2).all() lt1 = LabelTensor(torch.ones(20, 3), labels_all) lt2 = LabelTensor(torch.ones(20, 3), labels_all) @@ -166,7 +167,8 @@ def test_summation(): assert lt_sum.ndim == lt_sum.ndim assert lt_sum.shape[0] == 20 assert lt_sum.shape[1] == 3 - assert lt_sum.full_labels == labels_all + assert lt_sum.full_labels[0] == labels_all[0] + assert lt_sum.labels == ['x+x+x', 'y+y+y', 'z+z+z'] assert torch.eq(lt_sum.tensor, torch.ones(20, 3) * 2).all()