Improve conditions and refactor dataset classes (#475)
* Reimplement conditions * Refactor datasets and implement LabelBatch --------- Co-authored-by: Dario Coscia <dariocos99@gmail.com>
This commit is contained in:
committed by
Nicola Demo
parent
bdad144461
commit
a0cbf1c44a
@@ -3,7 +3,7 @@ import pytest
|
||||
|
||||
from pina import Condition, LabelTensor
|
||||
from pina.problem import AbstractProblem
|
||||
from pina.condition import InputOutputPointsCondition
|
||||
from pina.condition import InputTargetCondition
|
||||
from pina.solver import ReducedOrderModelSolver
|
||||
from pina.trainer import Trainer
|
||||
from pina.model import FeedForward
|
||||
@@ -16,8 +16,8 @@ class LabelTensorProblem(AbstractProblem):
|
||||
output_variables = ['u']
|
||||
conditions = {
|
||||
'data': Condition(
|
||||
input_points=LabelTensor(torch.randn(20, 2), ['u_0', 'u_1']),
|
||||
output_points=LabelTensor(torch.randn(20, 1), ['u'])),
|
||||
input=LabelTensor(torch.randn(20, 2), ['u_0', 'u_1']),
|
||||
target=LabelTensor(torch.randn(20, 1), ['u'])),
|
||||
}
|
||||
|
||||
|
||||
@@ -26,8 +26,8 @@ class TensorProblem(AbstractProblem):
|
||||
output_variables = ['u']
|
||||
conditions = {
|
||||
'data': Condition(
|
||||
input_points=torch.randn(20, 2),
|
||||
output_points=torch.randn(20, 1))
|
||||
input=torch.randn(20, 2),
|
||||
target=torch.randn(20, 1))
|
||||
}
|
||||
|
||||
|
||||
@@ -68,7 +68,7 @@ def test_constructor():
|
||||
ReducedOrderModelSolver(problem=LabelTensorProblem(),
|
||||
reduction_network=reduction_net,
|
||||
interpolation_network=interpolation_net)
|
||||
assert ReducedOrderModelSolver.accepted_conditions_types == InputOutputPointsCondition
|
||||
assert ReducedOrderModelSolver.accepted_conditions_types == InputTargetCondition
|
||||
with pytest.raises(SyntaxError):
|
||||
ReducedOrderModelSolver(problem=problem,
|
||||
reduction_network=AE_missing_encode(
|
||||
|
||||
Reference in New Issue
Block a user