Improve conditions and refactor dataset classes (#475)

* Reimplement conditions

* Refactor datasets and implement LabelBatch

---------

Co-authored-by: Dario Coscia <dariocos99@gmail.com>
This commit is contained in:
Filippo Olivo
2025-03-07 11:24:09 +01:00
committed by Nicola Demo
parent bdad144461
commit a0cbf1c44a
40 changed files with 943 additions and 550 deletions

View File

@@ -6,8 +6,8 @@ from pina.model import FeedForward
from pina.trainer import Trainer
from pina.solver import RBAPINN
from pina.condition import (
InputOutputPointsCondition,
InputPointsEquationCondition,
InputTargetCondition,
InputEquationCondition,
DomainEquationCondition
)
from pina.problem.zoo import (
@@ -32,8 +32,8 @@ input_pts = LabelTensor(input_pts, problem.input_variables)
output_pts = torch.rand(50, len(problem.output_variables))
output_pts = LabelTensor(output_pts, problem.output_variables)
problem.conditions['data'] = Condition(
input_points=input_pts,
output_points=output_pts
input=input_pts,
target=output_pts
)
@@ -46,8 +46,8 @@ def test_constructor(problem, eta, gamma):
solver = RBAPINN(model=model, problem=problem, eta=eta, gamma=gamma)
assert solver.accepted_conditions_types == (
InputOutputPointsCondition,
InputPointsEquationCondition,
InputTargetCondition,
InputEquationCondition,
DomainEquationCondition
)