fix doc domain

This commit is contained in:
giovanni
2025-03-12 15:27:06 +01:00
committed by Nicola Demo
parent dc5c5c2187
commit ce35af6397
10 changed files with 329 additions and 325 deletions

View File

@@ -1,6 +1,4 @@
"""
Module for Simplex Domain.
"""
"""Module for the Simplex Domain."""
import torch
from .domain_interface import DomainInterface
@@ -10,27 +8,28 @@ from ..utils import check_consistency
class SimplexDomain(DomainInterface):
"""PINA implementation of a Simplex."""
"""
Implementation of the simplex domain.
"""
def __init__(self, simplex_matrix, sample_surface=False):
"""
:param simplex_matrix: A matrix of LabelTensor objects representing
a vertex of the simplex (a tensor), and the coordinates of the
point (a list of labels).
Initialization of the :class:`SimplexDomain` class.
:type simplex_matrix: list[LabelTensor]
:param sample_surface: A variable for choosing sample strategies. If
``sample_surface=True`` only samples on the Simplex surface
frontier are taken. If ``sample_surface=False``, no such criteria
is followed.
:type sample_surface: bool
:param list[LabelTensor] simplex_matrix: A matrix representing the
vertices of the simplex.
:param bool sample_surface: A flag to choose the sampling strategy.
If ``True``, samples are taken only from the surface of the simplex.
If ``False``, samples are taken from the interior of the simplex.
Default is ``False``.
:raises ValueError: If the labels of the vertices don't match.
:raises ValueError: If the number of vertices is not equal to the
dimension of the simplex plus one.
.. warning::
Sampling for dimensions greater or equal to 10 could result
in a shrinking of the simplex, which degrades the quality
of the samples. For dimensions higher than 10, other algorithms
for sampling should be used.
Sampling for dimensions greater or equal to 10 could result in a
shrinkage of the simplex, which degrades the quality of the samples.
For dimensions higher than 10, use other sampling algorithms.
:Example:
>>> spatial_domain = SimplexDomain(
@@ -77,18 +76,30 @@ class SimplexDomain(DomainInterface):
@property
def sample_modes(self):
"""
List of available sampling modes.
:return: List of available sampling modes.
:rtype: list[str]
"""
return ["random"]
@property
def variables(self):
"""
List of variables of the domain.
:return: List of variables of the domain.
:rtype: list[str]
"""
return sorted(self._vertices_matrix.labels)
def _build_cartesian(self, vertices):
"""
Build Cartesian border for Simplex domain to be used in sampling.
:param vertex_matrix: matrix of vertices
:type vertices: list[list]
:return: Cartesian border for triangular domain
Build the cartesian border for a simplex domain to be used in sampling.
:param list[LabelTensor] vertices: Matrix of vertices defining the domain.
:return: The cartesian border for the simplex domain.
:rtype: CartesianDomain
"""
@@ -105,22 +116,16 @@ class SimplexDomain(DomainInterface):
def is_inside(self, point, check_border=False):
"""
Check if a point is inside the simplex.
Uses the algorithm described involving barycentric coordinates:
https://en.wikipedia.org/wiki/Barycentric_coordinate_system.
Check if a point is inside the simplex. It uses an algorithm involving
barycentric coordinates.
:param point: Point to be checked.
:type point: LabelTensor
:param check_border: Check if the point is also on the frontier
of the simplex, default ``False``.
:type check_border: bool
:return: Returning ``True`` if the point is inside, ``False`` otherwise.
:param LabelTensor point: Point to be checked.
:param check_border: If ``True``, the border is considered inside
the simplex. Default is ``False``.
:raises ValueError: If the labels of the point are different from those
passed in the ``__init__`` method.
:return: ``True`` if the point is inside the domain, ``False`` otherwise.
:rtype: bool
.. note::
When ``sample_surface`` in the ``__init()__``
is set to ``True``, then the method only checks
points on the surface, and not inside the domain.
"""
if not all(label in self.variables for label in point.labels):
@@ -134,7 +139,6 @@ class SimplexDomain(DomainInterface):
point_shift = point_shift.tensor.reshape(-1, 1)
# compute barycentric coordinates
lambda_ = torch.linalg.solve(
self._vectors_shifted * 1.0, point_shift * 1.0
)
@@ -151,13 +155,13 @@ class SimplexDomain(DomainInterface):
def _sample_interior_randomly(self, n, variables):
"""
Randomly sample points inside a simplex of arbitrary
dimension, without the boundary.
:param int n: Number of points to sample in the shape.
:param variables: pinn variable to be sampled, defaults to ``all``.
:type variables: str or list[str], optional
:return: Returns tensor of n sampled points.
:rtype: torch.Tensor
Sample at random points from the interior of the simplex. Boundaries are
excluded from this sampling routine.
:param int n: Number of points to sample.
:param list[str] variables: variables to be sampled.
:return: Sampled points.
:rtype: list[torch.Tensor]
"""
# =============== For Developers ================ #
@@ -182,10 +186,10 @@ class SimplexDomain(DomainInterface):
def _sample_boundary_randomly(self, n):
"""
Randomly sample points on the boundary of a simplex
of arbitrary dimensions.
:param int n: Number of points to sample in the shape.
:return: Returns tensor of n sampled points
Sample at random points from the boundary of the simplex.
:param int n: Number of points to sample.
:return: Sampled points.
:rtype: torch.Tensor
"""
@@ -221,20 +225,19 @@ class SimplexDomain(DomainInterface):
def sample(self, n, mode="random", variables="all"):
"""
Sample n points from Simplex domain.
Sampling routine.
:param int n: Number of points to sample in the shape.
:param str mode: Mode for sampling, defaults to ``random``. Available
modes include: ``random``.
:param variables: Variables to be sampled, defaults to ``all``.
:type variables: str | list[str]
:return: Returns ``LabelTensor`` of n sampled points.
:param int n: Number of points to sample.
:param str mode: Sampling method. Default is ``random``.
Available modes: random sampling, ``random``.
:param list[str] variables: variables to be sampled. Default is ``all``.
:raises NotImplementedError: If the sampling method is not implemented.
:return: Sampled points.
:rtype: LabelTensor
.. warning::
When ``sample_surface = True`` in the initialization, all
the variables are sampled, despite passing different once
in ``variables``.
When ``sample_surface=True``, all variables are sampled,
ignoring the ``variables`` parameter.
"""
if variables == "all":