Fix Codacy Warnings (#477)
--------- Co-authored-by: Dario Coscia <dariocos99@gmail.com>
This commit is contained in:
committed by
Nicola Demo
parent
e3790e049a
commit
4177bfbb50
@@ -1,6 +1,10 @@
|
||||
"""
|
||||
Module for Simplex Domain.
|
||||
"""
|
||||
|
||||
import torch
|
||||
from .domain_interface import DomainInterface
|
||||
from pina.domain import CartesianDomain
|
||||
from .cartesian import CartesianDomain
|
||||
from ..label_tensor import LabelTensor
|
||||
from ..utils import check_consistency
|
||||
|
||||
@@ -51,23 +55,20 @@ class SimplexDomain(DomainInterface):
|
||||
# check consistency of labels
|
||||
matrix_labels = simplex_matrix[0].labels
|
||||
if not all(vertex.labels == matrix_labels for vertex in simplex_matrix):
|
||||
raise ValueError(f"Labels don't match.")
|
||||
raise ValueError("Labels don't match.")
|
||||
|
||||
# check consistency dimensions
|
||||
dim_simplex = len(matrix_labels)
|
||||
if len(simplex_matrix) != dim_simplex + 1:
|
||||
raise ValueError(
|
||||
"An n-dimensional simplex is composed by n + 1 tensors of dimension n."
|
||||
"An n-dimensional simplex is composed by n + 1 tensors of "
|
||||
"dimension n."
|
||||
)
|
||||
|
||||
# creating vertices matrix
|
||||
self._vertices_matrix = LabelTensor.vstack(simplex_matrix)
|
||||
|
||||
# creating basis vectors for simplex
|
||||
# self._vectors_shifted = (
|
||||
# (self._vertices_matrix.T)[:, :-1] - (self._vertices_matrix.T)[:, None, -1]
|
||||
# ) ### TODO: Remove after checking
|
||||
|
||||
vert = self._vertices_matrix
|
||||
self._vectors_shifted = (vert[:-1] - vert[-1]).T
|
||||
|
||||
@@ -92,7 +93,7 @@ class SimplexDomain(DomainInterface):
|
||||
"""
|
||||
|
||||
span_dict = {}
|
||||
for i, coord in enumerate(self.variables):
|
||||
for coord in self.variables:
|
||||
sorted_vertices = torch.sort(vertices[coord].tensor.squeeze())
|
||||
# respective coord bounded by the lowest and highest values
|
||||
span_dict[coord] = [
|
||||
@@ -133,6 +134,7 @@ class SimplexDomain(DomainInterface):
|
||||
point_shift = point_shift.tensor.reshape(-1, 1)
|
||||
|
||||
# compute barycentric coordinates
|
||||
|
||||
lambda_ = torch.linalg.solve(
|
||||
self._vectors_shifted * 1.0, point_shift * 1.0
|
||||
)
|
||||
@@ -222,7 +224,8 @@ class SimplexDomain(DomainInterface):
|
||||
Sample n points from Simplex domain.
|
||||
|
||||
:param int n: Number of points to sample in the shape.
|
||||
:param str mode: Mode for sampling, defaults to ``random``. Available modes include: ``random``.
|
||||
:param str mode: Mode for sampling, defaults to ``random``. Available
|
||||
modes include: ``random``.
|
||||
:param variables: Variables to be sampled, defaults to ``all``.
|
||||
:type variables: str | list[str]
|
||||
:return: Returns ``LabelTensor`` of n sampled points.
|
||||
|
||||
Reference in New Issue
Block a user