From 6c8635c316543aaea0e13b37279cd3abd100c452 Mon Sep 17 00:00:00 2001 From: Dario Coscia <93731561+dario-coscia@users.noreply.github.com> Date: Wed, 28 Jun 2023 14:44:00 +0200 Subject: [PATCH] Variables in Discretise Domain (#139) * fix problems discretise_domain * adding docs, fixing tests --- pina/problem/abstract_problem.py | 108 ++++++++++++++++++++++--------- tests/test_problem.py | 32 +++++---- 2 files changed, 95 insertions(+), 45 deletions(-) diff --git a/pina/problem/abstract_problem.py b/pina/problem/abstract_problem.py index 8f32e90..e67b15f 100644 --- a/pina/problem/abstract_problem.py +++ b/pina/problem/abstract_problem.py @@ -1,6 +1,6 @@ """ Module for AbstractProblem class """ from abc import ABCMeta, abstractmethod -from ..utils import merge_tensors +from ..utils import merge_tensors, check_consistency class AbstractProblem(metaclass=ABCMeta): @@ -111,53 +111,97 @@ class AbstractProblem(metaclass=ABCMeta): continue self.input_pts[condition_name] = samples - def discretise_domain(self, *args, **kwargs): + def discretise_domain(self, n, mode = 'random', variables = 'all', locations = 'all'): """ Generate a set of points to span the `Location` of all the conditions of the problem. - >>> pinn.span_pts(n=10, mode='grid') - >>> pinn.span_pts(n=10, mode='grid', location=['bound1']) - >>> pinn.span_pts(n=10, mode='grid', variables=['x']) + :param n: Number of points to sample, see Note below + for reference. + :type n: int + :param mode: Mode for sampling, defaults to ``random``. + Available modes include: random sampling, ``random``; + latin hypercube sampling, ``latin`` or ``lh``; + chebyshev sampling, ``chebyshev``; grid sampling ``grid``. + :param variables: problem's variables to be sampled, defaults to 'all'. + :type variables: str or list[str], optional + :param locations: problem's locations from where to sample, defaults to 'all'. + :type locations: str, optional + + :Example: + >>> pinn.span_pts(n=10, mode='grid') + >>> pinn.span_pts(n=10, mode='grid', location=['bound1']) + >>> pinn.span_pts(n=10, mode='grid', variables=['x']) + + .. warning:: + ``random`` is currently the only implemented ``mode`` for all geometries, i.e. + ``EllipsoidDomain``, ``CartesianDomain``, ``SimplexDomain`` and the geometries + compositions ``Union``, ``Difference``, ``Exclusion``, ``Intersection``. The + modes ``latin`` or ``lh``, ``chebyshev``, ``grid`` are only implemented for + ``CartesianDomain``. """ - if all(key in kwargs for key in ['n', 'mode']): - argument = {} - argument['n'] = kwargs['n'] - argument['mode'] = kwargs['mode'] - argument['variables'] = self.input_variables - arguments = [argument] - elif any(key in kwargs for key in ['n', 'mode']) and args: - raise ValueError("Don't mix args and kwargs") - elif isinstance(args[0], int) and isinstance(args[1], str): - argument = {} - argument['n'] = int(args[0]) - argument['mode'] = args[1] - argument['variables'] = self.input_variables - arguments = [argument] - elif all(isinstance(arg, dict) for arg in args): - arguments = args + + # check consistecy n + check_consistency(n, int) + + # check consistency mode + check_consistency(mode, str) + if mode not in ['random', 'grid', 'lh', 'chebyshev', 'latin']: + raise TypeError(f'mode {mode} not valid.') + + # check consistency variables + if variables == 'all': + variables = self.input_variables else: - raise RuntimeError - - locations = kwargs.get('locations', 'all') - + check_consistency(variables, str) + + if sorted(variables) != sorted(self.input_variables): + TypeError(f'Wrong variables for sampling. Variables ', + f'should be in {self.input_variables}.') + + # check consistency location if locations == 'all': locations = [condition for condition in self.conditions] + else: + check_consistency(locations, str) + + if sorted(locations) != sorted(self.conditions): + TypeError(f'Wrong locations for sampling. Location ', + f'should be in {self.conditions}.') + + # sampling for location in locations: condition = self.conditions[location] - samples = tuple(condition.location.sample( - argument['n'], - argument['mode'], - variables=argument['variables']) - for argument in arguments) + # we try to check if we have already sampled + try: + already_sampled = [self.input_pts[location]] + # if we have not sampled, a key error is thrown + except KeyError: + already_sampled = [] + + # if we have already sampled fully the condition + # but we want to sample again we set already_sampled + # to an empty list since we need to sample again, and + # self._have_sampled_points to False. + if self._have_sampled_points[location]: + already_sampled = [] + self._have_sampled_points[location] = False + + # build samples + samples = [condition.location.sample( + n=n, + mode=mode, + variables=variables) + ] + already_sampled pts = merge_tensors(samples) self.input_pts[location] = pts # setting the grad self.input_pts[location].requires_grad_(True) self.input_pts[location].retain_grad() - # the condition is sampled - self._have_sampled_points[location] = True + # the condition is sampled if input_pts contains all labels + if sorted(self.input_pts[location].labels) == sorted(self.input_variables): + self._have_sampled_points[location] = True @property def have_sampled_points(self): diff --git a/tests/test_problem.py b/tests/test_problem.py index d991c1b..41edea6 100644 --- a/tests/test_problem.py +++ b/tests/test_problem.py @@ -78,20 +78,26 @@ def test_discretise_domain(): poisson_problem.discretise_domain(n, 'lh', locations=['D']) assert poisson_problem.input_pts['D'].shape[0] == n -def test_sampling_all_args(): +def test_sampling_few_variables(): n = 10 - poisson_problem.discretise_domain(n, 'grid', locations=['D']) + poisson_problem.discretise_domain(n, 'grid', locations=['D'], variables=['x']) + assert poisson_problem.input_pts['D'].shape[1] == 1 + assert poisson_problem._have_sampled_points['D'] is False -def test_sampling_all_kwargs(): - n = 10 - poisson_problem.discretise_domain(n=n, mode='latin', locations=['D']) +# def test_sampling_all_args(): +# n = 10 +# poisson_problem.discretise_domain(n, 'grid', locations=['D']) -def test_sampling_dict(): - n = 10 - poisson_problem.discretise_domain( - {'variables': ['x', 'y'], 'mode': 'grid', 'n': n}, locations=['D']) +# def test_sampling_all_kwargs(): +# n = 10 +# poisson_problem.discretise_domain(n=n, mode='latin', locations=['D']) -def test_sampling_mixed_args_kwargs(): - n = 10 - with pytest.raises(ValueError): - poisson_problem.discretise_domain(n, mode='latin', locations=['D']) \ No newline at end of file +# def test_sampling_dict(): +# n = 10 +# poisson_problem.discretise_domain( +# {'variables': ['x', 'y'], 'mode': 'grid', 'n': n}, locations=['D']) + +# def test_sampling_mixed_args_kwargs(): +# n = 10 +# with pytest.raises(ValueError): +# poisson_problem.discretise_domain(n, mode='latin', locations=['D']) \ No newline at end of file