Bug fix and add additional tests for Dataset and DataModule (#517)
This commit is contained in:
committed by
FilippoOlivo
parent
79a7199985
commit
80c257da4d
@@ -31,7 +31,7 @@ conditions_dict_single = {
|
||||
max_conditions_lengths_single = {"data": 100}
|
||||
|
||||
# Problem with multiple conditions
|
||||
conditions_dict_single_multi = {
|
||||
conditions_dict_multi = {
|
||||
"data_1": {
|
||||
"input": input_,
|
||||
"target": output_,
|
||||
@@ -49,7 +49,7 @@ max_conditions_lengths_multi = {"data_1": 100, "data_2": 50}
|
||||
"conditions_dict, max_conditions_lengths",
|
||||
[
|
||||
(conditions_dict_single, max_conditions_lengths_single),
|
||||
(conditions_dict_single_multi, max_conditions_lengths_multi),
|
||||
(conditions_dict_multi, max_conditions_lengths_multi),
|
||||
],
|
||||
)
|
||||
def test_constructor(conditions_dict, max_conditions_lengths):
|
||||
@@ -66,7 +66,7 @@ def test_constructor(conditions_dict, max_conditions_lengths):
|
||||
"conditions_dict, max_conditions_lengths",
|
||||
[
|
||||
(conditions_dict_single, max_conditions_lengths_single),
|
||||
(conditions_dict_single_multi, max_conditions_lengths_multi),
|
||||
(conditions_dict_multi, max_conditions_lengths_multi),
|
||||
],
|
||||
)
|
||||
def test_getitem(conditions_dict, max_conditions_lengths):
|
||||
@@ -110,3 +110,29 @@ def test_getitem(conditions_dict, max_conditions_lengths):
|
||||
]
|
||||
)
|
||||
assert all([d["input"].edge_attr.shape[0] == 1200 for d in data.values()])
|
||||
|
||||
|
||||
def test_input_single_condition():
|
||||
dataset = PinaDatasetFactory(
|
||||
conditions_dict_single,
|
||||
max_conditions_lengths=max_conditions_lengths_single,
|
||||
automatic_batching=True,
|
||||
)
|
||||
input_ = dataset.input
|
||||
assert isinstance(input_, dict)
|
||||
assert isinstance(input_["data"], list)
|
||||
assert all([isinstance(d, Data) for d in input_["data"]])
|
||||
|
||||
|
||||
def test_input_multi_condition():
|
||||
dataset = PinaDatasetFactory(
|
||||
conditions_dict_multi,
|
||||
max_conditions_lengths=max_conditions_lengths_multi,
|
||||
automatic_batching=True,
|
||||
)
|
||||
input_ = dataset.input
|
||||
assert isinstance(input_, dict)
|
||||
assert isinstance(input_["data_1"], list)
|
||||
assert all([isinstance(d, Data) for d in input_["data_1"]])
|
||||
assert isinstance(input_["data_2"], list)
|
||||
assert all([isinstance(d, Data) for d in input_["data_2"]])
|
||||
|
||||
Reference in New Issue
Block a user