Improve conditions and refactor dataset classes (#475)

* Reimplement conditions

* Refactor datasets and implement LabelBatch

---------

Co-authored-by: Dario Coscia <dariocos99@gmail.com>
This commit is contained in:
Filippo Olivo
2025-03-07 11:24:09 +01:00
committed by Nicola Demo
parent bdad144461
commit a0cbf1c44a
40 changed files with 943 additions and 550 deletions

View File

@@ -1,6 +1,6 @@
import torch
from pina.problem import AbstractProblem
from pina.condition import InputOutputPointsCondition
from pina.condition import InputTargetCondition
from pina.problem.zoo.supervised_problem import SupervisedProblem
from pina.graph import RadiusGraph
@@ -13,7 +13,7 @@ def test_constructor():
assert hasattr(problem, "conditions")
assert isinstance(problem.conditions, dict)
assert list(problem.conditions.keys()) == ["data"]
assert isinstance(problem.conditions["data"], InputOutputPointsCondition)
assert isinstance(problem.conditions["data"], InputTargetCondition)
def test_constructor_graph():
@@ -23,12 +23,12 @@ def test_constructor_graph():
RadiusGraph(x=x_, pos=pos_, radius=0.2, edge_attr=True)
for x_, pos_ in zip(x, pos)
]
output_ = torch.rand((100, 10))
output_ = torch.rand((20, 100, 10))
problem = SupervisedProblem(input_=input_, output_=output_)
assert isinstance(problem, AbstractProblem)
assert hasattr(problem, "conditions")
assert isinstance(problem.conditions, dict)
assert list(problem.conditions.keys()) == ["data"]
assert isinstance(problem.conditions["data"], InputOutputPointsCondition)
assert isinstance(problem.conditions["data"].input_points, list)
assert isinstance(problem.conditions["data"].output_points, torch.Tensor)
assert isinstance(problem.conditions["data"], InputTargetCondition)
assert isinstance(problem.conditions["data"].input, list)
assert isinstance(problem.conditions["data"].target, torch.Tensor)