Improve conditions and refactor dataset classes (#475)
* Reimplement conditions * Refactor datasets and implement LabelBatch --------- Co-authored-by: Dario Coscia <dariocos99@gmail.com>
This commit is contained in:
committed by
Nicola Demo
parent
bdad144461
commit
a0cbf1c44a
@@ -11,8 +11,8 @@ from pina.problem.zoo import (
|
||||
InverseDiffusionReactionProblem
|
||||
)
|
||||
from pina.condition import (
|
||||
InputOutputPointsCondition,
|
||||
InputPointsEquationCondition,
|
||||
InputTargetCondition,
|
||||
InputEquationCondition,
|
||||
DomainEquationCondition
|
||||
)
|
||||
from torch._dynamo.eval_frame import OptimizedModule
|
||||
@@ -43,8 +43,8 @@ input_pts = LabelTensor(input_pts, problem.input_variables)
|
||||
output_pts = torch.rand(50, len(problem.output_variables))
|
||||
output_pts = LabelTensor(output_pts, problem.output_variables)
|
||||
problem.conditions['data'] = Condition(
|
||||
input_points=input_pts,
|
||||
output_points=output_pts
|
||||
input=input_pts,
|
||||
target=output_pts
|
||||
)
|
||||
|
||||
|
||||
@@ -56,8 +56,8 @@ def test_constructor(problem, eps):
|
||||
solver = CausalPINN(model=model, problem=problem, eps=eps)
|
||||
|
||||
assert solver.accepted_conditions_types == (
|
||||
InputOutputPointsCondition,
|
||||
InputPointsEquationCondition,
|
||||
InputTargetCondition,
|
||||
InputEquationCondition,
|
||||
DomainEquationCondition
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user