From c8fb7715c43c7a6b6e6ac61a9222295abf63f444 Mon Sep 17 00:00:00 2001 From: Dario Coscia Date: Sat, 29 Apr 2023 12:25:50 +0200 Subject: [PATCH] is_inside method --- pina/ellipsoid.py | 39 ++++++++++++++++++++++++++++++++++++++- 1 file changed, 38 insertions(+), 1 deletion(-) diff --git a/pina/ellipsoid.py b/pina/ellipsoid.py index 58b31a4..bbf5fe1 100644 --- a/pina/ellipsoid.py +++ b/pina/ellipsoid.py @@ -31,7 +31,7 @@ class Ellipsoid(Location): Vol. 7. No. 8. 2001. :Example: - >>> spatial_domain = Ellipsoid({'x':[1, 0], 'y':[0,1]}) + >>> spatial_domain = Ellipsoid({'x':[-1, 1], 'y':[-1,1]}) """ self.fixed_ = {} @@ -80,6 +80,43 @@ class Ellipsoid(Location): """ return list(self.fixed_.keys()) + list(self.range_.keys()) + def is_inside(self, point, check_border=False): + """Check if a point is inside the ellipsoid. + + :param point: Point to be checked + :type point: LabelTensor + :param check_border: Check if the point is also on the frontier + of the ellipsoid, default False. + :type check_border: bool + :return: Returning True if the point is inside, False otherwise. + :rtype: bool + """ + + if not isinstance(point, LabelTensor): + raise ValueError('point expected to be LabelTensor.') + + # get axis ellipse + list_dict_vals = list(self._axis.values()) + tmp = torch.tensor(list_dict_vals, dtype=torch.float) + ax_sq = LabelTensor(tmp.reshape(1, -1), list(self._axis.keys())) + + if not all([i in ax_sq.labels for i in point.labels]): + raise ValueError('point labels different from constructor' + f' dictionary labels. Got {point.labels},' + f' expected {ax_sq.labels}.') + + # point square + point_sq = point.pow(2) + point_sq.labels = point.labels + + # calculate ellispoid equation + eqn = torch.sum(point_sq.extract(ax_sq.labels) / ax_sq) - 1. + + if check_border: + return bool(eqn <= 0) + + return bool(eqn < 0) + def _sample_range(self, n, mode, variables): """Rescale the samples to the correct bounds.