Solving #179 Simplex Domain Bug (#180)

* simplex bug solved
---------

Co-authored-by: Dario Coscia <dariocoscia@dhcp-235.eduroam.sissa.it>
This commit is contained in:
Dario Coscia
2023-10-05 11:26:17 +02:00
committed by Nicola Demo
parent a9b1bd2826
commit 4850b0045d
2 changed files with 38 additions and 15 deletions

View File

@@ -45,13 +45,13 @@ class SimplexDomain(Location):
check_consistency(simplex_matrix, LabelTensor)
# check consistency of labels
self._coordinates = simplex_matrix[0].labels
if not all(vertex.labels == self._coordinates for vertex in simplex_matrix):
matrix_labels = simplex_matrix[0].labels
if not all(vertex.labels == matrix_labels for vertex in simplex_matrix):
raise ValueError(f"Labels don't match.")
# creating vertices matrix
self._vertices_matrix = torch.cat(simplex_matrix)
self._vertices_matrix.labels = self._coordinates
self._vertices_matrix.labels = matrix_labels
# creating basis vectors for simplex
self._vectors_shifted = (
@@ -59,15 +59,11 @@ class SimplexDomain(Location):
)
# build cartesian_bound
self._cartesian_bound = self._build_cartesian(self.variables)
@property
def coordinates(self):
return self._coordinates
self._cartesian_bound = self._build_cartesian(self._vertices_matrix)
@property
def variables(self):
return self._vertices_matrix
return self._vertices_matrix.labels
def _build_cartesian(self, vertices):
"""
@@ -80,10 +76,11 @@ class SimplexDomain(Location):
span_dict = {}
for i, coord in enumerate(self._coordinates):
for i, coord in enumerate(self.variables):
sorted_vertices = sorted(vertices, key=lambda vertex: vertex[i])
# respective coord bounded by the lowest and highest values
span_dict[coord] = [sorted_vertices[0][i], sorted_vertices[-1][i]]
span_dict[coord] = [float(sorted_vertices[0][i]),
float(sorted_vertices[-1][i])]
return CartesianDomain(span_dict)
@@ -105,7 +102,7 @@ class SimplexDomain(Location):
:rtype: bool
"""
if not all([label in self.coordinates for label in point.labels]):
if not all(label in self.variables for label in point.labels):
raise ValueError(
"Point labels different from constructor"
f" dictionary labels. Got {point.labels},"
@@ -182,7 +179,7 @@ class SimplexDomain(Location):
while len(sampled_points) < n:
# extract number of vertices
number_of_vertices = self._vertices_matrix.shape[1]
number_of_vertices = self._vertices_matrix.shape[0]
# extract idx lambda to set to zero randomly
idx_lambda = torch.randint(low=0, high=number_of_vertices, size=(1,))
# build lambda vector
@@ -193,8 +190,7 @@ class SimplexDomain(Location):
# 3. normalize
lambdas /= lambdas.sum()
# 4. compute dot product
sampled_points.append(self._vertices_matrix @ lambdas)
sampled_points.append(self._vertices_matrix.T @ lambdas)
return torch.cat(sampled_points, dim=1).T
def sample(self, n, mode="random", variables="all"):