diff --git a/pina/geometry/simplex.py b/pina/geometry/simplex.py index e3670a2..0153540 100644 --- a/pina/geometry/simplex.py +++ b/pina/geometry/simplex.py @@ -45,13 +45,13 @@ class SimplexDomain(Location): check_consistency(simplex_matrix, LabelTensor) # check consistency of labels - self._coordinates = simplex_matrix[0].labels - if not all(vertex.labels == self._coordinates for vertex in simplex_matrix): + matrix_labels = simplex_matrix[0].labels + if not all(vertex.labels == matrix_labels for vertex in simplex_matrix): raise ValueError(f"Labels don't match.") # creating vertices matrix self._vertices_matrix = torch.cat(simplex_matrix) - self._vertices_matrix.labels = self._coordinates + self._vertices_matrix.labels = matrix_labels # creating basis vectors for simplex self._vectors_shifted = ( @@ -59,15 +59,11 @@ class SimplexDomain(Location): ) # build cartesian_bound - self._cartesian_bound = self._build_cartesian(self.variables) - - @property - def coordinates(self): - return self._coordinates + self._cartesian_bound = self._build_cartesian(self._vertices_matrix) @property def variables(self): - return self._vertices_matrix + return self._vertices_matrix.labels def _build_cartesian(self, vertices): """ @@ -80,10 +76,11 @@ class SimplexDomain(Location): span_dict = {} - for i, coord in enumerate(self._coordinates): + for i, coord in enumerate(self.variables): sorted_vertices = sorted(vertices, key=lambda vertex: vertex[i]) # respective coord bounded by the lowest and highest values - span_dict[coord] = [sorted_vertices[0][i], sorted_vertices[-1][i]] + span_dict[coord] = [float(sorted_vertices[0][i]), + float(sorted_vertices[-1][i])] return CartesianDomain(span_dict) @@ -105,7 +102,7 @@ class SimplexDomain(Location): :rtype: bool """ - if not all([label in self.coordinates for label in point.labels]): + if not all(label in self.variables for label in point.labels): raise ValueError( "Point labels different from constructor" f" dictionary labels. Got {point.labels}," @@ -182,7 +179,7 @@ class SimplexDomain(Location): while len(sampled_points) < n: # extract number of vertices - number_of_vertices = self._vertices_matrix.shape[1] + number_of_vertices = self._vertices_matrix.shape[0] # extract idx lambda to set to zero randomly idx_lambda = torch.randint(low=0, high=number_of_vertices, size=(1,)) # build lambda vector @@ -193,8 +190,7 @@ class SimplexDomain(Location): # 3. normalize lambdas /= lambdas.sum() # 4. compute dot product - sampled_points.append(self._vertices_matrix @ lambdas) - + sampled_points.append(self._vertices_matrix.T @ lambdas) return torch.cat(sampled_points, dim=1).T def sample(self, n, mode="random", variables="all"): diff --git a/tests/test_geometry/test_simplex.py b/tests/test_geometry/test_simplex.py index e0cf081..3eacd36 100644 --- a/tests/test_geometry/test_simplex.py +++ b/tests/test_geometry/test_simplex.py @@ -37,6 +37,33 @@ def test_constructor(): ] ) +def test_sample(): + # sampling inside + simplex = SimplexDomain( + [ + LabelTensor(torch.tensor([[0, 0]]), labels=["x", "y"]), + LabelTensor(torch.tensor([[1, 1]]), labels=["x", "y"]), + LabelTensor(torch.tensor([[0, 2]]), labels=["x", "y"]), + ] + ) + pts = simplex.sample(10) + assert isinstance(pts, LabelTensor) + assert pts.size() == torch.Size([10, 2]) + + # sampling border + SimplexDomain( + [ + LabelTensor(torch.tensor([[0, 0]]), labels=["x", "y"]), + LabelTensor(torch.tensor([[1, 1]]), labels=["x", "y"]), + LabelTensor(torch.tensor([[0, 2]]), labels=["x", "y"]), + ], + sample_surface=True, + ) + + pts = simplex.sample(10) + assert isinstance(pts, LabelTensor) + assert pts.size() == torch.Size([10, 2]) + def test_is_inside_faulty_point(): domain = SimplexDomain(