Improve conditions and refactor dataset classes (#475)
* Reimplement conditions * Refactor datasets and implement LabelBatch --------- Co-authored-by: Dario Coscia <dariocos99@gmail.com>
This commit is contained in:
committed by
Nicola Demo
parent
bdad144461
commit
a0cbf1c44a
@@ -4,7 +4,7 @@ import torch.nn as nn
|
||||
import pytest
|
||||
from pina import Condition, LabelTensor
|
||||
from pina.solver import GAROM
|
||||
from pina.condition import InputOutputPointsCondition
|
||||
from pina.condition import InputTargetCondition
|
||||
from pina.problem import AbstractProblem
|
||||
from pina.model import FeedForward
|
||||
from pina.trainer import Trainer
|
||||
@@ -16,8 +16,8 @@ class TensorProblem(AbstractProblem):
|
||||
output_variables = ['u']
|
||||
conditions = {
|
||||
'data': Condition(
|
||||
output_points=torch.randn(50, 2),
|
||||
input_points=torch.randn(50, 1))
|
||||
target=torch.randn(50, 2),
|
||||
input=torch.randn(50, 1))
|
||||
}
|
||||
|
||||
|
||||
@@ -74,7 +74,7 @@ def test_constructor():
|
||||
generator=Generator(),
|
||||
discriminator=Discriminator())
|
||||
assert GAROM.accepted_conditions_types == (
|
||||
InputOutputPointsCondition
|
||||
InputTargetCondition
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user