Bug fix and add additional tests for Dataset and DataModule (#517)

This commit is contained in:
Filippo Olivo
2025-03-25 12:18:27 +01:00
committed by GitHub
parent 03ef90c358
commit ef29f0a95d
4 changed files with 143 additions and 16 deletions

View File

@@ -238,3 +238,94 @@ def test_dataloader_labels(input_, output_, automatic_batching):
assert data["data"]["input"].labels == ["u", "v", "w"]
assert isinstance(data["data"]["target"], torch.Tensor)
assert data["data"]["target"].labels == ["u", "v", "w"]
def test_get_all_data():
input = torch.stack([torch.zeros((1,)) + i for i in range(1000)])
target = input
problem = SupervisedProblem(input, target)
datamodule = PinaDataModule(
problem,
train_size=0.7,
test_size=0.2,
val_size=0.1,
batch_size=64,
shuffle=False,
repeat=False,
automatic_batching=None,
num_workers=0,
pin_memory=False,
)
datamodule.setup("fit")
datamodule.setup("test")
assert len(datamodule.train_dataset.get_all_data()["data"]["input"]) == 700
assert torch.isclose(
datamodule.train_dataset.get_all_data()["data"]["input"], input[:700]
).all()
assert len(datamodule.val_dataset.get_all_data()["data"]["input"]) == 100
assert torch.isclose(
datamodule.val_dataset.get_all_data()["data"]["input"], input[900:]
).all()
assert len(datamodule.test_dataset.get_all_data()["data"]["input"]) == 200
assert torch.isclose(
datamodule.test_dataset.get_all_data()["data"]["input"], input[700:900]
).all()
def test_input_propery_tensor():
input = torch.stack([torch.zeros((1,)) + i for i in range(1000)])
target = input
problem = SupervisedProblem(input, target)
datamodule = PinaDataModule(
problem,
train_size=0.7,
test_size=0.2,
val_size=0.1,
batch_size=64,
shuffle=False,
repeat=False,
automatic_batching=None,
num_workers=0,
pin_memory=False,
)
datamodule.setup("fit")
datamodule.setup("test")
input_ = datamodule.input
assert isinstance(input_, dict)
assert isinstance(input_["train"], dict)
assert isinstance(input_["val"], dict)
assert isinstance(input_["test"], dict)
assert torch.isclose(input_["train"]["data"], input[:700]).all()
assert torch.isclose(input_["val"]["data"], input[900:]).all()
assert torch.isclose(input_["test"]["data"], input[700:900]).all()
def test_input_propery_graph():
problem = SupervisedProblem(input_graph, output_graph)
datamodule = PinaDataModule(
problem,
train_size=0.7,
test_size=0.2,
val_size=0.1,
batch_size=64,
shuffle=False,
repeat=False,
automatic_batching=None,
num_workers=0,
pin_memory=False,
)
datamodule.setup("fit")
datamodule.setup("test")
input_ = datamodule.input
assert isinstance(input_, dict)
assert isinstance(input_["train"], dict)
assert isinstance(input_["val"], dict)
assert isinstance(input_["test"], dict)
assert isinstance(input_["train"]["data"], list)
assert isinstance(input_["val"]["data"], list)
assert isinstance(input_["test"]["data"], list)
assert len(input_["train"]["data"]) == 70
assert len(input_["val"]["data"]) == 10
assert len(input_["test"]["data"]) == 20

View File

@@ -31,7 +31,7 @@ conditions_dict_single = {
max_conditions_lengths_single = {"data": 100}
# Problem with multiple conditions
conditions_dict_single_multi = {
conditions_dict_multi = {
"data_1": {
"input": input_,
"target": output_,
@@ -49,7 +49,7 @@ max_conditions_lengths_multi = {"data_1": 100, "data_2": 50}
"conditions_dict, max_conditions_lengths",
[
(conditions_dict_single, max_conditions_lengths_single),
(conditions_dict_single_multi, max_conditions_lengths_multi),
(conditions_dict_multi, max_conditions_lengths_multi),
],
)
def test_constructor(conditions_dict, max_conditions_lengths):
@@ -66,7 +66,7 @@ def test_constructor(conditions_dict, max_conditions_lengths):
"conditions_dict, max_conditions_lengths",
[
(conditions_dict_single, max_conditions_lengths_single),
(conditions_dict_single_multi, max_conditions_lengths_multi),
(conditions_dict_multi, max_conditions_lengths_multi),
],
)
def test_getitem(conditions_dict, max_conditions_lengths):
@@ -110,3 +110,29 @@ def test_getitem(conditions_dict, max_conditions_lengths):
]
)
assert all([d["input"].edge_attr.shape[0] == 1200 for d in data.values()])
def test_input_single_condition():
dataset = PinaDatasetFactory(
conditions_dict_single,
max_conditions_lengths=max_conditions_lengths_single,
automatic_batching=True,
)
input_ = dataset.input
assert isinstance(input_, dict)
assert isinstance(input_["data"], list)
assert all([isinstance(d, Data) for d in input_["data"]])
def test_input_multi_condition():
dataset = PinaDatasetFactory(
conditions_dict_multi,
max_conditions_lengths=max_conditions_lengths_multi,
automatic_batching=True,
)
input_ = dataset.input
assert isinstance(input_, dict)
assert isinstance(input_["data_1"], list)
assert all([isinstance(d, Data) for d in input_["data_1"]])
assert isinstance(input_["data_2"], list)
assert all([isinstance(d, Data) for d in input_["data_2"]])