Use Poisson problem from problems zoo in test_problem and minor fix in AbstractProblem

This commit is contained in:
FilippoOlivo
2025-02-06 16:08:51 +01:00
committed by Nicola Demo
parent 84775849d1
commit c4749efc8b
4 changed files with 18 additions and 91 deletions

View File

@@ -65,12 +65,12 @@ class Collector:
def store_sample_domains(self): def store_sample_domains(self):
""" """
Add # TODO: Add docstring
""" """
for condition_name in self.problem.conditions: for condition_name in self.problem.conditions:
condition = self.problem.conditions[condition_name] condition = self.problem.conditions[condition_name]
if not hasattr(condition, "domain"): if not hasattr(condition, "domain"):
continue continue
samples = self.problem.discretised_domains[condition.domain] samples = self.problem.discretised_domains[condition.domain]

View File

@@ -4,14 +4,14 @@ from abc import ABCMeta, abstractmethod
from ..utils import check_consistency from ..utils import check_consistency
from ..domain import DomainInterface from ..domain import DomainInterface
from ..condition.domain_equation_condition import DomainEquationCondition from ..condition.domain_equation_condition import DomainEquationCondition
from ..collector import Collector from ..condition import InputPointsEquationCondition
from copy import deepcopy from copy import deepcopy
class AbstractProblem(metaclass=ABCMeta): class AbstractProblem(metaclass=ABCMeta):
""" """
The abstract `AbstractProblem` class. All the class defining a PINA Problem The abstract `AbstractProblem` class. All the class defining a PINA Problem
should be inheritied from this class. should be inherited from this class.
In the definition of a PINA problem, the fundamental elements are: In the definition of a PINA problem, the fundamental elements are:
the output variables, the condition(s), and the domain(s) where the the output variables, the condition(s), and the domain(s) where the
@@ -27,21 +27,18 @@ class AbstractProblem(metaclass=ABCMeta):
for condition_name in self.conditions: for condition_name in self.conditions:
self.conditions[condition_name].problem = self self.conditions[condition_name].problem = self
# store in collector all the available fixed points
# note that some points could not be stored at this stage (e.g. when
# sampling locations). To check that all data points are ready for
# training all type self.collector.full, which returns true if all
# points are ready.
# self.collector.store_fixed_data()
self._batching_dimension = 0 self._batching_dimension = 0
# Store in domains dict all the domains object directly passed to
# ConditionInterface. Done for back compatibility with PINA <0.2
if not hasattr(self, "domains"): if not hasattr(self, "domains"):
self.domains = {} self.domains = {}
for k, v in self.conditions.items(): for cond_name, cond in self.conditions.items():
if isinstance(v, DomainEquationCondition): if isinstance(cond, (DomainEquationCondition,
self.domains[k] = v.domain InputPointsEquationCondition)):
self.conditions[k] = DomainEquationCondition( if isinstance(cond.domain, DomainInterface):
domain=v.domain, equation=v.equation) self.domains[cond_name] = cond.domain
cond.domain = cond_name
# @property # @property
# def collector(self): # def collector(self):
@@ -116,7 +113,6 @@ class AbstractProblem(metaclass=ABCMeta):
if hasattr(self, "parameters"): if hasattr(self, "parameters"):
variables += self.parameters variables += self.parameters
return variables return variables
@input_variables.setter @input_variables.setter
@@ -197,9 +193,7 @@ class AbstractProblem(metaclass=ABCMeta):
domains = self.domains.keys() domains = self.domains.keys()
elif not isinstance(domains, (list)): elif not isinstance(domains, (list)):
domains = [domains] domains = [domains]
print(domains)
print(self.domains)
for domain in domains: for domain in domains:
self.discretised_domains[domain] = ( self.discretised_domains[domain] = (
self.domains[domain].sample(n, mode, variables) self.domains[domain].sample(n, mode, variables)

View File

@@ -99,7 +99,6 @@ def test_pinn_collector():
if isinstance(v, DomainEquationCondition): if isinstance(v, DomainEquationCondition):
assert list(collector.data_collections[k].keys()) == ['input_points', 'equation'] assert list(collector.data_collections[k].keys()) == ['input_points', 'equation']
def test_supervised_graph_collector(): def test_supervised_graph_collector():
pos = torch.rand((100,3)) pos = torch.rand((100,3))
x = [torch.rand((100,3)) for _ in range(10)] x = [torch.rand((100,3)) for _ in range(10)]

View File

@@ -1,76 +1,11 @@
import torch import torch
import pytest import pytest
from pina.problem.zoo import Poisson2DSquareProblem as Poisson
from pina.problem import SpatialProblem
from pina.operators import laplacian
from pina import LabelTensor, Condition
from pina.domain import CartesianDomain
from pina.equation.equation import Equation
from pina.equation.equation_factory import FixedValue
def laplace_equation(input_, output_):
force_term = (torch.sin(input_.extract(['x']) * torch.pi) *
torch.sin(input_.extract(['y']) * torch.pi))
delta_u = laplacian(output_.extract(['u']), input_)
return delta_u - force_term
my_laplace = Equation(laplace_equation)
in_ = LabelTensor(torch.tensor([[0., 1.]], requires_grad=True), ['x', 'y'])
out_ = LabelTensor(torch.tensor([[0.]], requires_grad=True), ['u'])
class Poisson(SpatialProblem):
output_variables = ['u']
spatial_domain = CartesianDomain({'x': [0, 1], 'y': [0, 1]})
conditions = {
'gamma1':
Condition(domain=CartesianDomain({
'x': [0, 1],
'y': 1
}),
equation=FixedValue(0.0)),
'gamma2':
Condition(domain=CartesianDomain({
'x': [0, 1],
'y': 0
}),
equation=FixedValue(0.0)),
'gamma3':
Condition(domain=CartesianDomain({
'x': 1,
'y': [0, 1]
}),
equation=FixedValue(0.0)),
'gamma4':
Condition(domain=CartesianDomain({
'x': 0,
'y': [0, 1]
}),
equation=FixedValue(0.0)),
'D':
Condition(domain=CartesianDomain({
'x': [0, 1],
'y': [0, 1]
}),
equation=my_laplace),
'data':
Condition(input_points=in_, output_points=out_)
}
def poisson_sol(self, pts):
return -(torch.sin(pts.extract(['x']) * torch.pi) *
torch.sin(pts.extract(['y']) * torch.pi)) / (2 * torch.pi**2)
truth_solution = poisson_sol
def test_discretise_domain(): def test_discretise_domain():
n = 10 n = 10
poisson_problem = Poisson() poisson_problem = Poisson()
boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4'] boundaries = ['g1', 'g2', 'g3', 'g4']
poisson_problem.discretise_domain(n, 'grid', domains=boundaries) poisson_problem.discretise_domain(n, 'grid', domains=boundaries)
for b in boundaries: for b in boundaries:
assert poisson_problem.discretised_domains[b].shape[0] == n assert poisson_problem.discretised_domains[b].shape[0] == n
@@ -90,8 +25,7 @@ def test_discretise_domain():
assert poisson_problem.discretised_domains['D'].shape[0] == n assert poisson_problem.discretised_domains['D'].shape[0] == n
poisson_problem.discretise_domain(n) poisson_problem.discretise_domain(n)
'''
def test_sampling_few_variables(): def test_sampling_few_variables():
n = 10 n = 10
poisson_problem = Poisson() poisson_problem = Poisson()
@@ -100,9 +34,10 @@ def test_sampling_few_variables():
domains=['D'], domains=['D'],
variables=['x']) variables=['x'])
assert poisson_problem.discretised_domains['D'].shape[1] == 1 assert poisson_problem.discretised_domains['D'].shape[1] == 1
'''
def test_variables_correct_order_sampling(): def test_variables_correct_order_sampling():
n = 10 n = 10
poisson_problem = Poisson() poisson_problem = Poisson()
poisson_problem.discretise_domain(n, poisson_problem.discretise_domain(n,
@@ -115,7 +50,6 @@ def test_variables_correct_order_sampling():
assert poisson_problem.discretised_domains['D'].labels == sorted( assert poisson_problem.discretised_domains['D'].labels == sorted(
poisson_problem.input_variables) poisson_problem.input_variables)
# def test_add_points(): # def test_add_points():
# poisson_problem = Poisson() # poisson_problem = Poisson()
# poisson_problem.discretise_domain(0, # poisson_problem.discretise_domain(0,