Improve conditions and refactor dataset classes (#475)

* Reimplement conditions

* Refactor datasets and implement LabelBatch

---------

Co-authored-by: Dario Coscia <dariocos99@gmail.com>
This commit is contained in:
Filippo Olivo
2025-03-07 11:24:09 +01:00
committed by Nicola Demo
parent bdad144461
commit a0cbf1c44a
40 changed files with 943 additions and 550 deletions

View File

@@ -4,7 +4,7 @@ import torch.nn as nn
import pytest
from pina import Condition, LabelTensor
from pina.solver import GAROM
from pina.condition import InputOutputPointsCondition
from pina.condition import InputTargetCondition
from pina.problem import AbstractProblem
from pina.model import FeedForward
from pina.trainer import Trainer
@@ -16,8 +16,8 @@ class TensorProblem(AbstractProblem):
output_variables = ['u']
conditions = {
'data': Condition(
output_points=torch.randn(50, 2),
input_points=torch.randn(50, 1))
target=torch.randn(50, 2),
input=torch.randn(50, 1))
}
@@ -74,7 +74,7 @@ def test_constructor():
generator=Generator(),
discriminator=Discriminator())
assert GAROM.accepted_conditions_types == (
InputOutputPointsCondition
InputTargetCondition
)