Improve conditions and refactor dataset classes (#475)

* Reimplement conditions

* Refactor datasets and implement LabelBatch

---------

Co-authored-by: Dario Coscia <dariocos99@gmail.com>
This commit is contained in:
Filippo Olivo
2025-03-07 11:24:09 +01:00
committed by Nicola Demo
parent bdad144461
commit a0cbf1c44a
40 changed files with 943 additions and 550 deletions

View File

@@ -1,7 +1,7 @@
import torch
import pytest
from pina import Condition, LabelTensor
from pina.condition import InputOutputPointsCondition
from pina.condition import InputTargetCondition
from pina.problem import AbstractProblem
from pina.solver import SupervisedSolver
from pina.model import FeedForward
@@ -14,8 +14,8 @@ class LabelTensorProblem(AbstractProblem):
output_variables = ['u']
conditions = {
'data': Condition(
input_points=LabelTensor(torch.randn(20, 2), ['u_0', 'u_1']),
output_points=LabelTensor(torch.randn(20, 1), ['u'])),
input=LabelTensor(torch.randn(20, 2), ['u_0', 'u_1']),
target=LabelTensor(torch.randn(20, 1), ['u'])),
}
@@ -24,8 +24,8 @@ class TensorProblem(AbstractProblem):
output_variables = ['u']
conditions = {
'data': Condition(
input_points=torch.randn(20, 2),
output_points=torch.randn(20, 1))
input=torch.randn(20, 2),
target=torch.randn(20, 1))
}
@@ -36,7 +36,7 @@ def test_constructor():
SupervisedSolver(problem=TensorProblem(), model=model)
SupervisedSolver(problem=LabelTensorProblem(), model=model)
assert SupervisedSolver.accepted_conditions_types == (
InputOutputPointsCondition
InputTargetCondition
)