Bug fix and add additional tests for Dataset and DataModule (#517)
This commit is contained in:
committed by
FilippoOlivo
parent
79a7199985
commit
80c257da4d
@@ -238,3 +238,94 @@ def test_dataloader_labels(input_, output_, automatic_batching):
|
||||
assert data["data"]["input"].labels == ["u", "v", "w"]
|
||||
assert isinstance(data["data"]["target"], torch.Tensor)
|
||||
assert data["data"]["target"].labels == ["u", "v", "w"]
|
||||
|
||||
|
||||
def test_get_all_data():
|
||||
input = torch.stack([torch.zeros((1,)) + i for i in range(1000)])
|
||||
target = input
|
||||
|
||||
problem = SupervisedProblem(input, target)
|
||||
datamodule = PinaDataModule(
|
||||
problem,
|
||||
train_size=0.7,
|
||||
test_size=0.2,
|
||||
val_size=0.1,
|
||||
batch_size=64,
|
||||
shuffle=False,
|
||||
repeat=False,
|
||||
automatic_batching=None,
|
||||
num_workers=0,
|
||||
pin_memory=False,
|
||||
)
|
||||
datamodule.setup("fit")
|
||||
datamodule.setup("test")
|
||||
assert len(datamodule.train_dataset.get_all_data()["data"]["input"]) == 700
|
||||
assert torch.isclose(
|
||||
datamodule.train_dataset.get_all_data()["data"]["input"], input[:700]
|
||||
).all()
|
||||
assert len(datamodule.val_dataset.get_all_data()["data"]["input"]) == 100
|
||||
assert torch.isclose(
|
||||
datamodule.val_dataset.get_all_data()["data"]["input"], input[900:]
|
||||
).all()
|
||||
assert len(datamodule.test_dataset.get_all_data()["data"]["input"]) == 200
|
||||
assert torch.isclose(
|
||||
datamodule.test_dataset.get_all_data()["data"]["input"], input[700:900]
|
||||
).all()
|
||||
|
||||
|
||||
def test_input_propery_tensor():
|
||||
input = torch.stack([torch.zeros((1,)) + i for i in range(1000)])
|
||||
target = input
|
||||
|
||||
problem = SupervisedProblem(input, target)
|
||||
datamodule = PinaDataModule(
|
||||
problem,
|
||||
train_size=0.7,
|
||||
test_size=0.2,
|
||||
val_size=0.1,
|
||||
batch_size=64,
|
||||
shuffle=False,
|
||||
repeat=False,
|
||||
automatic_batching=None,
|
||||
num_workers=0,
|
||||
pin_memory=False,
|
||||
)
|
||||
datamodule.setup("fit")
|
||||
datamodule.setup("test")
|
||||
input_ = datamodule.input
|
||||
assert isinstance(input_, dict)
|
||||
assert isinstance(input_["train"], dict)
|
||||
assert isinstance(input_["val"], dict)
|
||||
assert isinstance(input_["test"], dict)
|
||||
assert torch.isclose(input_["train"]["data"], input[:700]).all()
|
||||
assert torch.isclose(input_["val"]["data"], input[900:]).all()
|
||||
assert torch.isclose(input_["test"]["data"], input[700:900]).all()
|
||||
|
||||
|
||||
def test_input_propery_graph():
|
||||
problem = SupervisedProblem(input_graph, output_graph)
|
||||
datamodule = PinaDataModule(
|
||||
problem,
|
||||
train_size=0.7,
|
||||
test_size=0.2,
|
||||
val_size=0.1,
|
||||
batch_size=64,
|
||||
shuffle=False,
|
||||
repeat=False,
|
||||
automatic_batching=None,
|
||||
num_workers=0,
|
||||
pin_memory=False,
|
||||
)
|
||||
datamodule.setup("fit")
|
||||
datamodule.setup("test")
|
||||
input_ = datamodule.input
|
||||
assert isinstance(input_, dict)
|
||||
assert isinstance(input_["train"], dict)
|
||||
assert isinstance(input_["val"], dict)
|
||||
assert isinstance(input_["test"], dict)
|
||||
assert isinstance(input_["train"]["data"], list)
|
||||
assert isinstance(input_["val"]["data"], list)
|
||||
assert isinstance(input_["test"]["data"], list)
|
||||
assert len(input_["train"]["data"]) == 70
|
||||
assert len(input_["val"]["data"]) == 10
|
||||
assert len(input_["test"]["data"]) == 20
|
||||
|
||||
Reference in New Issue
Block a user