add dataset and dataloader for sample points (#195)

* add dataset and dataloader for sample points
* unittests
This commit is contained in:
Nicola Demo
2023-11-07 11:34:44 +01:00
parent cd5bc9a558
commit d654259428
19 changed files with 581 additions and 196 deletions

View File

@@ -55,13 +55,15 @@ class SimplexDomain(Location):
raise ValueError("An n-dimensional simplex is composed by n + 1 tensors of dimension n.")
# creating vertices matrix
self._vertices_matrix = torch.cat(simplex_matrix)
self._vertices_matrix.labels = matrix_labels
self._vertices_matrix = LabelTensor.vstack(simplex_matrix)
# creating basis vectors for simplex
self._vectors_shifted = (
(self._vertices_matrix.T)[:, :-1] - (self._vertices_matrix.T)[:, None, -1]
)
# self._vectors_shifted = (
# (self._vertices_matrix.T)[:, :-1] - (self._vertices_matrix.T)[:, None, -1]
# ) ### TODO: Remove after checking
vert = self._vertices_matrix
self._vectors_shifted = (vert[:-1] - vert[-1]).T
# build cartesian_bound
self._cartesian_bound = self._build_cartesian(self._vertices_matrix)
@@ -114,8 +116,8 @@ class SimplexDomain(Location):
f" expected {self.variables}."
)
# shift point
point_shift = point.T - (self._vertices_matrix.T)[:, None, -1]
point_shift = point - self._vertices_matrix[-1]
point_shift = point_shift.tensor.reshape(-1, 1)
# compute barycentric coordinates
lambda_ = torch.linalg.solve(self._vectors_shifted * 1.0, point_shift * 1.0)