"""Utils module""" from torch.utils.data import Dataset, DataLoader from functools import reduce import types import torch from torch.utils.data import DataLoader, default_collate, ConcatDataset from .label_tensor import LabelTensor import torch def check_consistency(object, object_instance, subclass=False): """Helper function to check object inheritance consistency. Given a specific ``'object'`` we check if the object is instance of a specific ``'object_instance'``, or in case ``'subclass=True'`` we check if the object is subclass if the ``'object_instance'``. :param (iterable or class object) object: The object to check the inheritance :param Object object_instance: The parent class from where the object is expected to inherit :param str object_name: The name of the object :param bool subclass: Check if is a subclass and not instance :raises ValueError: If the object does not inherit from the specified class """ if not isinstance(object, (list, set, tuple)): object = [object] for obj in object: try: if not subclass: assert isinstance(obj, object_instance) else: assert issubclass(obj, object_instance) except AssertionError: raise ValueError(f"{type(obj).__name__} must be {object_instance}.") def number_parameters(model, aggregate=True, only_trainable=True): # TODO: check """ Return the number of parameters of a given `model`. :param torch.nn.Module model: the torch module to inspect. :param bool aggregate: if True the return values is an integer corresponding to the total amount of parameters of whole model. If False, it returns a dictionary whose keys are the names of layers and the values the corresponding number of parameters. Default is True. :param bool trainable: if True, only trainable parameters are count, otherwise no. Default is True. :return: the number of parameters of the model :rtype: dict or int """ tmp = {} for name, parameter in model.named_parameters(): if only_trainable and not parameter.requires_grad: continue tmp[name] = parameter.numel() if aggregate: tmp = sum(tmp.values()) return tmp def merge_tensors(tensors): # name to be changed if tensors: return reduce(merge_two_tensors, tensors[1:], tensors[0]) raise ValueError("Expected at least one tensor") def merge_two_tensors(tensor1, tensor2): n1 = tensor1.shape[0] n2 = tensor2.shape[0] tensor1 = LabelTensor(tensor1.repeat(n2, 1), labels=tensor1.labels) tensor2 = LabelTensor(tensor2.repeat_interleave(n1, dim=0), labels=tensor2.labels) return tensor1.append(tensor2) def torch_lhs(n, dim): """Latin Hypercube Sampling torch routine. Sampling in range $[0, 1)^d$. :param int n: number of samples :param int dim: dimensions of latin hypercube :return: samples :rtype: torch.tensor """ if not isinstance(n, int): raise TypeError('number of point n must be int') if not isinstance(dim, int): raise TypeError('dim must be int') if dim < 1: raise ValueError('dim must be greater than one') samples = torch.rand(size=(n, dim)) perms = torch.tile(torch.arange(1, n + 1), (dim, 1)) for row in range(dim): idx_perm = torch.randperm(perms.shape[-1]) perms[row, :] = perms[row, idx_perm] perms = perms.T samples = (perms - samples) / n return samples def is_function(f): """ Checks whether the given object `f` is a function or lambda. :param object f: The object to be checked. :return: `True` if `f` is a function, `False` otherwise. :rtype: bool """ return type(f) == types.FunctionType or type(f) == types.LambdaType def chebyshev_roots(n): """ Return the roots of *n* Chebyshev polynomials (between [-1, 1]). :param int n: number of roots :return: roots :rtype: torch.tensor """ pi = torch.acos(torch.zeros(1)).item() * 2 k = torch.arange(n) nodes = torch.sort(torch.cos(pi * (k + 0.5) / n))[0] return nodes # class PinaDataset(): # def __init__(self, pinn) -> None: # self.pinn = pinn # @property # def dataloader(self): # return self._create_dataloader() # @property # def dataset(self): # return [self.SampleDataset(key, val) # for key, val in self.input_pts.items()] # def _create_dataloader(self): # """Private method for creating dataloader # :return: dataloader # :rtype: torch.utils.data.DataLoader # """ # if self.pinn.batch_size is None: # return {key: [{key: val}] for key, val in self.pinn.input_pts.items()} # def custom_collate(batch): # # extracting pts labels # _, pts = list(batch[0].items())[0] # labels = pts.labels # # calling default torch collate # collate_res = default_collate(batch) # # save collate result in dict # res = {} # for key, val in collate_res.items(): # val.labels = labels # res[key] = val # __init__(self, location, tensor): # self._tensor = tensor # self._location = location # self._len = len(tensor) # def __getitem__(self, index): # tensor = self._tensor.select(0, index) # return {self._location: tensor} # def __len__(self): # return self._len class LabelTensorDataset(Dataset): def __init__(self, d): for k, v in d.items(): setattr(self, k, v) self.labels = list(d.keys()) def __getitem__(self, index): print(index) result = {} for label in self.labels: sample_tensor = getattr(self, label) # print('porcodio') # print(sample_tensor.shape[1]) # print(index) # print(sample_tensor[index]) try: result[label] = sample_tensor[index] except IndexError: result[label] = torch.tensor([]) print(result) return result def __len__(self): return max([len(getattr(self, label)) for label in self.labels]) class LabelTensorDataLoader(DataLoader): def collate_fn(self, data): pass # return dict(zip(self.pinn.input_pts.keys(), dataloaders)) # class SampleDataset(torch.utils.data.Dataset): # def __init__(self, location, tensor): # self._tensor = tensor # self._location = location # self._len = len(tensor) # def __getitem__(self, index): # tensor = self._tensor.select(0, index) # return {self._location: tensor} # def __len__(self): # return self._len class LabelTensorDataset(Dataset): def __init__(self, d): for k, v in d.items(): setattr(self, k, v) self.labels = list(d.keys()) def __getitem__(self, index): print(index) result = {} for label in self.labels: sample_tensor = getattr(self, label) # print('porcodio') # print(sample_tensor.shape[1]) # print(index) # print(sample_tensor[index]) try: result[label] = sample_tensor[index] except IndexError: result[label] = torch.tensor([]) print(result) return result def __len__(self): return max([len(getattr(self, label)) for label in self.labels]) class LabelTensorDataLoader(DataLoader): def collate_fn(self, data): pass