Improve conditions and refactor dataset classes (#475)

* Reimplement conditions

* Refactor datasets and implement LabelBatch

---------

Co-authored-by: Dario Coscia <dariocos99@gmail.com>
This commit is contained in:
Filippo Olivo
2025-03-07 11:24:09 +01:00
committed by Nicola Demo
parent bdad144461
commit a0cbf1c44a
40 changed files with 943 additions and 550 deletions

View File

@@ -1,7 +1,7 @@
import torch
import pytest
from pina import Condition, LabelTensor, Graph
from pina.condition import InputOutputPointsCondition, DomainEquationCondition
from pina.condition import InputTargetCondition, DomainEquationCondition
from pina.graph import RadiusGraph
from pina.problem import AbstractProblem, SpatialProblem
from pina.domain import CartesianDomain
@@ -16,16 +16,16 @@ def test_supervised_tensor_collector():
output_variables = None
conditions = {
"data1": Condition(
input_points=torch.rand((10, 2)),
output_points=torch.rand((10, 2)),
input=torch.rand((10, 2)),
target=torch.rand((10, 2)),
),
"data2": Condition(
input_points=torch.rand((20, 2)),
output_points=torch.rand((20, 2)),
input=torch.rand((20, 2)),
target=torch.rand((20, 2)),
),
"data3": Condition(
input_points=torch.rand((30, 2)),
output_points=torch.rand((30, 2)),
input=torch.rand((30, 2)),
target=torch.rand((30, 2)),
),
}
@@ -74,7 +74,7 @@ def test_pinn_collector():
domain=CartesianDomain({"x": [0, 1], "y": [0, 1]}),
equation=my_laplace,
),
"data": Condition(input_points=in_, output_points=out_),
"data": Condition(input=in_, target=out_),
}
def poisson_sol(self, pts):
@@ -95,16 +95,16 @@ def test_pinn_collector():
collector.store_sample_domains()
for k, v in problem.conditions.items():
if isinstance(v, InputOutputPointsCondition):
if isinstance(v, InputTargetCondition):
assert list(collector.data_collections[k].keys()) == [
"input_points",
"output_points",
"input",
"target",
]
for k, v in problem.conditions.items():
if isinstance(v, DomainEquationCondition):
assert list(collector.data_collections[k].keys()) == [
"input_points",
"input",
"equation",
]
@@ -123,8 +123,8 @@ def test_supervised_graph_collector():
class SupervisedProblem(AbstractProblem):
output_variables = None
conditions = {
"data1": Condition(input_points=graph_list_1, output_points=out_1),
"data2": Condition(input_points=graph_list_2, output_points=out_2),
"data1": Condition(input=graph_list_1, target=out_1),
"data2": Condition(input=graph_list_2, target=out_2),
}
problem = SupervisedProblem()