fix tests

This commit is contained in:
Nicola Demo
2025-02-06 16:07:26 +01:00
parent ee7ad797bd
commit 84775849d1
7 changed files with 78 additions and 91 deletions

View File

@@ -62,6 +62,7 @@ class Collector:
# condition now is ready # condition now is ready
self._is_conditions_ready[condition_name] = True self._is_conditions_ready[condition_name] = True
def store_sample_domains(self): def store_sample_domains(self):
""" """
Add Add

View File

@@ -35,6 +35,14 @@ class AbstractProblem(metaclass=ABCMeta):
# self.collector.store_fixed_data() # self.collector.store_fixed_data()
self._batching_dimension = 0 self._batching_dimension = 0
if not hasattr(self, "domains"):
self.domains = {}
for k, v in self.conditions.items():
if isinstance(v, DomainEquationCondition):
self.domains[k] = v.domain
self.conditions[k] = DomainEquationCondition(
domain=v.domain, equation=v.equation)
# @property # @property
# def collector(self): # def collector(self):
# return self._collector # return self._collector
@@ -190,6 +198,8 @@ class AbstractProblem(metaclass=ABCMeta):
elif not isinstance(domains, (list)): elif not isinstance(domains, (list)):
domains = [domains] domains = [domains]
print(domains)
print(self.domains)
for domain in domains: for domain in domains:
self.discretised_domains[domain] = ( self.discretised_domains[domain] = (
self.domains[domain].sample(n, mode, variables) self.domains[domain].sample(n, mode, variables)

View File

@@ -7,10 +7,10 @@ from pina.callbacks import R3Refinement
# make the problem # make the problem
poisson_problem = Poisson() poisson_problem = Poisson()
boundaries = ['nil_g1', 'nil_g2', 'nil_g3', 'nil_g4'] boundaries = ['g1', 'g2', 'g3', 'g4']
n = 10 n = 10
poisson_problem.discretise_domain(n, 'grid', locations=boundaries) poisson_problem.discretise_domain(n, 'grid', domains=boundaries)
poisson_problem.discretise_domain(n, 'grid', locations='laplace_D') poisson_problem.discretise_domain(n, 'grid', domains='D')
model = FeedForward(len(poisson_problem.input_variables), model = FeedForward(len(poisson_problem.input_variables),
len(poisson_problem.output_variables)) len(poisson_problem.output_variables))

View File

@@ -7,10 +7,10 @@ from pina.problem.zoo import Poisson2DSquareProblem as Poisson
# make the problem # make the problem
poisson_problem = Poisson() poisson_problem = Poisson()
boundaries = ['nil_g1', 'nil_g2', 'nil_g3', 'nil_g4'] boundaries = ['g1', 'g2', 'g3', 'g4']
n = 10 n = 10
poisson_problem.discretise_domain(n, 'grid', locations=boundaries) poisson_problem.discretise_domain(n, 'grid', domains=boundaries)
poisson_problem.discretise_domain(n, 'grid', locations='laplace_D') poisson_problem.discretise_domain(n, 'grid', domains='D')
model = FeedForward(len(poisson_problem.input_variables), model = FeedForward(len(poisson_problem.input_variables),
len(poisson_problem.output_variables)) len(poisson_problem.output_variables))

View File

@@ -10,10 +10,10 @@ from pina.optim import TorchOptimizer
# make the problem # make the problem
poisson_problem = Poisson() poisson_problem = Poisson()
boundaries = ['nil_g1', 'nil_g2', 'nil_g3', 'nil_g4'] boundaries = ['g1', 'g2', 'g3', 'g4']
n = 10 n = 10
poisson_problem.discretise_domain(n, 'grid', locations=boundaries) poisson_problem.discretise_domain(n, 'grid', domains=boundaries)
poisson_problem.discretise_domain(n, 'grid', locations='laplace_D') poisson_problem.discretise_domain(n, 'grid', domains='D')
model = FeedForward(len(poisson_problem.input_variables), model = FeedForward(len(poisson_problem.input_variables),
len(poisson_problem.output_variables)) len(poisson_problem.output_variables))

View File

@@ -8,23 +8,24 @@ from pina.domain import CartesianDomain
from pina.equation.equation import Equation from pina.equation.equation import Equation
from pina.equation.equation_factory import FixedValue from pina.equation.equation_factory import FixedValue
from pina.operators import laplacian from pina.operators import laplacian
from pina.collector import Collector
def test_supervised_tensor_collector(): # def test_supervised_tensor_collector():
class SupervisedProblem(AbstractProblem): # class SupervisedProblem(AbstractProblem):
output_variables = None # output_variables = None
conditions = { # conditions = {
'data1' : Condition(input_points=torch.rand((10,2)), # 'data1' : Condition(input_points=torch.rand((10,2)),
output_points=torch.rand((10,2))), # output_points=torch.rand((10,2))),
'data2' : Condition(input_points=torch.rand((20,2)), # 'data2' : Condition(input_points=torch.rand((20,2)),
output_points=torch.rand((20,2))), # output_points=torch.rand((20,2))),
'data3' : Condition(input_points=torch.rand((30,2)), # 'data3' : Condition(input_points=torch.rand((30,2)),
output_points=torch.rand((30,2))), # output_points=torch.rand((30,2))),
} # }
problem = SupervisedProblem() # problem = SupervisedProblem()
collector = problem.collector # collector = Collector(problem)
for v in collector.conditions_name.values(): # for v in collector.conditions_name.values():
assert v in problem.conditions.keys() # assert v in problem.conditions.keys()
assert all(collector._is_conditions_ready.values()) # assert all(collector._is_conditions_ready.values())
def test_pinn_collector(): def test_pinn_collector():
def laplace_equation(input_, output_): def laplace_equation(input_, output_):
@@ -82,19 +83,18 @@ def test_pinn_collector():
truth_solution = poisson_sol truth_solution = poisson_sol
problem = Poisson() problem = Poisson()
collector = problem.collector boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
problem.discretise_domain(10, 'grid', domains=boundaries)
problem.discretise_domain(10, 'grid', domains='D')
collector = Collector(problem)
collector.store_fixed_data()
collector.store_sample_domains()
for k,v in problem.conditions.items(): for k,v in problem.conditions.items():
if isinstance(v, InputOutputPointsCondition): if isinstance(v, InputOutputPointsCondition):
assert collector._is_conditions_ready[k] == True
assert list(collector.data_collections[k].keys()) == ['input_points', 'output_points'] assert list(collector.data_collections[k].keys()) == ['input_points', 'output_points']
else:
assert collector._is_conditions_ready[k] == False
assert collector.data_collections[k] == {}
boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
problem.discretise_domain(10, 'grid', locations=boundaries)
problem.discretise_domain(10, 'grid', locations='D')
assert all(collector._is_conditions_ready.values())
for k,v in problem.conditions.items(): for k,v in problem.conditions.items():
if isinstance(v, DomainEquationCondition): if isinstance(v, DomainEquationCondition):
assert list(collector.data_collections[k].keys()) == ['input_points', 'equation'] assert list(collector.data_collections[k].keys()) == ['input_points', 'equation']
@@ -119,7 +119,8 @@ def test_supervised_graph_collector():
} }
problem = SupervisedProblem() problem = SupervisedProblem()
collector = problem.collector collector = Collector(problem)
assert all(collector._is_conditions_ready.values()) collector.store_fixed_data()
# assert all(collector._is_conditions_ready.values())
for v in collector.conditions_name.values(): for v in collector.conditions_name.values():
assert v in problem.conditions.keys() assert v in problem.conditions.keys()

View File

@@ -71,23 +71,23 @@ def test_discretise_domain():
n = 10 n = 10
poisson_problem = Poisson() poisson_problem = Poisson()
boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4'] boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
poisson_problem.discretise_domain(n, 'grid', locations=boundaries) poisson_problem.discretise_domain(n, 'grid', domains=boundaries)
for b in boundaries: for b in boundaries:
assert poisson_problem.input_pts[b].shape[0] == n assert poisson_problem.discretised_domains[b].shape[0] == n
poisson_problem.discretise_domain(n, 'random', locations=boundaries) poisson_problem.discretise_domain(n, 'random', domains=boundaries)
for b in boundaries: for b in boundaries:
assert poisson_problem.input_pts[b].shape[0] == n assert poisson_problem.discretised_domains[b].shape[0] == n
poisson_problem.discretise_domain(n, 'grid', locations=['D']) poisson_problem.discretise_domain(n, 'grid', domains=['D'])
assert poisson_problem.input_pts['D'].shape[0] == n**2 assert poisson_problem.discretised_domains['D'].shape[0] == n**2
poisson_problem.discretise_domain(n, 'random', locations=['D']) poisson_problem.discretise_domain(n, 'random', domains=['D'])
assert poisson_problem.input_pts['D'].shape[0] == n assert poisson_problem.discretised_domains['D'].shape[0] == n
poisson_problem.discretise_domain(n, 'latin', locations=['D']) poisson_problem.discretise_domain(n, 'latin', domains=['D'])
assert poisson_problem.input_pts['D'].shape[0] == n assert poisson_problem.discretised_domains['D'].shape[0] == n
poisson_problem.discretise_domain(n, 'lh', locations=['D']) poisson_problem.discretise_domain(n, 'lh', domains=['D'])
assert poisson_problem.input_pts['D'].shape[0] == n assert poisson_problem.discretised_domains['D'].shape[0] == n
poisson_problem.discretise_domain(n) poisson_problem.discretise_domain(n)
@@ -97,10 +97,9 @@ def test_sampling_few_variables():
poisson_problem = Poisson() poisson_problem = Poisson()
poisson_problem.discretise_domain(n, poisson_problem.discretise_domain(n,
'grid', 'grid',
locations=['D'], domains=['D'],
variables=['x']) variables=['x'])
assert poisson_problem.input_pts['D'].shape[1] == 1 assert poisson_problem.discretised_domains['D'].shape[1] == 1
assert poisson_problem.collector._is_conditions_ready['D'] is False
def test_variables_correct_order_sampling(): def test_variables_correct_order_sampling():
@@ -108,48 +107,24 @@ def test_variables_correct_order_sampling():
poisson_problem = Poisson() poisson_problem = Poisson()
poisson_problem.discretise_domain(n, poisson_problem.discretise_domain(n,
'grid', 'grid',
locations=['D'], domains=['D'])
variables=['x']) assert poisson_problem.discretised_domains['D'].labels == sorted(
poisson_problem.discretise_domain(n,
'grid',
locations=['D'],
variables=['y'])
assert poisson_problem.input_pts['D'].labels == sorted(
poisson_problem.input_variables) poisson_problem.input_variables)
poisson_problem.discretise_domain(n, 'grid', locations=['D']) poisson_problem.discretise_domain(n, 'grid', domains=['D'])
assert poisson_problem.input_pts['D'].labels == sorted( assert poisson_problem.discretised_domains['D'].labels == sorted(
poisson_problem.input_variables)
poisson_problem.discretise_domain(n,
'grid',
locations=['D'],
variables=['y'])
poisson_problem.discretise_domain(n,
'grid',
locations=['D'],
variables=['x'])
assert poisson_problem.input_pts['D'].labels == sorted(
poisson_problem.input_variables) poisson_problem.input_variables)
def test_add_points(): # def test_add_points():
poisson_problem = Poisson() # poisson_problem = Poisson()
poisson_problem.discretise_domain(0, # poisson_problem.discretise_domain(0,
'random', # 'random',
locations=['D'], # domains=['D'],
variables=['x', 'y']) # variables=['x', 'y'])
new_pts = LabelTensor(torch.tensor([[0.5, -0.5]]), labels=['x', 'y']) # new_pts = LabelTensor(torch.tensor([[0.5, -0.5]]), labels=['x', 'y'])
poisson_problem.add_points({'D': new_pts}) # poisson_problem.add_points({'D': new_pts})
assert torch.isclose(poisson_problem.input_pts['D'].extract('x'), # assert torch.isclose(poisson_problem.discretised_domain['D'].extract('x'),
new_pts.extract('x')) # new_pts.extract('x'))
assert torch.isclose(poisson_problem.input_pts['D'].extract('y'), # assert torch.isclose(poisson_problem.discretised_domain['D'].extract('y'),
new_pts.extract('y')) # new_pts.extract('y'))
def test_collector():
poisson_problem = Poisson()
collector = poisson_problem.collector
assert collector.full is False
assert collector._is_conditions_ready['data'] is True
poisson_problem.discretise_domain(10)
assert collector.full is True