* simplex bug solved --------- Co-authored-by: Dario Coscia <dariocoscia@dhcp-235.eduroam.sissa.it>
This commit is contained in:
committed by
Nicola Demo
parent
a9b1bd2826
commit
4850b0045d
@@ -45,13 +45,13 @@ class SimplexDomain(Location):
|
|||||||
check_consistency(simplex_matrix, LabelTensor)
|
check_consistency(simplex_matrix, LabelTensor)
|
||||||
|
|
||||||
# check consistency of labels
|
# check consistency of labels
|
||||||
self._coordinates = simplex_matrix[0].labels
|
matrix_labels = simplex_matrix[0].labels
|
||||||
if not all(vertex.labels == self._coordinates for vertex in simplex_matrix):
|
if not all(vertex.labels == matrix_labels for vertex in simplex_matrix):
|
||||||
raise ValueError(f"Labels don't match.")
|
raise ValueError(f"Labels don't match.")
|
||||||
|
|
||||||
# creating vertices matrix
|
# creating vertices matrix
|
||||||
self._vertices_matrix = torch.cat(simplex_matrix)
|
self._vertices_matrix = torch.cat(simplex_matrix)
|
||||||
self._vertices_matrix.labels = self._coordinates
|
self._vertices_matrix.labels = matrix_labels
|
||||||
|
|
||||||
# creating basis vectors for simplex
|
# creating basis vectors for simplex
|
||||||
self._vectors_shifted = (
|
self._vectors_shifted = (
|
||||||
@@ -59,15 +59,11 @@ class SimplexDomain(Location):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# build cartesian_bound
|
# build cartesian_bound
|
||||||
self._cartesian_bound = self._build_cartesian(self.variables)
|
self._cartesian_bound = self._build_cartesian(self._vertices_matrix)
|
||||||
|
|
||||||
@property
|
|
||||||
def coordinates(self):
|
|
||||||
return self._coordinates
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def variables(self):
|
def variables(self):
|
||||||
return self._vertices_matrix
|
return self._vertices_matrix.labels
|
||||||
|
|
||||||
def _build_cartesian(self, vertices):
|
def _build_cartesian(self, vertices):
|
||||||
"""
|
"""
|
||||||
@@ -80,10 +76,11 @@ class SimplexDomain(Location):
|
|||||||
|
|
||||||
span_dict = {}
|
span_dict = {}
|
||||||
|
|
||||||
for i, coord in enumerate(self._coordinates):
|
for i, coord in enumerate(self.variables):
|
||||||
sorted_vertices = sorted(vertices, key=lambda vertex: vertex[i])
|
sorted_vertices = sorted(vertices, key=lambda vertex: vertex[i])
|
||||||
# respective coord bounded by the lowest and highest values
|
# respective coord bounded by the lowest and highest values
|
||||||
span_dict[coord] = [sorted_vertices[0][i], sorted_vertices[-1][i]]
|
span_dict[coord] = [float(sorted_vertices[0][i]),
|
||||||
|
float(sorted_vertices[-1][i])]
|
||||||
|
|
||||||
return CartesianDomain(span_dict)
|
return CartesianDomain(span_dict)
|
||||||
|
|
||||||
@@ -105,7 +102,7 @@ class SimplexDomain(Location):
|
|||||||
:rtype: bool
|
:rtype: bool
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if not all([label in self.coordinates for label in point.labels]):
|
if not all(label in self.variables for label in point.labels):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Point labels different from constructor"
|
"Point labels different from constructor"
|
||||||
f" dictionary labels. Got {point.labels},"
|
f" dictionary labels. Got {point.labels},"
|
||||||
@@ -182,7 +179,7 @@ class SimplexDomain(Location):
|
|||||||
|
|
||||||
while len(sampled_points) < n:
|
while len(sampled_points) < n:
|
||||||
# extract number of vertices
|
# extract number of vertices
|
||||||
number_of_vertices = self._vertices_matrix.shape[1]
|
number_of_vertices = self._vertices_matrix.shape[0]
|
||||||
# extract idx lambda to set to zero randomly
|
# extract idx lambda to set to zero randomly
|
||||||
idx_lambda = torch.randint(low=0, high=number_of_vertices, size=(1,))
|
idx_lambda = torch.randint(low=0, high=number_of_vertices, size=(1,))
|
||||||
# build lambda vector
|
# build lambda vector
|
||||||
@@ -193,8 +190,7 @@ class SimplexDomain(Location):
|
|||||||
# 3. normalize
|
# 3. normalize
|
||||||
lambdas /= lambdas.sum()
|
lambdas /= lambdas.sum()
|
||||||
# 4. compute dot product
|
# 4. compute dot product
|
||||||
sampled_points.append(self._vertices_matrix @ lambdas)
|
sampled_points.append(self._vertices_matrix.T @ lambdas)
|
||||||
|
|
||||||
return torch.cat(sampled_points, dim=1).T
|
return torch.cat(sampled_points, dim=1).T
|
||||||
|
|
||||||
def sample(self, n, mode="random", variables="all"):
|
def sample(self, n, mode="random", variables="all"):
|
||||||
|
|||||||
@@ -37,6 +37,33 @@ def test_constructor():
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_sample():
|
||||||
|
# sampling inside
|
||||||
|
simplex = SimplexDomain(
|
||||||
|
[
|
||||||
|
LabelTensor(torch.tensor([[0, 0]]), labels=["x", "y"]),
|
||||||
|
LabelTensor(torch.tensor([[1, 1]]), labels=["x", "y"]),
|
||||||
|
LabelTensor(torch.tensor([[0, 2]]), labels=["x", "y"]),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
pts = simplex.sample(10)
|
||||||
|
assert isinstance(pts, LabelTensor)
|
||||||
|
assert pts.size() == torch.Size([10, 2])
|
||||||
|
|
||||||
|
# sampling border
|
||||||
|
SimplexDomain(
|
||||||
|
[
|
||||||
|
LabelTensor(torch.tensor([[0, 0]]), labels=["x", "y"]),
|
||||||
|
LabelTensor(torch.tensor([[1, 1]]), labels=["x", "y"]),
|
||||||
|
LabelTensor(torch.tensor([[0, 2]]), labels=["x", "y"]),
|
||||||
|
],
|
||||||
|
sample_surface=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
pts = simplex.sample(10)
|
||||||
|
assert isinstance(pts, LabelTensor)
|
||||||
|
assert pts.size() == torch.Size([10, 2])
|
||||||
|
|
||||||
|
|
||||||
def test_is_inside_faulty_point():
|
def test_is_inside_faulty_point():
|
||||||
domain = SimplexDomain(
|
domain = SimplexDomain(
|
||||||
|
|||||||
Reference in New Issue
Block a user