Solving problems related to Geometry (#118)
* fix and add tests * minor fix on domain classes --------- Co-authored-by: Dario Coscia <dariocoscia@cli-10-110-0-208.WIFIeduroamSTUD.units.it> Co-authored-by: Dario Coscia <dariocoscia@dhcp-040.eduroam.sissa.it>
This commit is contained in:
committed by
Nicola Demo
parent
62ec69ccac
commit
982af4a04d
@@ -2,6 +2,7 @@ import torch
|
||||
|
||||
from .location import Location
|
||||
from ..label_tensor import LabelTensor
|
||||
from ..utils import check_consistency
|
||||
|
||||
|
||||
class EllipsoidDomain(Location):
|
||||
@@ -39,9 +40,8 @@ class EllipsoidDomain(Location):
|
||||
self._centers = None
|
||||
self._axis = None
|
||||
|
||||
if not isinstance(sample_surface, bool):
|
||||
raise ValueError('sample_surface must be bool type.')
|
||||
|
||||
# checking consistency
|
||||
check_consistency(sample_surface, bool)
|
||||
self._sample_surface = sample_surface
|
||||
|
||||
for k, v in ellipsoid_dict.items():
|
||||
@@ -81,9 +81,14 @@ class EllipsoidDomain(Location):
|
||||
return list(self.fixed_.keys()) + list(self.range_.keys())
|
||||
|
||||
def is_inside(self, point, check_border=False):
|
||||
"""Check if a point is inside the ellipsoid.
|
||||
"""Check if a point is inside the ellipsoid domain.
|
||||
|
||||
:param point: Point to be checked
|
||||
.. note::
|
||||
When ```'sample_surface'``` in the ```'__init()__'```
|
||||
is set to ```'True'```, then the method only checks
|
||||
points on the surface, and not inside the domain.
|
||||
|
||||
:param point: Point to be checked.
|
||||
:type point: LabelTensor
|
||||
:param check_border: Check if the point is also on the frontier
|
||||
of the ellipsoid, default False.
|
||||
@@ -92,29 +97,40 @@ class EllipsoidDomain(Location):
|
||||
:rtype: bool
|
||||
"""
|
||||
|
||||
if not isinstance(point, LabelTensor):
|
||||
raise ValueError('point expected to be LabelTensor.')
|
||||
|
||||
# get axis ellipse
|
||||
# small check that point is labeltensor
|
||||
check_consistency(point, LabelTensor)
|
||||
|
||||
# get axis ellipse as tensors
|
||||
list_dict_vals = list(self._axis.values())
|
||||
tmp = torch.tensor(list_dict_vals, dtype=torch.float)
|
||||
ax_sq = LabelTensor(tmp.reshape(1, -1)**2, list(self._axis.keys()))
|
||||
ax_sq = LabelTensor(tmp.reshape(1, -1)**2, self.variables)
|
||||
|
||||
# get centers ellipse as tensors
|
||||
list_dict_vals = list(self._centers.values())
|
||||
tmp = torch.tensor(list_dict_vals, dtype=torch.float)
|
||||
centers = LabelTensor(tmp.reshape(1, -1), self.variables)
|
||||
|
||||
if not all([i in ax_sq.labels for i in point.labels]):
|
||||
raise ValueError('point labels different from constructor'
|
||||
f' dictionary labels. Got {point.labels},'
|
||||
f' expected {ax_sq.labels}.')
|
||||
|
||||
# point square
|
||||
point_sq = point.pow(2)
|
||||
# point square + shift center
|
||||
point_sq = (point - centers).pow(2)
|
||||
point_sq.labels = point.labels
|
||||
|
||||
# calculate ellispoid equation
|
||||
eqn = torch.sum(point_sq.extract(ax_sq.labels) / ax_sq) - 1.
|
||||
|
||||
# if we have sampled only the surface, we check that the
|
||||
# point is inside the surface border only
|
||||
if self._sample_surface:
|
||||
return torch.allclose(eqn, torch.zeros_like(eqn))
|
||||
|
||||
# otherwise we check the ellipse
|
||||
if check_border:
|
||||
return bool(eqn <= 0)
|
||||
|
||||
|
||||
return bool(eqn < 0)
|
||||
|
||||
def _sample_range(self, n, mode, variables):
|
||||
@@ -265,4 +281,4 @@ class EllipsoidDomain(Location):
|
||||
if mode in ['random']:
|
||||
return _Nd_sampler(n, mode, variables)
|
||||
else:
|
||||
raise ValueError(f'mode={mode} is not valid.')
|
||||
raise NotImplemented(f'mode={mode} is not implemented.')
|
||||
|
||||
Reference in New Issue
Block a user