Improve conditions and refactor dataset classes (#475)
* Reimplement conditions * Refactor datasets and implement LabelBatch --------- Co-authored-by: Dario Coscia <dariocos99@gmail.com>
This commit is contained in:
committed by
Nicola Demo
parent
bdad144461
commit
a0cbf1c44a
@@ -1,7 +1,7 @@
|
||||
import torch
|
||||
import pytest
|
||||
from pina import Condition, LabelTensor, Graph
|
||||
from pina.condition import InputOutputPointsCondition, DomainEquationCondition
|
||||
from pina.condition import InputTargetCondition, DomainEquationCondition
|
||||
from pina.graph import RadiusGraph
|
||||
from pina.problem import AbstractProblem, SpatialProblem
|
||||
from pina.domain import CartesianDomain
|
||||
@@ -16,16 +16,16 @@ def test_supervised_tensor_collector():
|
||||
output_variables = None
|
||||
conditions = {
|
||||
"data1": Condition(
|
||||
input_points=torch.rand((10, 2)),
|
||||
output_points=torch.rand((10, 2)),
|
||||
input=torch.rand((10, 2)),
|
||||
target=torch.rand((10, 2)),
|
||||
),
|
||||
"data2": Condition(
|
||||
input_points=torch.rand((20, 2)),
|
||||
output_points=torch.rand((20, 2)),
|
||||
input=torch.rand((20, 2)),
|
||||
target=torch.rand((20, 2)),
|
||||
),
|
||||
"data3": Condition(
|
||||
input_points=torch.rand((30, 2)),
|
||||
output_points=torch.rand((30, 2)),
|
||||
input=torch.rand((30, 2)),
|
||||
target=torch.rand((30, 2)),
|
||||
),
|
||||
}
|
||||
|
||||
@@ -74,7 +74,7 @@ def test_pinn_collector():
|
||||
domain=CartesianDomain({"x": [0, 1], "y": [0, 1]}),
|
||||
equation=my_laplace,
|
||||
),
|
||||
"data": Condition(input_points=in_, output_points=out_),
|
||||
"data": Condition(input=in_, target=out_),
|
||||
}
|
||||
|
||||
def poisson_sol(self, pts):
|
||||
@@ -95,16 +95,16 @@ def test_pinn_collector():
|
||||
collector.store_sample_domains()
|
||||
|
||||
for k, v in problem.conditions.items():
|
||||
if isinstance(v, InputOutputPointsCondition):
|
||||
if isinstance(v, InputTargetCondition):
|
||||
assert list(collector.data_collections[k].keys()) == [
|
||||
"input_points",
|
||||
"output_points",
|
||||
"input",
|
||||
"target",
|
||||
]
|
||||
|
||||
for k, v in problem.conditions.items():
|
||||
if isinstance(v, DomainEquationCondition):
|
||||
assert list(collector.data_collections[k].keys()) == [
|
||||
"input_points",
|
||||
"input",
|
||||
"equation",
|
||||
]
|
||||
|
||||
@@ -123,8 +123,8 @@ def test_supervised_graph_collector():
|
||||
class SupervisedProblem(AbstractProblem):
|
||||
output_variables = None
|
||||
conditions = {
|
||||
"data1": Condition(input_points=graph_list_1, output_points=out_1),
|
||||
"data2": Condition(input_points=graph_list_2, output_points=out_2),
|
||||
"data1": Condition(input=graph_list_1, target=out_1),
|
||||
"data2": Condition(input=graph_list_2, target=out_2),
|
||||
}
|
||||
|
||||
problem = SupervisedProblem()
|
||||
|
||||
Reference in New Issue
Block a user