Improve conditions and refactor dataset classes (#475)

* Reimplement conditions

* Refactor datasets and implement LabelBatch

---------

Co-authored-by: Dario Coscia <dariocos99@gmail.com>
This commit is contained in:
Filippo Olivo
2025-03-07 11:24:09 +01:00
committed by Nicola Demo
parent bdad144461
commit a0cbf1c44a
40 changed files with 943 additions and 550 deletions

View File

@@ -10,8 +10,8 @@ from pina.problem.zoo import (
InversePoisson2DSquareProblem as InversePoisson
)
from pina.condition import (
InputOutputPointsCondition,
InputPointsEquationCondition,
InputTargetCondition,
InputEquationCondition,
DomainEquationCondition
)
from torch._dynamo.eval_frame import OptimizedModule
@@ -33,8 +33,8 @@ input_pts = LabelTensor(input_pts, problem.input_variables)
output_pts = torch.rand(50, len(problem.output_variables))
output_pts = LabelTensor(output_pts, problem.output_variables)
problem.conditions['data'] = Condition(
input_points=input_pts,
output_points=output_pts
input=input_pts,
target=output_pts
)
@@ -46,8 +46,8 @@ def test_constructor(problem, weight_fn):
solver = SAPINN(problem=problem, model=model, weight_function=weight_fn)
assert solver.accepted_conditions_types == (
InputOutputPointsCondition,
InputPointsEquationCondition,
InputTargetCondition,
InputEquationCondition,
DomainEquationCondition
)