From 7528f6ef7433a03cc6419c7d77fa3269cc8c425e Mon Sep 17 00:00:00 2001 From: Filippo Olivo Date: Thu, 10 Oct 2024 18:26:52 +0200 Subject: [PATCH] Update of LabelTensor class and fix Simplex domain (#362) *Implement new methods in LabelTensor and fix operators --- pina/collector.py | 29 +- pina/condition/data_condition.py | 4 +- pina/condition/domain_equation_condition.py | 4 +- pina/condition/input_equation_condition.py | 6 +- pina/condition/input_output_condition.py | 4 +- pina/domain/difference_domain.py | 2 +- pina/domain/exclusion_domain.py | 2 +- pina/domain/intersection_domain.py | 2 +- pina/domain/simplex.py | 7 +- pina/domain/union_domain.py | 5 +- pina/label_tensor.py | 252 +++++++++++++----- pina/operators.py | 75 +++--- pina/problem/abstract_problem.py | 13 +- tests/test_condition.py | 22 +- tests/test_geometry/test_simplex.py | 1 - .../test_label_tensor.py | 128 ++++++--- .../test_label_tensor/test_label_tensor_01.py | 117 ++++++++ tests/test_operators.py | 8 +- tests/test_problem.py | 87 +++--- 19 files changed, 551 insertions(+), 217 deletions(-) rename tests/{ => test_label_tensor}/test_label_tensor.py (56%) create mode 100644 tests/test_label_tensor/test_label_tensor_01.py diff --git a/pina/collector.py b/pina/collector.py index fa3247e..0f4e9da 100644 --- a/pina/collector.py +++ b/pina/collector.py @@ -1,3 +1,6 @@ +from sympy.strategies.branch import condition + +from . import LabelTensor from .utils import check_consistency, merge_tensors class Collector: @@ -51,7 +54,7 @@ class Collector: already_sampled = [] # if we have sampled the condition but not all variables else: - already_sampled = [self.data_collections[loc].input_points] + already_sampled = [self.data_collections[loc]['input_points']] # if the condition is ready but we want to sample again else: self.is_conditions_ready[loc] = False @@ -63,10 +66,24 @@ class Collector: ] + already_sampled pts = merge_tensors(samples) if ( - sorted(self.data_collections[loc].input_points.labels) - == - sorted(self.problem.input_variables) + set(pts.labels).issubset(sorted(self.problem.input_variables)) ): - self.is_conditions_ready[loc] = True + pts = pts.sort_labels() + if sorted(pts.labels)==sorted(self.problem.input_variables): + self.is_conditions_ready[loc] = True values = [pts, condition.equation] - self.data_collections[loc] = dict(zip(keys, values)) \ No newline at end of file + self.data_collections[loc] = dict(zip(keys, values)) + else: + raise RuntimeError('Try to sample variables which are not in problem defined in the problem') + + def add_points(self, new_points_dict): + """ + Add input points to a sampled condition + + :param new_points_dict: Dictonary of input points (condition_name: LabelTensor) + :raises RuntimeError: if at least one condition is not already sampled + """ + for k,v in new_points_dict.items(): + if not self.is_conditions_ready[k]: + raise RuntimeError('Cannot add points on a non sampled condition') + self.data_collections[k]['input_points'] = self.data_collections[k]['input_points'].vstack(v) \ No newline at end of file diff --git a/pina/condition/data_condition.py b/pina/condition/data_condition.py index b9fe1ed..d5ac639 100644 --- a/pina/condition/data_condition.py +++ b/pina/condition/data_condition.py @@ -27,4 +27,6 @@ class DataConditionInterface(ConditionInterface): def __setattr__(self, key, value): if (key == 'data') or (key == 'conditionalvariable'): check_consistency(value, (LabelTensor, Graph, torch.Tensor)) - DataConditionInterface.__dict__[key].__set__(self, value) \ No newline at end of file + DataConditionInterface.__dict__[key].__set__(self, value) + elif key in ('_condition_type', '_problem', 'problem', 'condition_type'): + super().__setattr__(key, value) \ No newline at end of file diff --git a/pina/condition/domain_equation_condition.py b/pina/condition/domain_equation_condition.py index f0ef8e0..ab35d20 100644 --- a/pina/condition/domain_equation_condition.py +++ b/pina/condition/domain_equation_condition.py @@ -28,4 +28,6 @@ class DomainEquationCondition(ConditionInterface): DomainEquationCondition.__dict__[key].__set__(self, value) elif key == 'equation': check_consistency(value, (EquationInterface)) - DomainEquationCondition.__dict__[key].__set__(self, value) \ No newline at end of file + DomainEquationCondition.__dict__[key].__set__(self, value) + elif key in ('_condition_type', '_problem', 'problem', 'condition_type'): + super().__setattr__(key, value) \ No newline at end of file diff --git a/pina/condition/input_equation_condition.py b/pina/condition/input_equation_condition.py index f77b025..dc12d02 100644 --- a/pina/condition/input_equation_condition.py +++ b/pina/condition/input_equation_condition.py @@ -21,7 +21,7 @@ class InputPointsEquationCondition(ConditionInterface): super().__init__() self.input_points = input_points self.equation = equation - self.condition_type = 'physics' + self._condition_type = 'physics' def __setattr__(self, key, value): if key == 'input_points': @@ -29,4 +29,6 @@ class InputPointsEquationCondition(ConditionInterface): InputPointsEquationCondition.__dict__[key].__set__(self, value) elif key == 'equation': check_consistency(value, (EquationInterface)) - InputPointsEquationCondition.__dict__[key].__set__(self, value) \ No newline at end of file + InputPointsEquationCondition.__dict__[key].__set__(self, value) + elif key in ('_condition_type', '_problem', 'problem', 'condition_type'): + super().__setattr__(key, value) \ No newline at end of file diff --git a/pina/condition/input_output_condition.py b/pina/condition/input_output_condition.py index 70388b3..a4fa489 100644 --- a/pina/condition/input_output_condition.py +++ b/pina/condition/input_output_condition.py @@ -26,4 +26,6 @@ class InputOutputPointsCondition(ConditionInterface): def __setattr__(self, key, value): if (key == 'input_points') or (key == 'output_points'): check_consistency(value, (LabelTensor, Graph, torch.Tensor)) - InputOutputPointsCondition.__dict__[key].__set__(self, value) \ No newline at end of file + InputOutputPointsCondition.__dict__[key].__set__(self, value) + elif key in ('_condition_type', '_problem', 'problem', 'condition_type'): + super().__setattr__(key, value) diff --git a/pina/domain/difference_domain.py b/pina/domain/difference_domain.py index 9554aaf..4015a38 100644 --- a/pina/domain/difference_domain.py +++ b/pina/domain/difference_domain.py @@ -77,7 +77,7 @@ class Difference(OperationInterface): 5 """ - if mode != self.sample_modes: + if mode not in self.sample_modes: raise NotImplementedError( f"{mode} is not a valid mode for sampling." ) diff --git a/pina/domain/exclusion_domain.py b/pina/domain/exclusion_domain.py index 4fc582c..a05b154 100644 --- a/pina/domain/exclusion_domain.py +++ b/pina/domain/exclusion_domain.py @@ -76,7 +76,7 @@ class Exclusion(OperationInterface): 5 """ - if mode != self.sample_modes: + if mode not in self.sample_modes: raise NotImplementedError( f"{mode} is not a valid mode for sampling." ) diff --git a/pina/domain/intersection_domain.py b/pina/domain/intersection_domain.py index b580f21..bb0499b 100644 --- a/pina/domain/intersection_domain.py +++ b/pina/domain/intersection_domain.py @@ -78,7 +78,7 @@ class Intersection(OperationInterface): 5 """ - if mode != self.sample_modes: + if mode not in self.sample_modes: raise NotImplementedError( f"{mode} is not a valid mode for sampling." ) diff --git a/pina/domain/simplex.py b/pina/domain/simplex.py index 3d33bff..cea2132 100644 --- a/pina/domain/simplex.py +++ b/pina/domain/simplex.py @@ -92,13 +92,12 @@ class SimplexDomain(DomainInterface): """ span_dict = {} - for i, coord in enumerate(self.variables): - sorted_vertices = sorted(vertices, key=lambda vertex: vertex[i]) + sorted_vertices = torch.sort(vertices[coord].tensor.squeeze()) # respective coord bounded by the lowest and highest values span_dict[coord] = [ - float(sorted_vertices[0][i]), - float(sorted_vertices[-1][i]), + float(sorted_vertices.values[0]), + float(sorted_vertices.values[-1]), ] return CartesianDomain(span_dict) diff --git a/pina/domain/union_domain.py b/pina/domain/union_domain.py index bd7fa56..a72115f 100644 --- a/pina/domain/union_domain.py +++ b/pina/domain/union_domain.py @@ -41,7 +41,10 @@ class Union(OperationInterface): @property def variables(self): - return list(set([geom.variables for geom in self.geometries])) + variables = [] + for geom in self.geometries: + variables+=geom.variables + return list(set(variables)) def is_inside(self, point, check_border=False): """ diff --git a/pina/label_tensor.py b/pina/label_tensor.py index 08e0b03..1df318e 100644 --- a/pina/label_tensor.py +++ b/pina/label_tensor.py @@ -1,5 +1,5 @@ """ Module for LabelTensor """ - +from copy import deepcopy, copy import torch from torch import Tensor @@ -35,12 +35,22 @@ class LabelTensor(torch.Tensor): {1: {"name": "space"['a', 'b', 'c']) """ + self.dim_names = None self.labels = labels @property def labels(self): """Property decorator for labels + :return: labels of self + :rtype: list + """ + return self._labels[self.tensor.ndim-1]['dof'] + + @property + def full_labels(self): + """Property decorator for labels + :return: labels of self :rtype: list """ @@ -65,6 +75,13 @@ class LabelTensor(torch.Tensor): self.update_labels_from_list(labels) else: raise ValueError(f"labels must be list, dict or string.") + self.set_names() + + def set_names(self): + labels = self.full_labels + self.dim_names = {} + for dim in range(self.tensor.ndim): + self.dim_names[labels[dim]['name']] = dim def extract(self, label_to_extract): """ @@ -76,46 +93,63 @@ class LabelTensor(torch.Tensor): :raises TypeError: Labels are not ``str``. :raises ValueError: Label to extract is not in the labels ``list``. """ - from copy import deepcopy if isinstance(label_to_extract, (str, int)): label_to_extract = [label_to_extract] if isinstance(label_to_extract, (tuple, list)): - last_dim_label = self._labels[self.tensor.ndim - 1]['dof'] - if set(label_to_extract).issubset(last_dim_label) is False: - raise ValueError('Cannot extract a dof which is not in the original LabelTensor') - idx_to_extract = [last_dim_label.index(i) for i in label_to_extract] - new_tensor = self.tensor - new_tensor = new_tensor[..., idx_to_extract] - new_labels = deepcopy(self._labels) - last_dim_new_label = {self.tensor.ndim - 1: { - 'dof': label_to_extract, - 'name': self._labels[self.tensor.ndim - 1]['name'] - }} - new_labels.update(last_dim_new_label) + return self._extract_from_list(label_to_extract) elif isinstance(label_to_extract, dict): - new_labels = (deepcopy(self._labels)) - new_tensor = self.tensor - for k, v in label_to_extract.items(): - idx_dim = None - for kl, vl in self._labels.items(): - if vl['name'] == k: - idx_dim = kl - break - dim_labels = self._labels[idx_dim]['dof'] - if isinstance(label_to_extract[k], (int, str)): - label_to_extract[k] = [label_to_extract[k]] - if set(label_to_extract[k]).issubset(dim_labels) is False: - raise ValueError('Cannot extract a dof which is not in the original LabelTensor') - idx_to_extract = [dim_labels.index(i) for i in label_to_extract[k]] - indexer = [slice(None)] * idx_dim + [idx_to_extract] + [slice(None)] * (self.tensor.ndim - idx_dim - 1) - new_tensor = new_tensor[indexer] - dim_new_label = {idx_dim: { - 'dof': label_to_extract[k], - 'name': self._labels[idx_dim]['name'] - }} - new_labels.update(dim_new_label) + return self._extract_from_dict(label_to_extract) else: raise ValueError('labels_to_extract must be str or list or dict') + + def _extract_from_list(self, labels_to_extract): + #Store locally all necessary obj/variables + ndim = self.tensor.ndim + labels = self.full_labels + tensor = self.tensor + last_dim_label = self.labels + + #Verify if all the labels in labels_to_extract are in last dimension + if set(labels_to_extract).issubset(last_dim_label) is False: + raise ValueError('Cannot extract a dof which is not in the original LabelTensor') + + #Extract index to extract + idx_to_extract = [last_dim_label.index(i) for i in labels_to_extract] + + #Perform extraction + new_tensor = tensor[..., idx_to_extract] + + #Manage labels + new_labels = copy(labels) + + last_dim_new_label = {ndim - 1: { + 'dof': list(labels_to_extract), + 'name': labels[ndim - 1]['name'] + }} + new_labels.update(last_dim_new_label) + return LabelTensor(new_tensor, new_labels) + + def _extract_from_dict(self, labels_to_extract): + labels = self.full_labels + tensor = self.tensor + ndim = tensor.ndim + new_labels = deepcopy(labels) + new_tensor = tensor + for k, _ in labels_to_extract.items(): + idx_dim = self.dim_names[k] + dim_labels = labels[idx_dim]['dof'] + if isinstance(labels_to_extract[k], (int, str)): + labels_to_extract[k] = [labels_to_extract[k]] + if set(labels_to_extract[k]).issubset(dim_labels) is False: + raise ValueError('Cannot extract a dof which is not in the original LabelTensor') + idx_to_extract = [dim_labels.index(i) for i in labels_to_extract[k]] + indexer = [slice(None)] * idx_dim + [idx_to_extract] + [slice(None)] * (ndim - idx_dim - 1) + new_tensor = new_tensor[indexer] + dim_new_label = {idx_dim: { + 'dof': labels_to_extract[k], + 'name': labels[idx_dim]['name'] + }} + new_labels.update(dim_new_label) return LabelTensor(new_tensor, new_labels) def __str__(self): @@ -147,32 +181,42 @@ class LabelTensor(torch.Tensor): return [] if len(tensors) == 1: return tensors[0] - n_dims = tensors[0].ndim - new_labels_cat_dim = [] - for i in range(n_dims): - name = tensors[0].labels[i]['name'] - if i != dim: - dof = tensors[0].labels[i]['dof'] - for tensor in tensors: - dof_to_check = tensor.labels[i]['dof'] - name_to_check = tensor.labels[i]['name'] - if dof != dof_to_check or name != name_to_check: - raise ValueError('dimensions must have the same dof and name') - else: - for tensor in tensors: - new_labels_cat_dim += tensor.labels[i]['dof'] - name_to_check = tensor.labels[i]['name'] - if name != name_to_check: - raise ValueError('dimensions must have the same dof and name') + new_labels_cat_dim = LabelTensor._check_validity_before_cat(tensors, dim) + + # Perform cat on tensors new_tensor = torch.cat(tensors, dim=dim) - labels = tensors[0].labels + + #Update labels + labels = tensors[0].full_labels labels.pop(dim) new_labels_cat_dim = new_labels_cat_dim if len(set(new_labels_cat_dim)) == len(new_labels_cat_dim) \ else range(new_tensor.shape[dim]) labels[dim] = {'dof': new_labels_cat_dim, - 'name': tensors[1].labels[dim]['name']} + 'name': tensors[1].full_labels[dim]['name']} return LabelTensor(new_tensor, labels) + @staticmethod + def _check_validity_before_cat(tensors, dim): + n_dims = tensors[0].ndim + new_labels_cat_dim = [] + # Check if names and dof of the labels are the same in all dimensions except in dim + for i in range(n_dims): + name = tensors[0].full_labels[i]['name'] + if i != dim: + dof = tensors[0].full_labels[i]['dof'] + for tensor in tensors: + dof_to_check = tensor.full_labels[i]['dof'] + name_to_check = tensor.full_labels[i]['name'] + if dof != dof_to_check or name != name_to_check: + raise ValueError('dimensions must have the same dof and name') + else: + for tensor in tensors: + new_labels_cat_dim += tensor.full_labels[i]['dof'] + name_to_check = tensor.full_labels[i]['name'] + if name != name_to_check: + raise ValueError('Dimensions to concatenate must have the same name') + return new_labels_cat_dim + def requires_grad_(self, mode=True): lt = super().requires_grad_(mode) lt.labels = self._labels @@ -204,7 +248,6 @@ class LabelTensor(torch.Tensor): out = LabelTensor(super().clone(*args, **kwargs), self._labels) return out - def init_labels(self): self._labels = { idx_: { @@ -221,13 +264,14 @@ class LabelTensor(torch.Tensor): :type labels: dict :raises ValueError: dof list contain duplicates or number of dof does not match with tensor shape """ - tensor_shape = self.tensor.shape + #Check dimensionality for k, v in labels.items(): if len(v['dof']) != len(set(v['dof'])): raise ValueError("dof must be unique") if len(v['dof']) != tensor_shape[k]: raise ValueError('Number of dof does not match with tensor dimension') + #Perform update self._labels.update(labels) def update_labels_from_list(self, labels): @@ -237,6 +281,7 @@ class LabelTensor(torch.Tensor): :param labels: The label(s) to update. :type labels: list """ + # Create a dict with labels last_dim_labels = {self.tensor.ndim - 1: {'dof': labels, 'name': self.tensor.ndim - 1}} self.update_labels_from_dict(last_dim_labels) @@ -246,26 +291,103 @@ class LabelTensor(torch.Tensor): raise ValueError('tensors list must not be empty') if len(tensors) == 1: return tensors[0] - labels = tensors[0].labels + # Collect all labels + labels = tensors[0].full_labels + # Check labels of all the tensors in each dimension for j in range(tensors[0].ndim): for i in range(1, len(tensors)): - if labels[j] != tensors[i].labels[j]: + if labels[j] != tensors[i].full_labels[j]: labels.pop(j) break - + # Sum tensors data = torch.zeros(tensors[0].tensor.shape) for i in range(len(tensors)): data += tensors[i].tensor new_tensor = LabelTensor(data, labels) return new_tensor - def last_dim_dof(self): - return self._labels[self.tensor.ndim - 1]['dof'] - def append(self, tensor, mode='std'): - print(self.labels) - print(tensor.labels) if mode == 'std': + # Call cat on last dimension new_label_tensor = LabelTensor.cat([self, tensor], dim=self.tensor.ndim - 1) - + elif mode=='cross': + # Crete tensor and call cat on last dimension + tensor1 = self + tensor2 = tensor + n1 = tensor1.shape[0] + n2 = tensor2.shape[0] + tensor1 = LabelTensor(tensor1.repeat(n2, 1), labels=tensor1.labels) + tensor2 = LabelTensor(tensor2.repeat_interleave(n1, dim=0), labels=tensor2.labels) + new_label_tensor = LabelTensor.cat([tensor1, tensor2], dim=self.tensor.ndim-1) + else: + raise ValueError('mode must be either "std" or "cross"') return new_label_tensor + + @staticmethod + def vstack(label_tensors): + """ + Stack tensors vertically. For more details, see + :meth:`torch.vstack`. + + :param list(LabelTensor) label_tensors: the tensors to stack. They need + to have equal labels. + :return: the stacked tensor + :rtype: LabelTensor + """ + return LabelTensor.cat(label_tensors, dim=0) + + def __getitem__(self, index): + """ + Return a copy of the selected tensor. + """ + + if isinstance(index, str) or (isinstance(index, (tuple, list)) and all(isinstance(a, str) for a in index)): + return self.extract(index) + + selected_lt = super().__getitem__(index) + + try: + len_index = len(index) + except TypeError: + len_index = 1 + + if isinstance(index, int) or len_index == 1: + if selected_lt.ndim == 1: + selected_lt = selected_lt.reshape(1, -1) + if hasattr(self, "labels"): + new_labels = deepcopy(self.full_labels) + new_labels.pop(0) + selected_lt.labels = new_labels + elif len(index) == self.tensor.ndim: + new_labels = deepcopy(self.full_labels) + if selected_lt.ndim == 1: + selected_lt = selected_lt.reshape(-1, 1) + for j in range(selected_lt.ndim): + if hasattr(self, "labels"): + if isinstance(index[j], list): + new_labels.update({j: {'dof': [new_labels[j]['dof'][i] for i in index[1]], + 'name':new_labels[j]['name']}}) + else: + new_labels.update({j: {'dof': new_labels[j]['dof'][index[j]], + 'name':new_labels[j]['name']}}) + + selected_lt.labels = new_labels + else: + new_labels = deepcopy(self.full_labels) + new_labels.update({0: {'dof': list[index], 'name': new_labels[0]['name']}}) + selected_lt.labels = self.labels + + return selected_lt + + def sort_labels(self, dim=None): + def argsort(lst): + return sorted(range(len(lst)), key=lambda x: lst[x]) + if dim is None: + dim = self.tensor.ndim-1 + labels = self.full_labels[dim]['dof'] + sorted_index = argsort(labels) + indexer = [slice(None)] * self.tensor.ndim + indexer[dim] = sorted_index + new_labels = deepcopy(self.full_labels) + new_labels[dim] = {'dof': sorted(labels), 'name': new_labels[dim]['name']} + return LabelTensor(self.tensor[indexer], new_labels) \ No newline at end of file diff --git a/pina/operators.py b/pina/operators.py index 082d725..8822f20 100644 --- a/pina/operators.py +++ b/pina/operators.py @@ -1,13 +1,11 @@ """ Module for operators vectorize implementation. Differential operators are used to write any differential problem. -These operators are implemented to work on different accelerators: CPU, GPU, TPU or MPS. +These operators are implemented to work on different accellerators: CPU, GPU, TPU or MPS. All operators take as input a tensor onto which computing the operator, a tensor with respect to which computing the operator, the name of the output variables to calculate the operator for (in case of multidimensional functions), and the variables name on which the operator is calculated. """ - import torch -from copy import deepcopy from pina.label_tensor import LabelTensor @@ -49,12 +47,12 @@ def grad(output_, input_, components=None, d=None): :rtype: LabelTensor """ - if len(output_.labels[output_.tensor.ndim-1]['dof']) != 1: + if len(output_.labels) != 1: raise RuntimeError("only scalar function can be differentiated") - if not all([di in input_.labels[input_.tensor.ndim-1]['dof'] for di in d]): + if not all([di in input_.labels for di in d]): raise RuntimeError("derivative labels missing from input tensor") - output_fieldname = output_.labels[output_.ndim-1]['dof'][0] + output_fieldname = output_.labels[0] gradients = torch.autograd.grad( output_, input_, @@ -65,35 +63,37 @@ def grad(output_, input_, components=None, d=None): retain_graph=True, allow_unused=True, )[0] - new_labels = deepcopy(input_.labels) - gradients.labels = new_labels + + gradients.labels = input_.labels gradients = gradients.extract(d) - new_labels[input_.tensor.ndim - 1]['dof'] = [f"d{output_fieldname}d{i}" for i in d] - gradients.labels = new_labels + gradients.labels = [f"d{output_fieldname}d{i}" for i in d] + return gradients if not isinstance(input_, LabelTensor): raise TypeError + if d is None: - d = input_.labels[input_.tensor.ndim-1]['dof'] + d = input_.labels if components is None: - components = output_.labels[output_.tensor.ndim-1]['dof'] + components = output_.labels - if output_.shape[output_.ndim-1] == 1: # scalar output ################################ + if output_.shape[1] == 1: # scalar output ################################ - if components != output_.labels[output_.tensor.ndim-1]['dof']: + if components != output_.labels: raise RuntimeError gradients = grad_scalar_output(output_, input_, d) - elif output_.shape[output_.ndim-1] >= 2: # vector output ############################## + elif output_.shape[output_.ndim - 1] >= 2: # vector output ############################## tensor_to_cat = [] for i, c in enumerate(components): c_output = output_.extract([c]) tensor_to_cat.append(grad_scalar_output(c_output, input_, d)) - gradients = LabelTensor.cat(tensor_to_cat, dim=output_.tensor.ndim-1) + gradients = LabelTensor.cat(tensor_to_cat, dim=output_.tensor.ndim - 1) else: raise NotImplementedError + return gradients @@ -124,30 +124,27 @@ def div(output_, input_, components=None, d=None): raise TypeError if d is None: - d = input_.labels[input_.tensor.ndim-1]['dof'] + d = input_.labels if components is None: - components = output_.labels[output_.tensor.ndim-1]['dof'] + components = output_.labels - if output_.shape[output_.ndim-1] < 2 or len(components) < 2: + if output_.shape[1] < 2 or len(components) < 2: raise ValueError("div supported only for vector fields") if len(components) != len(d): raise ValueError grad_output = grad(output_, input_, components, d) - last_dim_dof = [None] * len(components) - to_sum_tensors = [] + labels = [None] * len(components) + tensors_to_sum = [] for i, (c, d) in enumerate(zip(components, d)): c_fields = f"d{c}d{d}" - last_dim_dof[i] = c_fields - to_sum_tensors.append(grad_output.extract(c_fields)) - - div = LabelTensor.summation(to_sum_tensors) - new_labels = deepcopy(input_.labels) - new_labels[input_.tensor.ndim-1]['dof'] = ["+".join(last_dim_dof)] - div.labels = new_labels - return div + tensors_to_sum.append(grad_output.extract(c_fields)) + labels[i] = c_fields + div_result = LabelTensor.summation(tensors_to_sum) + div_result.labels = ["+".join(labels)] + return div_result def laplacian(output_, input_, components=None, d=None, method="std"): @@ -201,10 +198,10 @@ def laplacian(output_, input_, components=None, d=None, method="std"): return result if d is None: - d = input_.labels[input_.tensor.ndim-1]['dof'] + d = input_.labels if components is None: - components = output_.labels[output_.tensor.ndim-1]['dof'] + components = output_.labels if method == "divgrad": raise NotImplementedError("divgrad not implemented as method") @@ -217,9 +214,9 @@ def laplacian(output_, input_, components=None, d=None, method="std"): # result = scalar_laplace(output_, input_, components, d) # TODO check (from 0.1) grad_output = grad(output_, input_, components=components, d=d) to_append_tensors = [] - for i, label in enumerate(grad_output.labels[grad_output.ndim-1]['dof']): + for i, label in enumerate(grad_output.labels): gg = grad(grad_output, input_, d=d, components=[label]) - to_append_tensors.append(gg.extract([gg.labels[gg.tensor.ndim-1]['dof'][i]])) + to_append_tensors.append(gg.extract([gg.labels[i]])) labels = [f"dd{components[0]}"] result = LabelTensor.summation(tensors=to_append_tensors) result.labels = labels @@ -236,21 +233,27 @@ def laplacian(output_, input_, components=None, d=None, method="std"): # result = result.as_subclass(LabelTensor) # result.labels = labels + result = torch.empty( + input_.shape[0], len(components), device=output_.device + ) labels = [None] * len(components) to_append_tensors = [None] * len(components) for idx, (ci, di) in enumerate(zip(components, d)): + if not isinstance(ci, list): ci = [ci] if not isinstance(di, list): di = [di] + grad_output = grad(output_, input_, components=ci, d=di) + result[:, idx] = grad(grad_output, input_, d=di).flatten() to_append_tensors[idx] = grad(grad_output, input_, d=di) labels[idx] = f"dd{ci[0]}dd{di[0]}" result = LabelTensor.cat(tensors=to_append_tensors, dim=output_.tensor.ndim-1) result.labels = labels return result -# TODO Fix advection operator + def advection(output_, input_, velocity_field, components=None, d=None): """ Perform advection operation. The operator works for vectorial functions, @@ -272,10 +275,10 @@ def advection(output_, input_, velocity_field, components=None, d=None): :rtype: LabelTensor """ if d is None: - d = input_.labels[input_.tensor.ndim-1]['dof'] + d = input_.labels if components is None: - components = output_.labels[output_.tensor.ndim-1]['dof'] + components = output_.labels tmp = ( grad(output_, input_, components, d) diff --git a/pina/problem/abstract_problem.py b/pina/problem/abstract_problem.py index 8894207..edf214a 100644 --- a/pina/problem/abstract_problem.py +++ b/pina/problem/abstract_problem.py @@ -36,7 +36,15 @@ class AbstractProblem(metaclass=ABCMeta): @property def input_pts(self): - return self.collector.data_collections + to_return = {} + for k, v in self.collector.data_collections.items(): + if 'input_points' in v.keys(): + to_return[k] = v['input_points'] + return to_return + + @property + def _have_sampled_points(self): + return self.collector.is_conditions_ready def __deepcopy__(self, memo): """ @@ -165,3 +173,6 @@ class AbstractProblem(metaclass=ABCMeta): # store data self.collector.store_sample_domains(n, mode, variables, locations) + + def add_points(self, new_points_dict): + self.collector.add_points(new_points_dict) diff --git a/tests/test_condition.py b/tests/test_condition.py index 5f1c623..f12979d 100644 --- a/tests/test_condition.py +++ b/tests/test_condition.py @@ -18,27 +18,27 @@ def test_init_inputoutput(): Condition(input_points=example_input_pts, output_points=example_output_pts) with pytest.raises(ValueError): Condition(example_input_pts, example_output_pts) - with pytest.raises(TypeError): + with pytest.raises(ValueError): Condition(input_points=3., output_points='example') - with pytest.raises(TypeError): + with pytest.raises(ValueError): Condition(input_points=example_domain, output_points=example_domain) +test_init_inputoutput() - -def test_init_locfunc(): - Condition(location=example_domain, equation=FixedValue(0.0)) +def test_init_domainfunc(): + Condition(domain=example_domain, equation=FixedValue(0.0)) with pytest.raises(ValueError): Condition(example_domain, FixedValue(0.0)) - with pytest.raises(TypeError): - Condition(location=3., equation='example') - with pytest.raises(TypeError): - Condition(location=example_input_pts, equation=example_output_pts) + with pytest.raises(ValueError): + Condition(domain=3., equation='example') + with pytest.raises(ValueError): + Condition(domain=example_input_pts, equation=example_output_pts) def test_init_inputfunc(): Condition(input_points=example_input_pts, equation=FixedValue(0.0)) with pytest.raises(ValueError): Condition(example_domain, FixedValue(0.0)) - with pytest.raises(TypeError): + with pytest.raises(ValueError): Condition(input_points=3., equation='example') - with pytest.raises(TypeError): + with pytest.raises(ValueError): Condition(input_points=example_domain, equation=example_output_pts) diff --git a/tests/test_geometry/test_simplex.py b/tests/test_geometry/test_simplex.py index 7fc34ce..25224aa 100644 --- a/tests/test_geometry/test_simplex.py +++ b/tests/test_geometry/test_simplex.py @@ -40,7 +40,6 @@ def test_constructor(): LabelTensor(torch.tensor([[-.5, .5]]), labels=["x", "y"]), ]) - def test_sample(): # sampling inside simplex = SimplexDomain([ diff --git a/tests/test_label_tensor.py b/tests/test_label_tensor/test_label_tensor.py similarity index 56% rename from tests/test_label_tensor.py rename to tests/test_label_tensor/test_label_tensor.py index 6ef484f..1165594 100644 --- a/tests/test_label_tensor.py +++ b/tests/test_label_tensor/test_label_tensor.py @@ -2,7 +2,6 @@ import torch import pytest from pina.label_tensor import LabelTensor -#import pina data = torch.rand((20, 3)) labels_column = { @@ -22,8 +21,7 @@ labels_all = labels_column | labels_row @pytest.mark.parametrize("labels", [labels_column, labels_row, labels_all, labels_list]) def test_constructor(labels): - LabelTensor(data, labels) - + print(LabelTensor(data, labels)) def test_wrong_constructor(): with pytest.raises(ValueError): @@ -92,7 +90,7 @@ def test_extract_3D(): )) assert tensor2.ndim == tensor.ndim assert tensor2.shape == tensor.shape - assert tensor.labels == tensor2.labels + assert tensor.full_labels == tensor2.full_labels assert new.shape != tensor.shape def test_concatenation_3D(): @@ -104,9 +102,9 @@ def test_concatenation_3D(): lt2 = LabelTensor(data_2, labels_2) lt_cat = LabelTensor.cat([lt1, lt2]) assert lt_cat.shape == (70, 3, 4) - assert lt_cat.labels[0]['dof'] == range(70) - assert lt_cat.labels[1]['dof'] == range(3) - assert lt_cat.labels[2]['dof'] == ['x', 'y', 'z', 'w'] + assert lt_cat.full_labels[0]['dof'] == range(70) + assert lt_cat.full_labels[1]['dof'] == range(3) + assert lt_cat.full_labels[2]['dof'] == ['x', 'y', 'z', 'w'] data_1 = torch.rand(20, 3, 4) labels_1 = ['x', 'y', 'z', 'w'] @@ -116,9 +114,9 @@ def test_concatenation_3D(): lt2 = LabelTensor(data_2, labels_2) lt_cat = LabelTensor.cat([lt1, lt2], dim=1) assert lt_cat.shape == (20, 5, 4) - assert lt_cat.labels[0]['dof'] == range(20) - assert lt_cat.labels[1]['dof'] == range(5) - assert lt_cat.labels[2]['dof'] == ['x', 'y', 'z', 'w'] + assert lt_cat.full_labels[0]['dof'] == range(20) + assert lt_cat.full_labels[1]['dof'] == range(5) + assert lt_cat.full_labels[2]['dof'] == ['x', 'y', 'z', 'w'] data_1 = torch.rand(20, 3, 2) labels_1 = ['x', 'y'] @@ -128,9 +126,9 @@ def test_concatenation_3D(): lt2 = LabelTensor(data_2, labels_2) lt_cat = LabelTensor.cat([lt1, lt2], dim=2) assert lt_cat.shape == (20, 3, 5) - assert lt_cat.labels[2]['dof'] == ['x', 'y', 'z', 'w', 'a'] - assert lt_cat.labels[0]['dof'] == range(20) - assert lt_cat.labels[1]['dof'] == range(3) + assert lt_cat.full_labels[2]['dof'] == ['x', 'y', 'z', 'w', 'a'] + assert lt_cat.full_labels[0]['dof'] == range(20) + assert lt_cat.full_labels[1]['dof'] == range(3) data_1 = torch.rand(20, 2, 4) labels_1 = ['x', 'y', 'z', 'w'] @@ -140,7 +138,6 @@ def test_concatenation_3D(): lt2 = LabelTensor(data_2, labels_2) with pytest.raises(ValueError): LabelTensor.cat([lt1, lt2], dim=2) - data_1 = torch.rand(20, 3, 2) labels_1 = ['x', 'y'] lt1 = LabelTensor(data_1, labels_1) @@ -149,9 +146,9 @@ def test_concatenation_3D(): lt2 = LabelTensor(data_2, labels_2) lt_cat = LabelTensor.cat([lt1, lt2], dim=2) assert lt_cat.shape == (20, 3, 5) - assert lt_cat.labels[2]['dof'] == range(5) - assert lt_cat.labels[0]['dof'] == range(20) - assert lt_cat.labels[1]['dof'] == range(3) + assert lt_cat.full_labels[2]['dof'] == range(5) + assert lt_cat.full_labels[0]['dof'] == range(20) + assert lt_cat.full_labels[1]['dof'] == range(3) def test_summation(): @@ -165,7 +162,7 @@ def test_summation(): assert lt_sum.ndim == lt_sum.ndim assert lt_sum.shape[0] == 20 assert lt_sum.shape[1] == 3 - assert lt_sum.labels == labels_all + assert lt_sum.full_labels == labels_all assert torch.eq(lt_sum.tensor, torch.ones(20,3)*2).all() lt1 = LabelTensor(torch.ones(20,3), labels_all) lt2 = LabelTensor(torch.ones(20,3), labels_all) @@ -174,29 +171,92 @@ def test_summation(): assert lt_sum.ndim == lt_sum.ndim assert lt_sum.shape[0] == 20 assert lt_sum.shape[1] == 3 - assert lt_sum.labels == labels_all + assert lt_sum.full_labels == labels_all assert torch.eq(lt_sum.tensor, torch.ones(20,3)*2).all() def test_append_3D(): - data_1 = torch.rand(20, 3, 4) - labels_1 = ['x', 'y', 'z', 'w'] - lt1 = LabelTensor(data_1, labels_1) - data_2 = torch.rand(50, 3, 4) - labels_2 = ['x', 'y', 'z', 'w'] - lt2 = LabelTensor(data_2, labels_2) - lt1 = lt1.append(lt2) - assert lt1.shape == (70, 3, 4) - assert lt1.labels[0]['dof'] == range(70) - assert lt1.labels[1]['dof'] == range(3) - assert lt1.labels[2]['dof'] == ['x', 'y', 'z', 'w'] data_1 = torch.rand(20, 3, 2) labels_1 = ['x', 'y'] lt1 = LabelTensor(data_1, labels_1) data_2 = torch.rand(20, 3, 2) labels_2 = ['z', 'w'] lt2 = LabelTensor(data_2, labels_2) - lt1 = lt1.append(lt2, mode='cross') + lt1 = lt1.append(lt2) assert lt1.shape == (20, 3, 4) - assert lt1.labels[0]['dof'] == range(20) - assert lt1.labels[1]['dof'] == range(3) - assert lt1.labels[2]['dof'] == ['x', 'y', 'z', 'w'] + assert lt1.full_labels[0]['dof'] == range(20) + assert lt1.full_labels[1]['dof'] == range(3) + assert lt1.full_labels[2]['dof'] == ['x', 'y', 'z', 'w'] + +def test_append_2D(): + data_1 = torch.rand(20, 2) + labels_1 = ['x', 'y'] + lt1 = LabelTensor(data_1, labels_1) + data_2 = torch.rand(20, 2) + labels_2 = ['z', 'w'] + lt2 = LabelTensor(data_2, labels_2) + lt1 = lt1.append(lt2, mode='cross') + assert lt1.shape == (400, 4) + assert lt1.full_labels[0]['dof'] == range(400) + assert lt1.full_labels[1]['dof'] == ['x', 'y', 'z', 'w'] + +def test_vstack_3D(): + data_1 = torch.rand(20, 3, 2) + labels_1 = {1:{'dof': ['a', 'b', 'c'], 'name': 'first'}, 2: {'dof': ['x', 'y'], 'name': 'second'}} + lt1 = LabelTensor(data_1, labels_1) + data_2 = torch.rand(20, 3, 2) + labels_1 = {1:{'dof': ['a', 'b', 'c'], 'name': 'first'}, 2: {'dof': ['x', 'y'], 'name': 'second'}} + lt2 = LabelTensor(data_2, labels_1) + lt_stacked = LabelTensor.vstack([lt1, lt2]) + assert lt_stacked.shape == (40, 3, 2) + assert lt_stacked.full_labels[0]['dof'] == range(40) + assert lt_stacked.full_labels[1]['dof'] == ['a', 'b', 'c'] + assert lt_stacked.full_labels[2]['dof'] == ['x', 'y'] + assert lt_stacked.full_labels[1]['name'] == 'first' + assert lt_stacked.full_labels[2]['name'] == 'second' + +def test_vstack_2D(): + data_1 = torch.rand(20, 2) + labels_1 = { 1: {'dof': ['x', 'y'], 'name': 'second'}} + lt1 = LabelTensor(data_1, labels_1) + data_2 = torch.rand(20, 2) + labels_1 = { 1: {'dof': ['x', 'y'], 'name': 'second'}} + lt2 = LabelTensor(data_2, labels_1) + lt_stacked = LabelTensor.vstack([lt1, lt2]) + assert lt_stacked.shape == (40, 2) + assert lt_stacked.full_labels[0]['dof'] == range(40) + assert lt_stacked.full_labels[1]['dof'] == ['x', 'y'] + assert lt_stacked.full_labels[0]['name'] == 0 + assert lt_stacked.full_labels[1]['name'] == 'second' + +def test_sorting(): + data = torch.ones(20, 5) + data[:,0] = data[:,0]*4 + data[:,1] = data[:,1]*2 + data[:,2] = data[:,2] + data[:,3] = data[:,3]*5 + data[:,4] = data[:,4]*3 + labels = ['d', 'b', 'a', 'e', 'c'] + lt_data = LabelTensor(data, labels) + lt_sorted = LabelTensor.sort_labels(lt_data) + assert lt_sorted.shape == (20,5) + assert lt_sorted.labels == ['a', 'b', 'c', 'd', 'e'] + assert torch.eq(lt_sorted.tensor[:,0], torch.ones(20) * 1).all() + assert torch.eq(lt_sorted.tensor[:,1], torch.ones(20) * 2).all() + assert torch.eq(lt_sorted.tensor[:,2], torch.ones(20) * 3).all() + assert torch.eq(lt_sorted.tensor[:,3], torch.ones(20) * 4).all() + assert torch.eq(lt_sorted.tensor[:,4], torch.ones(20) * 5).all() + + data = torch.ones(20, 4, 5) + data[:,0,:] = data[:,0]*4 + data[:,1,:] = data[:,1]*2 + data[:,2,:] = data[:,2] + data[:,3,:] = data[:,3]*3 + labels = {1: {'dof': ['d', 'b', 'a', 'c'], 'name': 1}} + lt_data = LabelTensor(data, labels) + lt_sorted = LabelTensor.sort_labels(lt_data, dim=1) + assert lt_sorted.shape == (20,4, 5) + assert lt_sorted.full_labels[1]['dof'] == ['a', 'b', 'c', 'd'] + assert torch.eq(lt_sorted.tensor[:,0,:], torch.ones(20,5) * 1).all() + assert torch.eq(lt_sorted.tensor[:,1,:], torch.ones(20,5) * 2).all() + assert torch.eq(lt_sorted.tensor[:,2,:], torch.ones(20,5) * 3).all() + assert torch.eq(lt_sorted.tensor[:,3,:], torch.ones(20,5) * 4).all() diff --git a/tests/test_label_tensor/test_label_tensor_01.py b/tests/test_label_tensor/test_label_tensor_01.py new file mode 100644 index 0000000..a2e129d --- /dev/null +++ b/tests/test_label_tensor/test_label_tensor_01.py @@ -0,0 +1,117 @@ +import torch +import pytest + +from pina import LabelTensor + +data = torch.rand((20, 3)) +labels = ['a', 'b', 'c'] + + +def test_constructor(): + LabelTensor(data, labels) + + +def test_wrong_constructor(): + with pytest.raises(ValueError): + LabelTensor(data, ['a', 'b']) + + +def test_labels(): + tensor = LabelTensor(data, labels) + assert isinstance(tensor, torch.Tensor) + assert tensor.labels == labels + with pytest.raises(ValueError): + tensor.labels = labels[:-1] + + +def test_extract(): + label_to_extract = ['a', 'c'] + tensor = LabelTensor(data, labels) + new = tensor.extract(label_to_extract) + assert new.labels == label_to_extract + assert new.shape[1] == len(label_to_extract) + assert torch.all(torch.isclose(data[:, 0::2], new)) + + +def test_extract_onelabel(): + label_to_extract = ['a'] + tensor = LabelTensor(data, labels) + new = tensor.extract(label_to_extract) + assert new.ndim == 2 + assert new.labels == label_to_extract + assert new.shape[1] == len(label_to_extract) + assert torch.all(torch.isclose(data[:, 0].reshape(-1, 1), new)) + + +def test_wrong_extract(): + label_to_extract = ['a', 'cc'] + tensor = LabelTensor(data, labels) + with pytest.raises(ValueError): + tensor.extract(label_to_extract) + + +def test_extract_order(): + label_to_extract = ['c', 'a'] + tensor = LabelTensor(data, labels) + new = tensor.extract(label_to_extract) + expected = torch.cat( + (data[:, 2].reshape(-1, 1), data[:, 0].reshape(-1, 1)), + dim=1) + assert new.labels == label_to_extract + assert new.shape[1] == len(label_to_extract) + assert torch.all(torch.isclose(expected, new)) + + +def test_merge(): + tensor = LabelTensor(data, labels) + tensor_a = tensor.extract('a') + tensor_b = tensor.extract('b') + tensor_c = tensor.extract('c') + + tensor_bc = tensor_b.append(tensor_c) + assert torch.allclose(tensor_bc, tensor.extract(['b', 'c'])) + + +def test_merge2(): + tensor = LabelTensor(data, labels) + tensor_b = tensor.extract('b') + tensor_c = tensor.extract('c') + + tensor_bc = tensor_b.append(tensor_c) + assert torch.allclose(tensor_bc, tensor.extract(['b', 'c'])) + + +def test_getitem(): + tensor = LabelTensor(data, labels) + tensor_view = tensor['a'] + assert tensor_view.labels == ['a'] + assert torch.allclose(tensor_view.flatten(), data[:, 0]) + + tensor_view = tensor['a', 'c'] + assert tensor_view.labels == ['a', 'c'] + assert torch.allclose(tensor_view, data[:, 0::2]) + +def test_getitem2(): + tensor = LabelTensor(data, labels) + tensor_view = tensor[:5] + assert tensor_view.labels == labels + assert torch.allclose(tensor_view, data[:5]) + + idx = torch.randperm(tensor.shape[0]) + tensor_view = tensor[idx] + assert tensor_view.labels == labels + +def test_slice(): + tensor = LabelTensor(data, labels) + tensor_view = tensor[:5, :2] + assert tensor_view.labels == labels[:2] + assert torch.allclose(tensor_view, data[:5, :2]) + + tensor_view2 = tensor[3] + + assert tensor_view2.labels == labels + assert torch.allclose(tensor_view2, data[3]) + + tensor_view3 = tensor[:, 2] + assert tensor_view3.labels == labels[2] + assert torch.allclose(tensor_view3, data[:, 2].reshape(-1, 1)) \ No newline at end of file diff --git a/tests/test_operators.py b/tests/test_operators.py index e18eaf2..1271c37 100644 --- a/tests/test_operators.py +++ b/tests/test_operators.py @@ -27,15 +27,15 @@ def test_grad_scalar_output(): grad_tensor_s = grad(tensor_s, inp) true_val = 2*inp assert grad_tensor_s.shape == inp.shape - assert grad_tensor_s.labels[grad_tensor_s.ndim-1]['dof'] == [ - f'd{tensor_s.labels[tensor_s.ndim-1]["dof"][0]}d{i}' for i in inp.labels[inp.ndim-1]['dof'] + assert grad_tensor_s.labels == [ + f'd{tensor_s.labels[0]}d{i}' for i in inp.labels ] assert torch.allclose(grad_tensor_s, true_val) grad_tensor_s = grad(tensor_s, inp, d=['x', 'y']) assert grad_tensor_s.shape == (20, 2) - assert grad_tensor_s.labels[grad_tensor_s.ndim-1]['dof'] == [ - f'd{tensor_s.labels[tensor_s.ndim-1]["dof"][0]}d{i}' for i in ['x', 'y'] + assert grad_tensor_s.labels == [ + f'd{tensor_s.labels[0]}d{i}' for i in ['x', 'y'] ] assert torch.allclose(grad_tensor_s, true_val) diff --git a/tests/test_problem.py b/tests/test_problem.py index 77871b1..cc7e255 100644 --- a/tests/test_problem.py +++ b/tests/test_problem.py @@ -27,50 +27,46 @@ class Poisson(SpatialProblem): conditions = { 'gamma1': - Condition(domain=CartesianDomain({ - 'x': [0, 1], - 'y': 1 - }), - equation=FixedValue(0.0)), + Condition(domain=CartesianDomain({ + 'x': [0, 1], + 'y': 1 + }), + equation=FixedValue(0.0)), 'gamma2': - Condition(domain=CartesianDomain({ - 'x': [0, 1], - 'y': 0 - }), - equation=FixedValue(0.0)), + Condition(domain=CartesianDomain({ + 'x': [0, 1], + 'y': 0 + }), + equation=FixedValue(0.0)), 'gamma3': - Condition(domain=CartesianDomain({ - 'x': 1, - 'y': [0, 1] - }), - equation=FixedValue(0.0)), + Condition(domain=CartesianDomain({ + 'x': 1, + 'y': [0, 1] + }), + equation=FixedValue(0.0)), 'gamma4': - Condition(domain=CartesianDomain({ - 'x': 0, - 'y': [0, 1] - }), - equation=FixedValue(0.0)), + Condition(domain=CartesianDomain({ + 'x': 0, + 'y': [0, 1] + }), + equation=FixedValue(0.0)), 'D': - Condition(domain=CartesianDomain({ - 'x': [0, 1], - 'y': [0, 1] - }), - equation=my_laplace), + Condition(domain=CartesianDomain({ + 'x': [0, 1], + 'y': [0, 1] + }), + equation=my_laplace), 'data': - Condition(input_points=in_, output_points=out_) + Condition(input_points=in_, output_points=out_) } def poisson_sol(self, pts): return -(torch.sin(pts.extract(['x']) * torch.pi) * - torch.sin(pts.extract(['y']) * torch.pi)) / (2 * torch.pi**2) + torch.sin(pts.extract(['y']) * torch.pi)) / (2 * torch.pi ** 2) truth_solution = poisson_sol -# make the problem -poisson_problem = Poisson() -print(poisson_problem.input_pts) - def test_discretise_domain(): n = 10 poisson_problem = Poisson() @@ -83,7 +79,7 @@ def test_discretise_domain(): assert poisson_problem.input_pts[b].shape[0] == n poisson_problem.discretise_domain(n, 'grid', locations=['D']) - assert poisson_problem.input_pts['D'].shape[0] == n**2 + assert poisson_problem.input_pts['D'].shape[0] == n ** 2 poisson_problem.discretise_domain(n, 'random', locations=['D']) assert poisson_problem.input_pts['D'].shape[0] == n @@ -94,14 +90,15 @@ def test_discretise_domain(): assert poisson_problem.input_pts['D'].shape[0] == n -# def test_sampling_few_variables(): -# n = 10 -# poisson_problem.discretise_domain(n, -# 'grid', -# locations=['D'], -# variables=['x']) -# assert poisson_problem.input_pts['D'].shape[1] == 1 -# assert poisson_problem._have_sampled_points['D'] is False +def test_sampling_few_variables(): + n = 10 + poisson_problem = Poisson() + poisson_problem.discretise_domain(n, + 'grid', + locations=['D'], + variables=['x']) + assert poisson_problem.input_pts['D'].shape[1] == 1 + assert poisson_problem._have_sampled_points['D'] is False def test_variables_correct_order_sampling(): @@ -117,13 +114,11 @@ def test_variables_correct_order_sampling(): variables=['y']) assert poisson_problem.input_pts['D'].labels == sorted( poisson_problem.input_variables) - poisson_problem.discretise_domain(n, 'grid', locations=['D']) assert poisson_problem.input_pts['D'].labels == sorted( poisson_problem.input_variables) - poisson_problem.discretise_domain(n, 'grid', locations=['D'], @@ -140,8 +135,8 @@ def test_add_points(): poisson_problem.discretise_domain(0, 'random', locations=['D'], - variables=['x','y']) - new_pts = LabelTensor(torch.tensor([[0.5,-0.5]]),labels=['x','y']) + variables=['x', 'y']) + new_pts = LabelTensor(torch.tensor([[0.5, -0.5]]), labels=['x', 'y']) poisson_problem.add_points({'D': new_pts}) - assert torch.isclose(poisson_problem.input_pts['D'].extract('x'),new_pts.extract('x')) - assert torch.isclose(poisson_problem.input_pts['D'].extract('y'),new_pts.extract('y')) + assert torch.isclose(poisson_problem.input_pts['D'].extract('x'), new_pts.extract('x')) + assert torch.isclose(poisson_problem.input_pts['D'].extract('y'), new_pts.extract('y'))