Solving problems related to Geometry (#118)

* fix and add tests
* minor fix on domain classes

---------

Co-authored-by: Dario Coscia <dariocoscia@cli-10-110-0-208.WIFIeduroamSTUD.units.it>
Co-authored-by: Dario Coscia <dariocoscia@dhcp-040.eduroam.sissa.it>
This commit is contained in:
Dario Coscia
2023-06-20 17:30:28 +02:00
committed by Nicola Demo
parent 62ec69ccac
commit 982af4a04d
7 changed files with 122 additions and 35 deletions

View File

@@ -2,6 +2,7 @@ import torch
from .location import Location
from ..label_tensor import LabelTensor
from ..utils import check_consistency
class EllipsoidDomain(Location):
@@ -39,9 +40,8 @@ class EllipsoidDomain(Location):
self._centers = None
self._axis = None
if not isinstance(sample_surface, bool):
raise ValueError('sample_surface must be bool type.')
# checking consistency
check_consistency(sample_surface, bool)
self._sample_surface = sample_surface
for k, v in ellipsoid_dict.items():
@@ -81,9 +81,14 @@ class EllipsoidDomain(Location):
return list(self.fixed_.keys()) + list(self.range_.keys())
def is_inside(self, point, check_border=False):
"""Check if a point is inside the ellipsoid.
"""Check if a point is inside the ellipsoid domain.
:param point: Point to be checked
.. note::
When ```'sample_surface'``` in the ```'__init()__'```
is set to ```'True'```, then the method only checks
points on the surface, and not inside the domain.
:param point: Point to be checked.
:type point: LabelTensor
:param check_border: Check if the point is also on the frontier
of the ellipsoid, default False.
@@ -92,29 +97,40 @@ class EllipsoidDomain(Location):
:rtype: bool
"""
if not isinstance(point, LabelTensor):
raise ValueError('point expected to be LabelTensor.')
# get axis ellipse
# small check that point is labeltensor
check_consistency(point, LabelTensor)
# get axis ellipse as tensors
list_dict_vals = list(self._axis.values())
tmp = torch.tensor(list_dict_vals, dtype=torch.float)
ax_sq = LabelTensor(tmp.reshape(1, -1)**2, list(self._axis.keys()))
ax_sq = LabelTensor(tmp.reshape(1, -1)**2, self.variables)
# get centers ellipse as tensors
list_dict_vals = list(self._centers.values())
tmp = torch.tensor(list_dict_vals, dtype=torch.float)
centers = LabelTensor(tmp.reshape(1, -1), self.variables)
if not all([i in ax_sq.labels for i in point.labels]):
raise ValueError('point labels different from constructor'
f' dictionary labels. Got {point.labels},'
f' expected {ax_sq.labels}.')
# point square
point_sq = point.pow(2)
# point square + shift center
point_sq = (point - centers).pow(2)
point_sq.labels = point.labels
# calculate ellispoid equation
eqn = torch.sum(point_sq.extract(ax_sq.labels) / ax_sq) - 1.
# if we have sampled only the surface, we check that the
# point is inside the surface border only
if self._sample_surface:
return torch.allclose(eqn, torch.zeros_like(eqn))
# otherwise we check the ellipse
if check_border:
return bool(eqn <= 0)
return bool(eqn < 0)
def _sample_range(self, n, mode, variables):
@@ -265,4 +281,4 @@ class EllipsoidDomain(Location):
if mode in ['random']:
return _Nd_sampler(n, mode, variables)
else:
raise ValueError(f'mode={mode} is not valid.')
raise NotImplemented(f'mode={mode} is not implemented.')