diff --git a/pina/collector.py b/pina/collector.py index 5ab1e1b..4c98bd8 100644 --- a/pina/collector.py +++ b/pina/collector.py @@ -62,7 +62,6 @@ class Collector: # condition now is ready self._is_conditions_ready[condition_name] = True - def store_sample_domains(self): """ # TODO: Add docstring @@ -70,7 +69,7 @@ class Collector: for condition_name in self.problem.conditions: condition = self.problem.conditions[condition_name] if not hasattr(condition, "domain"): - continue + continue samples = self.problem.discretised_domains[condition.domain] @@ -78,56 +77,3 @@ class Collector: 'input_points': samples, 'equation': condition.equation } - - # # get condition - # condition = self.problem.conditions[loc] - # condition_domain = condition.domain - # if isinstance(condition_domain, str): - # condition_domain = self.problem.domains[condition_domain] - # keys = ["input_points", "equation"] - # # if the condition is not ready, we get and store the data - # if not self._is_conditions_ready[loc]: - # # if it is the first time we sample - # if not self.data_collections[loc]: - # already_sampled = [] - # # if we have sampled the condition but not all variables - # else: - # already_sampled = [ - # self.data_collections[loc]['input_points'] - # ] - # # if the condition is ready but we want to sample again - # else: - # self._is_conditions_ready[loc] = False - # already_sampled = [] - # # get the samples - # samples = [ - # condition_domain.sample(n=n, mode=mode, - # variables=variables) - # ] + already_sampled - # pts = merge_tensors(samples) - # if set(pts.labels).issubset(sorted(self.problem.input_variables)): - # pts = pts.sort_labels() - # if sorted(pts.labels) == sorted(self.problem.input_variables): - # self._is_conditions_ready[loc] = True - # values = [pts, condition.equation] - # self.data_collections[loc] = dict(zip(keys, values)) - # else: - # raise RuntimeError( - # 'Try to sample variables which are not in problem defined ' - # 'in the problem') - - def add_points(self, new_points_dict): - """ - Add input points to a sampled condition - - :param new_points_dict: Dictonary of input points (condition_name: - LabelTensor) - :raises RuntimeError: if at least one condition is not already sampled - """ - for k, v in new_points_dict.items(): - if not self._is_conditions_ready[k]: - raise RuntimeError( - 'Cannot add points on a non sampled condition') - self.data_collections[k]['input_points'] = LabelTensor.vstack( - [self.data_collections[k][ - 'input_points'], v]) diff --git a/pina/problem/abstract_problem.py b/pina/problem/abstract_problem.py index 3a30e33..dee16a1 100644 --- a/pina/problem/abstract_problem.py +++ b/pina/problem/abstract_problem.py @@ -6,6 +6,7 @@ from ..domain import DomainInterface from ..condition.domain_equation_condition import DomainEquationCondition from ..condition import InputPointsEquationCondition from copy import deepcopy +from pina import LabelTensor class AbstractProblem(metaclass=ABCMeta): @@ -135,12 +136,11 @@ class AbstractProblem(metaclass=ABCMeta): """ The conditions of the problem. """ - return self._conditions + return self.conditions def discretise_domain(self, n, mode="random", - variables="all", domains="all"): """ Generate a set of points to span the `Location` of all the conditions of @@ -153,10 +153,8 @@ class AbstractProblem(metaclass=ABCMeta): Available modes include: random sampling, ``random``; latin hypercube sampling, ``latin`` or ``lh``; chebyshev sampling, ``chebyshev``; grid sampling ``grid``. - :param variables: problem's variables to be sampled, defaults to 'all'. - :type variables: str | list[str] - :param domain: problem's domain from where to sample, defaults to 'all'. - :type locations: str + :param domains: problem's domain from where to sample, defaults to 'all'. + :type domains: str | list[str] :Example: >>> pinn.discretise_domain(n=10, mode='grid') @@ -174,22 +172,12 @@ class AbstractProblem(metaclass=ABCMeta): # check consistecy n, mode, variables, locations check_consistency(n, int) check_consistency(mode, str) - check_consistency(variables, str) check_consistency(domains, (list, str)) # check correct sampling mode # if mode not in DomainInterface.available_sampling_modes: # raise TypeError(f"mode {mode} not valid.") - # check correct variables - if variables == "all": - variables = self.input_variables - for variable in variables: - if variable not in self.input_variables: - TypeError( - f"Wrong variables for sampling. Variables ", - f"should be in {self.input_variables}.", - ) # check correct location if domains == "all": domains = self.domains.keys() @@ -198,14 +186,16 @@ class AbstractProblem(metaclass=ABCMeta): for domain in domains: self.discretised_domains[domain] = ( - self.domains[domain].sample(n, mode, variables) + self.domains[domain].sample(n, mode) ) - # if not isinstance(self.conditions[loc], DomainEquationCondition): - # raise TypeError( - # f"Wrong locations passed, locations for sampling " - # f"should be in {[loc for loc in locations if isinstance(self.conditions[loc], DomainEquationCondition)]}.", - # ) - # store data - # self.collector.store_sample_domains() - # self.collector.store_sample_domains(n, mode, variables, domain) \ No newline at end of file + def add_points(self, new_points_dict): + """ + Add input points to a sampled condition + :param new_points_dict: Dictionary of input points (condition_name: + LabelTensor) + :raises RuntimeError: if at least one condition is not already sampled + """ + for k, v in new_points_dict.items(): + self.discretised_domains[k] = LabelTensor.vstack( + [self.discretised_domains[k], v]) diff --git a/tests/test_collector.py b/tests/test_collector.py index b8e8813..26967e8 100644 --- a/tests/test_collector.py +++ b/tests/test_collector.py @@ -10,6 +10,7 @@ from pina.equation.equation_factory import FixedValue from pina.operators import laplacian from pina.collector import Collector + # def test_supervised_tensor_collector(): # class SupervisedProblem(AbstractProblem): # output_variables = None @@ -37,6 +38,7 @@ def test_pinn_collector(): my_laplace = Equation(laplace_equation) in_ = LabelTensor(torch.tensor([[0., 1.]], requires_grad=True), ['x', 'y']) out_ = LabelTensor(torch.tensor([[0.]], requires_grad=True), ['u']) + class Poisson(SpatialProblem): output_variables = ['u'] spatial_domain = CartesianDomain({'x': [0, 1], 'y': [0, 1]}) @@ -78,7 +80,8 @@ def test_pinn_collector(): def poisson_sol(self, pts): return -(torch.sin(pts.extract(['x']) * torch.pi) * - torch.sin(pts.extract(['y']) * torch.pi)) / (2 * torch.pi**2) + torch.sin(pts.extract(['y']) * torch.pi)) / ( + 2 * torch.pi ** 2) truth_solution = poisson_sol @@ -91,30 +94,34 @@ def test_pinn_collector(): collector.store_fixed_data() collector.store_sample_domains() - for k,v in problem.conditions.items(): + for k, v in problem.conditions.items(): if isinstance(v, InputOutputPointsCondition): - assert list(collector.data_collections[k].keys()) == ['input_points', 'output_points'] + assert list(collector.data_collections[k].keys()) == [ + 'input_points', 'output_points'] - for k,v in problem.conditions.items(): + for k, v in problem.conditions.items(): if isinstance(v, DomainEquationCondition): - assert list(collector.data_collections[k].keys()) == ['input_points', 'equation'] + assert list(collector.data_collections[k].keys()) == [ + 'input_points', 'equation'] + def test_supervised_graph_collector(): - pos = torch.rand((100,3)) - x = [torch.rand((100,3)) for _ in range(10)] + pos = torch.rand((100, 3)) + x = [torch.rand((100, 3)) for _ in range(10)] graph_list_1 = RadiusGraph(pos=pos, x=x, build_edge_attr=True, r=.4) - out_1 = torch.rand((10,100,3)) - pos = torch.rand((50,3)) - x = [torch.rand((50,3)) for _ in range(10)] + out_1 = torch.rand((10, 100, 3)) + pos = torch.rand((50, 3)) + x = [torch.rand((50, 3)) for _ in range(10)] graph_list_2 = RadiusGraph(pos=pos, x=x, build_edge_attr=True, r=.4) - out_2 = torch.rand((10,50,3)) + out_2 = torch.rand((10, 50, 3)) + class SupervisedProblem(AbstractProblem): output_variables = None conditions = { - 'data1' : Condition(input_points=graph_list_1, - output_points=out_1), - 'data2' : Condition(input_points=graph_list_2, - output_points=out_2), + 'data1': Condition(input_points=graph_list_1, + output_points=out_1), + 'data2': Condition(input_points=graph_list_2, + output_points=out_2), } problem = SupervisedProblem() diff --git a/tests/test_problem.py b/tests/test_problem.py index 1698588..4930b82 100644 --- a/tests/test_problem.py +++ b/tests/test_problem.py @@ -1,6 +1,8 @@ import torch import pytest from pina.problem.zoo import Poisson2DSquareProblem as Poisson +from pina import LabelTensor + def test_discretise_domain(): n = 10 @@ -14,7 +16,7 @@ def test_discretise_domain(): assert poisson_problem.discretised_domains[b].shape[0] == n poisson_problem.discretise_domain(n, 'grid', domains=['D']) - assert poisson_problem.discretised_domains['D'].shape[0] == n**2 + assert poisson_problem.discretised_domains['D'].shape[0] == n ** 2 poisson_problem.discretise_domain(n, 'random', domains=['D']) assert poisson_problem.discretised_domains['D'].shape[0] == n @@ -25,6 +27,8 @@ def test_discretise_domain(): assert poisson_problem.discretised_domains['D'].shape[0] == n poisson_problem.discretise_domain(n) + + ''' def test_sampling_few_variables(): n = 10 @@ -36,8 +40,8 @@ def test_sampling_few_variables(): assert poisson_problem.discretised_domains['D'].shape[1] == 1 ''' -def test_variables_correct_order_sampling(): +def test_variables_correct_order_sampling(): n = 10 poisson_problem = Poisson() poisson_problem.discretise_domain(n, @@ -50,15 +54,15 @@ def test_variables_correct_order_sampling(): assert poisson_problem.discretised_domains['D'].labels == sorted( poisson_problem.input_variables) -# def test_add_points(): -# poisson_problem = Poisson() -# poisson_problem.discretise_domain(0, -# 'random', -# domains=['D'], -# variables=['x', 'y']) -# new_pts = LabelTensor(torch.tensor([[0.5, -0.5]]), labels=['x', 'y']) -# poisson_problem.add_points({'D': new_pts}) -# assert torch.isclose(poisson_problem.discretised_domain['D'].extract('x'), -# new_pts.extract('x')) -# assert torch.isclose(poisson_problem.discretised_domain['D'].extract('y'), -# new_pts.extract('y')) \ No newline at end of file + +def test_add_points(): + poisson_problem = Poisson() + poisson_problem.discretise_domain(0, + 'random', + domains=['D']) + new_pts = LabelTensor(torch.tensor([[0.5, -0.5]]), labels=['x', 'y']) + poisson_problem.add_points({'D': new_pts}) + assert torch.isclose(poisson_problem.discretised_domains['D'].extract('x'), + new_pts.extract('x')) + assert torch.isclose(poisson_problem.discretised_domains['D'].extract('y'), + new_pts.extract('y'))