* simplex bug solved --------- Co-authored-by: Dario Coscia <dariocoscia@dhcp-235.eduroam.sissa.it>
This commit is contained in:
committed by
Nicola Demo
parent
a9b1bd2826
commit
4850b0045d
@@ -45,13 +45,13 @@ class SimplexDomain(Location):
|
||||
check_consistency(simplex_matrix, LabelTensor)
|
||||
|
||||
# check consistency of labels
|
||||
self._coordinates = simplex_matrix[0].labels
|
||||
if not all(vertex.labels == self._coordinates for vertex in simplex_matrix):
|
||||
matrix_labels = simplex_matrix[0].labels
|
||||
if not all(vertex.labels == matrix_labels for vertex in simplex_matrix):
|
||||
raise ValueError(f"Labels don't match.")
|
||||
|
||||
# creating vertices matrix
|
||||
self._vertices_matrix = torch.cat(simplex_matrix)
|
||||
self._vertices_matrix.labels = self._coordinates
|
||||
self._vertices_matrix.labels = matrix_labels
|
||||
|
||||
# creating basis vectors for simplex
|
||||
self._vectors_shifted = (
|
||||
@@ -59,15 +59,11 @@ class SimplexDomain(Location):
|
||||
)
|
||||
|
||||
# build cartesian_bound
|
||||
self._cartesian_bound = self._build_cartesian(self.variables)
|
||||
|
||||
@property
|
||||
def coordinates(self):
|
||||
return self._coordinates
|
||||
self._cartesian_bound = self._build_cartesian(self._vertices_matrix)
|
||||
|
||||
@property
|
||||
def variables(self):
|
||||
return self._vertices_matrix
|
||||
return self._vertices_matrix.labels
|
||||
|
||||
def _build_cartesian(self, vertices):
|
||||
"""
|
||||
@@ -80,10 +76,11 @@ class SimplexDomain(Location):
|
||||
|
||||
span_dict = {}
|
||||
|
||||
for i, coord in enumerate(self._coordinates):
|
||||
for i, coord in enumerate(self.variables):
|
||||
sorted_vertices = sorted(vertices, key=lambda vertex: vertex[i])
|
||||
# respective coord bounded by the lowest and highest values
|
||||
span_dict[coord] = [sorted_vertices[0][i], sorted_vertices[-1][i]]
|
||||
span_dict[coord] = [float(sorted_vertices[0][i]),
|
||||
float(sorted_vertices[-1][i])]
|
||||
|
||||
return CartesianDomain(span_dict)
|
||||
|
||||
@@ -105,7 +102,7 @@ class SimplexDomain(Location):
|
||||
:rtype: bool
|
||||
"""
|
||||
|
||||
if not all([label in self.coordinates for label in point.labels]):
|
||||
if not all(label in self.variables for label in point.labels):
|
||||
raise ValueError(
|
||||
"Point labels different from constructor"
|
||||
f" dictionary labels. Got {point.labels},"
|
||||
@@ -182,7 +179,7 @@ class SimplexDomain(Location):
|
||||
|
||||
while len(sampled_points) < n:
|
||||
# extract number of vertices
|
||||
number_of_vertices = self._vertices_matrix.shape[1]
|
||||
number_of_vertices = self._vertices_matrix.shape[0]
|
||||
# extract idx lambda to set to zero randomly
|
||||
idx_lambda = torch.randint(low=0, high=number_of_vertices, size=(1,))
|
||||
# build lambda vector
|
||||
@@ -193,8 +190,7 @@ class SimplexDomain(Location):
|
||||
# 3. normalize
|
||||
lambdas /= lambdas.sum()
|
||||
# 4. compute dot product
|
||||
sampled_points.append(self._vertices_matrix @ lambdas)
|
||||
|
||||
sampled_points.append(self._vertices_matrix.T @ lambdas)
|
||||
return torch.cat(sampled_points, dim=1).T
|
||||
|
||||
def sample(self, n, mode="random", variables="all"):
|
||||
|
||||
@@ -37,6 +37,33 @@ def test_constructor():
|
||||
]
|
||||
)
|
||||
|
||||
def test_sample():
|
||||
# sampling inside
|
||||
simplex = SimplexDomain(
|
||||
[
|
||||
LabelTensor(torch.tensor([[0, 0]]), labels=["x", "y"]),
|
||||
LabelTensor(torch.tensor([[1, 1]]), labels=["x", "y"]),
|
||||
LabelTensor(torch.tensor([[0, 2]]), labels=["x", "y"]),
|
||||
]
|
||||
)
|
||||
pts = simplex.sample(10)
|
||||
assert isinstance(pts, LabelTensor)
|
||||
assert pts.size() == torch.Size([10, 2])
|
||||
|
||||
# sampling border
|
||||
SimplexDomain(
|
||||
[
|
||||
LabelTensor(torch.tensor([[0, 0]]), labels=["x", "y"]),
|
||||
LabelTensor(torch.tensor([[1, 1]]), labels=["x", "y"]),
|
||||
LabelTensor(torch.tensor([[0, 2]]), labels=["x", "y"]),
|
||||
],
|
||||
sample_surface=True,
|
||||
)
|
||||
|
||||
pts = simplex.sample(10)
|
||||
assert isinstance(pts, LabelTensor)
|
||||
assert pts.size() == torch.Size([10, 2])
|
||||
|
||||
|
||||
def test_is_inside_faulty_point():
|
||||
domain = SimplexDomain(
|
||||
|
||||
Reference in New Issue
Block a user