add dataset and dataloader for sample points (#195)
* add dataset and dataloader for sample points * unittests
This commit is contained in:
@@ -55,13 +55,15 @@ class SimplexDomain(Location):
|
||||
raise ValueError("An n-dimensional simplex is composed by n + 1 tensors of dimension n.")
|
||||
|
||||
# creating vertices matrix
|
||||
self._vertices_matrix = torch.cat(simplex_matrix)
|
||||
self._vertices_matrix.labels = matrix_labels
|
||||
self._vertices_matrix = LabelTensor.vstack(simplex_matrix)
|
||||
|
||||
# creating basis vectors for simplex
|
||||
self._vectors_shifted = (
|
||||
(self._vertices_matrix.T)[:, :-1] - (self._vertices_matrix.T)[:, None, -1]
|
||||
)
|
||||
# self._vectors_shifted = (
|
||||
# (self._vertices_matrix.T)[:, :-1] - (self._vertices_matrix.T)[:, None, -1]
|
||||
# ) ### TODO: Remove after checking
|
||||
|
||||
vert = self._vertices_matrix
|
||||
self._vectors_shifted = (vert[:-1] - vert[-1]).T
|
||||
|
||||
# build cartesian_bound
|
||||
self._cartesian_bound = self._build_cartesian(self._vertices_matrix)
|
||||
@@ -114,8 +116,8 @@ class SimplexDomain(Location):
|
||||
f" expected {self.variables}."
|
||||
)
|
||||
|
||||
# shift point
|
||||
point_shift = point.T - (self._vertices_matrix.T)[:, None, -1]
|
||||
point_shift = point - self._vertices_matrix[-1]
|
||||
point_shift = point_shift.tensor.reshape(-1, 1)
|
||||
|
||||
# compute barycentric coordinates
|
||||
lambda_ = torch.linalg.solve(self._vectors_shifted * 1.0, point_shift * 1.0)
|
||||
|
||||
Reference in New Issue
Block a user