diff --git a/pina/collector.py b/pina/collector.py index 2aa4b3f..5ab1e1b 100644 --- a/pina/collector.py +++ b/pina/collector.py @@ -65,12 +65,12 @@ class Collector: def store_sample_domains(self): """ - Add + # TODO: Add docstring """ for condition_name in self.problem.conditions: condition = self.problem.conditions[condition_name] if not hasattr(condition, "domain"): - continue + continue samples = self.problem.discretised_domains[condition.domain] diff --git a/pina/problem/abstract_problem.py b/pina/problem/abstract_problem.py index 6cadb10..2273c52 100644 --- a/pina/problem/abstract_problem.py +++ b/pina/problem/abstract_problem.py @@ -4,14 +4,14 @@ from abc import ABCMeta, abstractmethod from ..utils import check_consistency from ..domain import DomainInterface from ..condition.domain_equation_condition import DomainEquationCondition -from ..collector import Collector +from ..condition import InputPointsEquationCondition from copy import deepcopy class AbstractProblem(metaclass=ABCMeta): """ The abstract `AbstractProblem` class. All the class defining a PINA Problem - should be inheritied from this class. + should be inherited from this class. In the definition of a PINA problem, the fundamental elements are: the output variables, the condition(s), and the domain(s) where the @@ -27,21 +27,18 @@ class AbstractProblem(metaclass=ABCMeta): for condition_name in self.conditions: self.conditions[condition_name].problem = self - # store in collector all the available fixed points - # note that some points could not be stored at this stage (e.g. when - # sampling locations). To check that all data points are ready for - # training all type self.collector.full, which returns true if all - # points are ready. - # self.collector.store_fixed_data() self._batching_dimension = 0 + # Store in domains dict all the domains object directly passed to + # ConditionInterface. Done for back compatibility with PINA <0.2 if not hasattr(self, "domains"): self.domains = {} - for k, v in self.conditions.items(): - if isinstance(v, DomainEquationCondition): - self.domains[k] = v.domain - self.conditions[k] = DomainEquationCondition( - domain=v.domain, equation=v.equation) + for cond_name, cond in self.conditions.items(): + if isinstance(cond, (DomainEquationCondition, + InputPointsEquationCondition)): + if isinstance(cond.domain, DomainInterface): + self.domains[cond_name] = cond.domain + cond.domain = cond_name # @property # def collector(self): @@ -116,7 +113,6 @@ class AbstractProblem(metaclass=ABCMeta): if hasattr(self, "parameters"): variables += self.parameters - return variables @input_variables.setter @@ -197,9 +193,7 @@ class AbstractProblem(metaclass=ABCMeta): domains = self.domains.keys() elif not isinstance(domains, (list)): domains = [domains] - - print(domains) - print(self.domains) + for domain in domains: self.discretised_domains[domain] = ( self.domains[domain].sample(n, mode, variables) diff --git a/tests/test_collector.py b/tests/test_collector.py index 7df3a1a..b8e8813 100644 --- a/tests/test_collector.py +++ b/tests/test_collector.py @@ -99,7 +99,6 @@ def test_pinn_collector(): if isinstance(v, DomainEquationCondition): assert list(collector.data_collections[k].keys()) == ['input_points', 'equation'] - def test_supervised_graph_collector(): pos = torch.rand((100,3)) x = [torch.rand((100,3)) for _ in range(10)] diff --git a/tests/test_problem.py b/tests/test_problem.py index 6363aee..1698588 100644 --- a/tests/test_problem.py +++ b/tests/test_problem.py @@ -1,76 +1,11 @@ import torch import pytest - -from pina.problem import SpatialProblem -from pina.operators import laplacian -from pina import LabelTensor, Condition -from pina.domain import CartesianDomain -from pina.equation.equation import Equation -from pina.equation.equation_factory import FixedValue - - -def laplace_equation(input_, output_): - force_term = (torch.sin(input_.extract(['x']) * torch.pi) * - torch.sin(input_.extract(['y']) * torch.pi)) - delta_u = laplacian(output_.extract(['u']), input_) - return delta_u - force_term - - -my_laplace = Equation(laplace_equation) -in_ = LabelTensor(torch.tensor([[0., 1.]], requires_grad=True), ['x', 'y']) -out_ = LabelTensor(torch.tensor([[0.]], requires_grad=True), ['u']) - - -class Poisson(SpatialProblem): - output_variables = ['u'] - spatial_domain = CartesianDomain({'x': [0, 1], 'y': [0, 1]}) - - conditions = { - 'gamma1': - Condition(domain=CartesianDomain({ - 'x': [0, 1], - 'y': 1 - }), - equation=FixedValue(0.0)), - 'gamma2': - Condition(domain=CartesianDomain({ - 'x': [0, 1], - 'y': 0 - }), - equation=FixedValue(0.0)), - 'gamma3': - Condition(domain=CartesianDomain({ - 'x': 1, - 'y': [0, 1] - }), - equation=FixedValue(0.0)), - 'gamma4': - Condition(domain=CartesianDomain({ - 'x': 0, - 'y': [0, 1] - }), - equation=FixedValue(0.0)), - 'D': - Condition(domain=CartesianDomain({ - 'x': [0, 1], - 'y': [0, 1] - }), - equation=my_laplace), - 'data': - Condition(input_points=in_, output_points=out_) - } - - def poisson_sol(self, pts): - return -(torch.sin(pts.extract(['x']) * torch.pi) * - torch.sin(pts.extract(['y']) * torch.pi)) / (2 * torch.pi**2) - - truth_solution = poisson_sol - +from pina.problem.zoo import Poisson2DSquareProblem as Poisson def test_discretise_domain(): n = 10 poisson_problem = Poisson() - boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4'] + boundaries = ['g1', 'g2', 'g3', 'g4'] poisson_problem.discretise_domain(n, 'grid', domains=boundaries) for b in boundaries: assert poisson_problem.discretised_domains[b].shape[0] == n @@ -90,8 +25,7 @@ def test_discretise_domain(): assert poisson_problem.discretised_domains['D'].shape[0] == n poisson_problem.discretise_domain(n) - - +''' def test_sampling_few_variables(): n = 10 poisson_problem = Poisson() @@ -100,9 +34,10 @@ def test_sampling_few_variables(): domains=['D'], variables=['x']) assert poisson_problem.discretised_domains['D'].shape[1] == 1 - +''' def test_variables_correct_order_sampling(): + n = 10 poisson_problem = Poisson() poisson_problem.discretise_domain(n, @@ -115,7 +50,6 @@ def test_variables_correct_order_sampling(): assert poisson_problem.discretised_domains['D'].labels == sorted( poisson_problem.input_variables) - # def test_add_points(): # poisson_problem = Poisson() # poisson_problem.discretise_domain(0,