Improve conditions and refactor dataset classes (#475)
* Reimplement conditions * Refactor datasets and implement LabelBatch --------- Co-authored-by: Dario Coscia <dariocos99@gmail.com>
This commit is contained in:
committed by
Nicola Demo
parent
bdad144461
commit
a0cbf1c44a
@@ -6,8 +6,8 @@ from pina.model import FeedForward
|
||||
from pina.trainer import Trainer
|
||||
from pina.solver import PINN
|
||||
from pina.condition import (
|
||||
InputOutputPointsCondition,
|
||||
InputPointsEquationCondition,
|
||||
InputTargetCondition,
|
||||
InputEquationCondition,
|
||||
DomainEquationCondition
|
||||
)
|
||||
from pina.problem.zoo import (
|
||||
@@ -33,8 +33,8 @@ input_pts = LabelTensor(input_pts, problem.input_variables)
|
||||
output_pts = torch.rand(50, len(problem.output_variables))
|
||||
output_pts = LabelTensor(output_pts, problem.output_variables)
|
||||
problem.conditions['data'] = Condition(
|
||||
input_points=input_pts,
|
||||
output_points=output_pts
|
||||
input=input_pts,
|
||||
target=output_pts
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("problem", [problem, inverse_problem])
|
||||
@@ -42,8 +42,8 @@ def test_constructor(problem):
|
||||
solver = PINN(problem=problem, model=model)
|
||||
|
||||
assert solver.accepted_conditions_types == (
|
||||
InputOutputPointsCondition,
|
||||
InputPointsEquationCondition,
|
||||
InputTargetCondition,
|
||||
InputEquationCondition,
|
||||
DomainEquationCondition
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user