Improve conditions and refactor dataset classes (#475)

* Reimplement conditions

* Refactor datasets and implement LabelBatch

---------

Co-authored-by: Dario Coscia <dariocos99@gmail.com>
This commit is contained in:
Filippo Olivo
2025-03-07 11:24:09 +01:00
committed by Nicola Demo
parent bdad144461
commit a0cbf1c44a
40 changed files with 943 additions and 550 deletions

View File

@@ -6,8 +6,8 @@ from pina.model import FeedForward
from pina.trainer import Trainer
from pina.solver import PINN
from pina.condition import (
InputOutputPointsCondition,
InputPointsEquationCondition,
InputTargetCondition,
InputEquationCondition,
DomainEquationCondition
)
from pina.problem.zoo import (
@@ -33,8 +33,8 @@ input_pts = LabelTensor(input_pts, problem.input_variables)
output_pts = torch.rand(50, len(problem.output_variables))
output_pts = LabelTensor(output_pts, problem.output_variables)
problem.conditions['data'] = Condition(
input_points=input_pts,
output_points=output_pts
input=input_pts,
target=output_pts
)
@pytest.mark.parametrize("problem", [problem, inverse_problem])
@@ -42,8 +42,8 @@ def test_constructor(problem):
solver = PINN(problem=problem, model=model)
assert solver.accepted_conditions_types == (
InputOutputPointsCondition,
InputPointsEquationCondition,
InputTargetCondition,
InputEquationCondition,
DomainEquationCondition
)