From 195224794f74c186667fcc733277d88b39c05cd5 Mon Sep 17 00:00:00 2001 From: FilippoOlivo Date: Fri, 7 Feb 2025 15:56:04 +0100 Subject: [PATCH] Implement custom sampling logic --- pina/domain/cartesian.py | 10 +++-- pina/problem/abstract_problem.py | 65 +++++++++++++++++++++++++------- tests/test_collector.py | 35 ++++++++--------- tests/test_problem.py | 50 ++++++++++++++++++------ 4 files changed, 114 insertions(+), 46 deletions(-) diff --git a/pina/domain/cartesian.py b/pina/domain/cartesian.py index e2086e4..48e5e4d 100644 --- a/pina/domain/cartesian.py +++ b/pina/domain/cartesian.py @@ -160,10 +160,10 @@ class CartesianDomain(DomainInterface): pts_variable.labels = [variable] tmp.append(pts_variable) - - result = tmp[0] - for i in tmp[1:]: - result = result.append(i, mode="cross") + if tmp: + result = tmp[0] + for i in tmp[1:]: + result = result.append(i, mode="cross") for variable in variables: if variable in self.fixed_.keys(): @@ -242,6 +242,8 @@ class CartesianDomain(DomainInterface): if self.fixed_ and (not self.range_): return _single_points_sample(n, variables) + if isinstance(variables, str) and variables in self.fixed_.keys(): + return _single_points_sample(n, variables) if mode in ["grid", "chebyshev"]: return _1d_sampler(n, mode, variables).extract(variables) diff --git a/pina/problem/abstract_problem.py b/pina/problem/abstract_problem.py index dee16a1..7898c81 100644 --- a/pina/problem/abstract_problem.py +++ b/pina/problem/abstract_problem.py @@ -2,11 +2,12 @@ from abc import ABCMeta, abstractmethod from ..utils import check_consistency -from ..domain import DomainInterface +from ..domain import DomainInterface, CartesianDomain from ..condition.domain_equation_condition import DomainEquationCondition from ..condition import InputPointsEquationCondition from copy import deepcopy -from pina import LabelTensor +from .. import LabelTensor +from ..utils import merge_tensors class AbstractProblem(metaclass=ABCMeta): @@ -21,7 +22,7 @@ class AbstractProblem(metaclass=ABCMeta): def __init__(self): - self.discretised_domains = {} + self._discretised_domains = {} # create collector to manage problem data # create hook conditions <-> problems @@ -53,6 +54,10 @@ class AbstractProblem(metaclass=ABCMeta): def batching_dimension(self, value): self._batching_dimension = value + @property + def discretised_domains(self): + return self._discretised_domains + # TODO this should be erase when dataloading will interface collector, # kept only for back compatibility @property @@ -62,7 +67,7 @@ class AbstractProblem(metaclass=ABCMeta): if hasattr(cond, "input_points"): to_return[cond_name] = cond.input_points elif hasattr(cond, "domain"): - to_return[cond_name] = self.discretised_domains[cond.domain] + to_return[cond_name] = self._discretised_domains[cond.domain] return to_return def __deepcopy__(self, memo): @@ -139,9 +144,10 @@ class AbstractProblem(metaclass=ABCMeta): return self.conditions def discretise_domain(self, - n, + n=None, mode="random", - domains="all"): + domains="all", + sample_rules=None): """ Generate a set of points to span the `Location` of all the conditions of the problem. @@ -153,6 +159,8 @@ class AbstractProblem(metaclass=ABCMeta): Available modes include: random sampling, ``random``; latin hypercube sampling, ``latin`` or ``lh``; chebyshev sampling, ``chebyshev``; grid sampling ``grid``. + :param variables: variable(s) to sample, defaults to 'all'. + :type variables: str | list[str] :param domains: problem's domain from where to sample, defaults to 'all'. :type domains: str | list[str] @@ -170,25 +178,56 @@ class AbstractProblem(metaclass=ABCMeta): """ # check consistecy n, mode, variables, locations - check_consistency(n, int) - check_consistency(mode, str) + if sample_rules is not None: + check_consistency(sample_rules, dict) + if mode is not None: + check_consistency(mode, str) check_consistency(domains, (list, str)) - # check correct sampling mode - # if mode not in DomainInterface.available_sampling_modes: - # raise TypeError(f"mode {mode} not valid.") - # check correct location if domains == "all": domains = self.domains.keys() elif not isinstance(domains, (list)): domains = [domains] + if n is not None and sample_rules is None: + self._apply_default_discretization(n, mode, domains) + if n is None and sample_rules is not None: + self._apply_custom_discretization(sample_rules, domains) + elif n is not None and sample_rules is not None: + raise RuntimeError( + "You can't specify both n and sample_rules at the same time." + ) + elif n is None and sample_rules is None: + raise RuntimeError( + "You have to specify either n or sample_rules." + ) + def _apply_default_discretization(self, n, mode, domains): for domain in domains: self.discretised_domains[domain] = ( - self.domains[domain].sample(n, mode) + self.domains[domain].sample(n, mode).sort_labels() ) + def _apply_custom_discretization(self, sample_rules, domains): + if sorted(list(sample_rules.keys())) != sorted(self.input_variables): + raise RuntimeError( + "The keys of the sample_rules dictionary must be the same as " + "the input variables." + ) + for domain in domains: + if not isinstance(self.domains[domain], CartesianDomain): + raise RuntimeError( + "Custom discretisation can be applied only on Cartesian " + "domains") + discretised_tensor = [] + for var, rules in sample_rules.items(): + n, mode = rules['n'], rules['mode'] + points = self.domains[domain].sample(n, mode, var) + discretised_tensor.append(points) + + self.discretised_domains[domain] = merge_tensors( + discretised_tensor).sort_labels() + def add_points(self, new_points_dict): """ Add input points to a sampled condition diff --git a/tests/test_collector.py b/tests/test_collector.py index 26967e8..f55997f 100644 --- a/tests/test_collector.py +++ b/tests/test_collector.py @@ -11,22 +11,23 @@ from pina.operators import laplacian from pina.collector import Collector -# def test_supervised_tensor_collector(): -# class SupervisedProblem(AbstractProblem): -# output_variables = None -# conditions = { -# 'data1' : Condition(input_points=torch.rand((10,2)), -# output_points=torch.rand((10,2))), -# 'data2' : Condition(input_points=torch.rand((20,2)), -# output_points=torch.rand((20,2))), -# 'data3' : Condition(input_points=torch.rand((30,2)), -# output_points=torch.rand((30,2))), -# } -# problem = SupervisedProblem() -# collector = Collector(problem) -# for v in collector.conditions_name.values(): -# assert v in problem.conditions.keys() -# assert all(collector._is_conditions_ready.values()) +def test_supervised_tensor_collector(): + class SupervisedProblem(AbstractProblem): + output_variables = None + conditions = { + 'data1': Condition(input_points=torch.rand((10, 2)), + output_points=torch.rand((10, 2))), + 'data2': Condition(input_points=torch.rand((20, 2)), + output_points=torch.rand((20, 2))), + 'data3': Condition(input_points=torch.rand((30, 2)), + output_points=torch.rand((30, 2))), + } + + problem = SupervisedProblem() + collector = Collector(problem) + for v in collector.conditions_name.values(): + assert v in problem.conditions.keys() + def test_pinn_collector(): def laplace_equation(input_, output_): @@ -81,7 +82,7 @@ def test_pinn_collector(): def poisson_sol(self, pts): return -(torch.sin(pts.extract(['x']) * torch.pi) * torch.sin(pts.extract(['y']) * torch.pi)) / ( - 2 * torch.pi ** 2) + 2 * torch.pi ** 2) truth_solution = poisson_sol diff --git a/tests/test_problem.py b/tests/test_problem.py index 4930b82..30122d4 100644 --- a/tests/test_problem.py +++ b/tests/test_problem.py @@ -2,6 +2,8 @@ import torch import pytest from pina.problem.zoo import Poisson2DSquareProblem as Poisson from pina import LabelTensor +from pina.domain import Union +from pina.domain import CartesianDomain def test_discretise_domain(): @@ -29,18 +31,6 @@ def test_discretise_domain(): poisson_problem.discretise_domain(n) -''' -def test_sampling_few_variables(): - n = 10 - poisson_problem = Poisson() - poisson_problem.discretise_domain(n, - 'grid', - domains=['D'], - variables=['x']) - assert poisson_problem.discretised_domains['D'].shape[1] == 1 -''' - - def test_variables_correct_order_sampling(): n = 10 poisson_problem = Poisson() @@ -66,3 +56,39 @@ def test_add_points(): new_pts.extract('x')) assert torch.isclose(poisson_problem.discretised_domains['D'].extract('y'), new_pts.extract('y')) + +@pytest.mark.parametrize( + "mode", + [ + 'random', + 'grid' + ] +) +def test_custom_sampling_logic(mode): + poisson_problem = Poisson() + sampling_rules = { + 'x': {'n': 100, 'mode': mode}, + 'y': {'n': 50, 'mode': mode} + } + poisson_problem.discretise_domain(sample_rules=sampling_rules) + for domain in ['g1', 'g2', 'g3', 'g4']: + assert poisson_problem.discretised_domains[domain].shape[0] == 100 * 50 + assert poisson_problem.discretised_domains[domain].labels == ['x', 'y'] + +@pytest.mark.parametrize( + "mode", + [ + 'random', + 'grid' + ] +) +def test_wrong_custom_sampling_logic(mode): + d2 = CartesianDomain({'x': [1,2], 'y': [0,1] }) + poisson_problem = Poisson() + poisson_problem.domains['D'] = Union([poisson_problem.domains['D'], d2]) + sampling_rules = { + 'x': {'n': 100, 'mode': mode}, + 'y': {'n': 50, 'mode': mode} + } + with pytest.raises(RuntimeError): + poisson_problem.discretise_domain(sample_rules=sampling_rules) \ No newline at end of file