Improve conditions and refactor dataset classes (#475)

* Reimplement conditions

* Refactor datasets and implement LabelBatch

---------

Co-authored-by: Dario Coscia <dariocos99@gmail.com>
This commit is contained in:
Filippo Olivo
2025-03-07 11:24:09 +01:00
committed by Nicola Demo
parent bdad144461
commit a0cbf1c44a
40 changed files with 943 additions and 550 deletions

View File

@@ -11,8 +11,8 @@ from pina.problem.zoo import (
InversePoisson2DSquareProblem as InversePoisson
)
from pina.condition import (
InputOutputPointsCondition,
InputPointsEquationCondition,
InputTargetCondition,
InputEquationCondition,
DomainEquationCondition
)
from torch._dynamo.eval_frame import OptimizedModule
@@ -43,8 +43,8 @@ input_pts = LabelTensor(input_pts, problem.input_variables)
output_pts = torch.rand(50, len(problem.output_variables))
output_pts = LabelTensor(output_pts, problem.output_variables)
problem.conditions['data'] = Condition(
input_points=input_pts,
output_points=output_pts
input=input_pts,
target=output_pts
)
@@ -55,8 +55,8 @@ def test_constructor(problem):
solver = GradientPINN(model=model, problem=problem)
assert solver.accepted_conditions_types == (
InputOutputPointsCondition,
InputPointsEquationCondition,
InputTargetCondition,
InputEquationCondition,
DomainEquationCondition
)