is_inside method

This commit is contained in:
Dario Coscia
2023-04-29 12:25:50 +02:00
committed by Nicola Demo
parent 025a7ed0df
commit c8fb7715c4

View File

@@ -31,7 +31,7 @@ class Ellipsoid(Location):
Vol. 7. No. 8. 2001.
:Example:
>>> spatial_domain = Ellipsoid({'x':[1, 0], 'y':[0,1]})
>>> spatial_domain = Ellipsoid({'x':[-1, 1], 'y':[-1,1]})
"""
self.fixed_ = {}
@@ -80,6 +80,43 @@ class Ellipsoid(Location):
"""
return list(self.fixed_.keys()) + list(self.range_.keys())
def is_inside(self, point, check_border=False):
"""Check if a point is inside the ellipsoid.
:param point: Point to be checked
:type point: LabelTensor
:param check_border: Check if the point is also on the frontier
of the ellipsoid, default False.
:type check_border: bool
:return: Returning True if the point is inside, False otherwise.
:rtype: bool
"""
if not isinstance(point, LabelTensor):
raise ValueError('point expected to be LabelTensor.')
# get axis ellipse
list_dict_vals = list(self._axis.values())
tmp = torch.tensor(list_dict_vals, dtype=torch.float)
ax_sq = LabelTensor(tmp.reshape(1, -1), list(self._axis.keys()))
if not all([i in ax_sq.labels for i in point.labels]):
raise ValueError('point labels different from constructor'
f' dictionary labels. Got {point.labels},'
f' expected {ax_sq.labels}.')
# point square
point_sq = point.pow(2)
point_sq.labels = point.labels
# calculate ellispoid equation
eqn = torch.sum(point_sq.extract(ax_sq.labels) / ax_sq) - 1.
if check_border:
return bool(eqn <= 0)
return bool(eqn < 0)
def _sample_range(self, n, mode, variables):
"""Rescale the samples to the correct bounds.