Improve conditions and refactor dataset classes (#475)
* Reimplement conditions * Refactor datasets and implement LabelBatch --------- Co-authored-by: Dario Coscia <dariocos99@gmail.com>
This commit is contained in:
committed by
Nicola Demo
parent
bdad144461
commit
a0cbf1c44a
@@ -1,7 +1,7 @@
|
||||
import torch
|
||||
import pytest
|
||||
from pina import Condition, LabelTensor
|
||||
from pina.condition import InputOutputPointsCondition
|
||||
from pina.condition import InputTargetCondition
|
||||
from pina.problem import AbstractProblem
|
||||
from pina.solver import SupervisedSolver
|
||||
from pina.model import FeedForward
|
||||
@@ -14,8 +14,8 @@ class LabelTensorProblem(AbstractProblem):
|
||||
output_variables = ['u']
|
||||
conditions = {
|
||||
'data': Condition(
|
||||
input_points=LabelTensor(torch.randn(20, 2), ['u_0', 'u_1']),
|
||||
output_points=LabelTensor(torch.randn(20, 1), ['u'])),
|
||||
input=LabelTensor(torch.randn(20, 2), ['u_0', 'u_1']),
|
||||
target=LabelTensor(torch.randn(20, 1), ['u'])),
|
||||
}
|
||||
|
||||
|
||||
@@ -24,8 +24,8 @@ class TensorProblem(AbstractProblem):
|
||||
output_variables = ['u']
|
||||
conditions = {
|
||||
'data': Condition(
|
||||
input_points=torch.randn(20, 2),
|
||||
output_points=torch.randn(20, 1))
|
||||
input=torch.randn(20, 2),
|
||||
target=torch.randn(20, 1))
|
||||
}
|
||||
|
||||
|
||||
@@ -36,7 +36,7 @@ def test_constructor():
|
||||
SupervisedSolver(problem=TensorProblem(), model=model)
|
||||
SupervisedSolver(problem=LabelTensorProblem(), model=model)
|
||||
assert SupervisedSolver.accepted_conditions_types == (
|
||||
InputOutputPointsCondition
|
||||
InputTargetCondition
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user