fix tests

This commit is contained in:
FilippoOlivo
2025-11-13 17:03:31 +01:00
parent 0ee63686dd
commit 8440a672a7
5 changed files with 289 additions and 300 deletions

View File

@@ -51,7 +51,7 @@ def test_sample(condition_to_update):
} }
trainer.train() trainer.train()
after_n_points = { after_n_points = {
loc: len(trainer.data_module.train_dataset.input[loc]) loc: len(trainer.data_module.train_dataset[loc].input)
for loc in condition_to_update for loc in condition_to_update
} }
assert before_n_points == trainer.callbacks[0].initial_population_size assert before_n_points == trainer.callbacks[0].initial_population_size

View File

@@ -142,14 +142,10 @@ def test_setup(solver, fn, stage, apply_to):
for cond in ["data1", "data2"]: for cond in ["data1", "data2"]:
scale = scale_fn( scale = scale_fn(
trainer_copy.data_module.train_dataset.conditions_dict[cond][ trainer_copy.data_module.train_dataset[cond].data[apply_to]
apply_to
]
) )
shift = shift_fn( shift = shift_fn(
trainer_copy.data_module.train_dataset.conditions_dict[cond][ trainer_copy.data_module.train_dataset[cond].data[apply_to]
apply_to
]
) )
assert "scale" in normalizer[cond] assert "scale" in normalizer[cond]
assert "shift" in normalizer[cond] assert "shift" in normalizer[cond]
@@ -158,8 +154,8 @@ def test_setup(solver, fn, stage, apply_to):
for ds_name in stage_map[stage]: for ds_name in stage_map[stage]:
dataset = getattr(trainer.data_module, ds_name, None) dataset = getattr(trainer.data_module, ds_name, None)
old_dataset = getattr(trainer_copy.data_module, ds_name, None) old_dataset = getattr(trainer_copy.data_module, ds_name, None)
current_points = dataset.conditions_dict[cond][apply_to] current_points = dataset[cond].data[apply_to]
old_points = old_dataset.conditions_dict[cond][apply_to] old_points = old_dataset[cond].data[apply_to]
expected = (old_points - shift) / scale expected = (old_points - shift) / scale
assert torch.allclose(current_points, expected) assert torch.allclose(current_points, expected)
@@ -204,10 +200,10 @@ def test_setup_pinn(fn, stage, apply_to):
cond = "data" cond = "data"
scale = scale_fn( scale = scale_fn(
trainer_copy.data_module.train_dataset.conditions_dict[cond][apply_to] trainer_copy.data_module.train_dataset[cond].data[apply_to]
) )
shift = shift_fn( shift = shift_fn(
trainer_copy.data_module.train_dataset.conditions_dict[cond][apply_to] trainer_copy.data_module.train_dataset[cond].data[apply_to]
) )
assert "scale" in normalizer[cond] assert "scale" in normalizer[cond]
assert "shift" in normalizer[cond] assert "shift" in normalizer[cond]
@@ -216,8 +212,8 @@ def test_setup_pinn(fn, stage, apply_to):
for ds_name in stage_map[stage]: for ds_name in stage_map[stage]:
dataset = getattr(trainer.data_module, ds_name, None) dataset = getattr(trainer.data_module, ds_name, None)
old_dataset = getattr(trainer_copy.data_module, ds_name, None) old_dataset = getattr(trainer_copy.data_module, ds_name, None)
current_points = dataset.conditions_dict[cond][apply_to] current_points = dataset[cond].data[apply_to]
old_points = old_dataset.conditions_dict[cond][apply_to] old_points = old_dataset[cond].data[apply_to]
expected = (old_points - shift) / scale expected = (old_points - shift) / scale
assert torch.allclose(current_points, expected) assert torch.allclose(current_points, expected)
@@ -242,3 +238,7 @@ def test_setup_graph_dataset():
) )
with pytest.raises(NotImplementedError): with pytest.raises(NotImplementedError):
trainer.train() trainer.train()
# if __name__ == "__main__":
# test_setup(supervised_solver_lt, [torch.std, torch.mean], "all", "input")

View File

@@ -1,10 +1,11 @@
import torch import torch
import pytest import pytest
from pina.data import PinaDataModule from pina.data import PinaDataModule
from pina.data.dataset import PinaTensorDataset, PinaGraphDataset from pina.data.dataset import PinaDataset
from pina.problem.zoo import SupervisedProblem from pina.problem.zoo import SupervisedProblem
from pina.graph import RadiusGraph from pina.graph import RadiusGraph
from pina.data.data_module import DummyDataloader
from pina.data.dataloader import DummyDataloader, PinaDataLoader
from pina import Trainer from pina import Trainer
from pina.solver import SupervisedSolver from pina.solver import SupervisedSolver
from torch_geometric.data import Batch from torch_geometric.data import Batch
@@ -44,22 +45,33 @@ def test_setup_train(input_, output_, train_size, val_size, test_size):
) )
dm.setup() dm.setup()
assert hasattr(dm, "train_dataset") assert hasattr(dm, "train_dataset")
if isinstance(input_, torch.Tensor): assert isinstance(dm.train_dataset, dict)
assert isinstance(dm.train_dataset, PinaTensorDataset) assert all(
else: isinstance(dm.train_dataset[cond], PinaDataset)
assert isinstance(dm.train_dataset, PinaGraphDataset) for cond in dm.train_dataset
# assert len(dm.train_dataset) == int(len(input_) * train_size) )
assert all(
dm.train_dataset[cond].is_graph_dataset == isinstance(input_, list)
for cond in dm.train_dataset
)
assert all(
len(dm.train_dataset[cond]) == int(len(input_) * train_size)
for cond in dm.train_dataset
)
if test_size > 0: if test_size > 0:
assert hasattr(dm, "test_dataset") assert hasattr(dm, "test_dataset")
assert dm.test_dataset is None assert dm.test_dataset is None
else: else:
assert not hasattr(dm, "test_dataset") assert not hasattr(dm, "test_dataset")
assert hasattr(dm, "val_dataset") assert hasattr(dm, "val_dataset")
if isinstance(input_, torch.Tensor):
assert isinstance(dm.val_dataset, PinaTensorDataset) assert isinstance(dm.val_dataset, dict)
else: assert all(
assert isinstance(dm.val_dataset, PinaGraphDataset) isinstance(dm.val_dataset[cond], PinaDataset) for cond in dm.val_dataset
# assert len(dm.val_dataset) == int(len(input_) * val_size) )
assert all(
isinstance(dm.val_dataset[cond], PinaDataset) for cond in dm.val_dataset
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
@@ -87,49 +99,59 @@ def test_setup_test(input_, output_, train_size, val_size, test_size):
assert not hasattr(dm, "val_dataset") assert not hasattr(dm, "val_dataset")
assert hasattr(dm, "test_dataset") assert hasattr(dm, "test_dataset")
if isinstance(input_, torch.Tensor): assert all(
assert isinstance(dm.test_dataset, PinaTensorDataset) isinstance(dm.test_dataset[cond], PinaDataset)
else: for cond in dm.test_dataset
assert isinstance(dm.test_dataset, PinaGraphDataset) )
# assert len(dm.test_dataset) == int(len(input_) * test_size) assert all(
dm.test_dataset[cond].is_graph_dataset == isinstance(input_, list)
for cond in dm.test_dataset
@pytest.mark.parametrize( )
"input_, output_", assert all(
[(input_tensor, output_tensor), (input_graph, output_graph)], len(dm.test_dataset[cond]) == int(len(input_) * test_size)
) for cond in dm.test_dataset
def test_dummy_dataloader(input_, output_):
problem = SupervisedProblem(input_=input_, output_=output_)
solver = SupervisedSolver(problem=problem, model=torch.nn.Linear(10, 10))
trainer = Trainer(
solver, batch_size=None, train_size=0.7, val_size=0.3, test_size=0.0
) )
dm = trainer.data_module
dm.setup()
dm.trainer = trainer
dataloader = dm.train_dataloader()
assert isinstance(dataloader, DummyDataloader)
assert len(dataloader) == 1
data = next(dataloader)
assert isinstance(data, list)
assert isinstance(data[0], tuple)
if isinstance(input_, list):
assert isinstance(data[0][1]["input"], Batch)
else:
assert isinstance(data[0][1]["input"], torch.Tensor)
assert isinstance(data[0][1]["target"], torch.Tensor)
dataloader = dm.val_dataloader()
assert isinstance(dataloader, DummyDataloader) # @pytest.mark.parametrize(
assert len(dataloader) == 1 # "input_, output_",
data = next(dataloader) # [(input_tensor, output_tensor), (input_graph, output_graph)],
assert isinstance(data, list) # )
assert isinstance(data[0], tuple) # def test_dummy_dataloader(input_, output_):
if isinstance(input_, list): # problem = SupervisedProblem(input_=input_, output_=output_)
assert isinstance(data[0][1]["input"], Batch) # solver = SupervisedSolver(problem=problem, model=torch.nn.Linear(10, 10))
else: # trainer = Trainer(
assert isinstance(data[0][1]["input"], torch.Tensor) # solver, batch_size=None, train_size=0.7, val_size=0.3, test_size=0.0
assert isinstance(data[0][1]["target"], torch.Tensor) # )
# dm = trainer.data_module
# dm.setup()
# dm.trainer = trainer
# dataloader = dm.train_dataloader()
# assert isinstance(dataloader, PinaDataLoader)
# print(dataloader.dataloaders)
# assert all([isinstance(ds, DummyDataloader) for ds in dataloader.dataloaders.values()])
# data = next(iter(dataloader))
# assert isinstance(data, list)
# assert isinstance(data[0], tuple)
# if isinstance(input_, list):
# assert isinstance(data[0][1]["input"], Batch)
# else:
# assert isinstance(data[0][1]["input"], torch.Tensor)
# assert isinstance(data[0][1]["target"], torch.Tensor)
# dataloader = dm.val_dataloader()
# assert isinstance(dataloader, DummyDataloader)
# assert len(dataloader) == 1
# data = next(dataloader)
# assert isinstance(data, list)
# assert isinstance(data[0], tuple)
# if isinstance(input_, list):
# assert isinstance(data[0][1]["input"], Batch)
# else:
# assert isinstance(data[0][1]["input"], torch.Tensor)
# assert isinstance(data[0][1]["target"], torch.Tensor)
@pytest.mark.parametrize( @pytest.mark.parametrize(
@@ -147,12 +169,13 @@ def test_dataloader(input_, output_, automatic_batching):
val_size=0.3, val_size=0.3,
test_size=0.0, test_size=0.0,
automatic_batching=automatic_batching, automatic_batching=automatic_batching,
common_batch_size=True,
) )
dm = trainer.data_module dm = trainer.data_module
dm.setup() dm.setup()
dm.trainer = trainer dm.trainer = trainer
dataloader = dm.train_dataloader() dataloader = dm.train_dataloader()
assert isinstance(dataloader, DataLoader) assert isinstance(dataloader, PinaDataLoader)
assert len(dataloader) == 7 assert len(dataloader) == 7
data = next(iter(dataloader)) data = next(iter(dataloader))
assert isinstance(data, dict) assert isinstance(data, dict)
@@ -163,7 +186,7 @@ def test_dataloader(input_, output_, automatic_batching):
assert isinstance(data["data"]["target"], torch.Tensor) assert isinstance(data["data"]["target"], torch.Tensor)
dataloader = dm.val_dataloader() dataloader = dm.val_dataloader()
assert isinstance(dataloader, DataLoader) assert isinstance(dataloader, PinaDataLoader)
assert len(dataloader) == 3 assert len(dataloader) == 3
data = next(iter(dataloader)) data = next(iter(dataloader))
assert isinstance(data, dict) assert isinstance(data, dict)
@@ -202,12 +225,13 @@ def test_dataloader_labels(input_, output_, automatic_batching):
val_size=0.3, val_size=0.3,
test_size=0.0, test_size=0.0,
automatic_batching=automatic_batching, automatic_batching=automatic_batching,
common_batch_size=True,
) )
dm = trainer.data_module dm = trainer.data_module
dm.setup() dm.setup()
dm.trainer = trainer dm.trainer = trainer
dataloader = dm.train_dataloader() dataloader = dm.train_dataloader()
assert isinstance(dataloader, DataLoader) assert isinstance(dataloader, PinaDataLoader)
assert len(dataloader) == 7 assert len(dataloader) == 7
data = next(iter(dataloader)) data = next(iter(dataloader))
assert isinstance(data, dict) assert isinstance(data, dict)
@@ -223,7 +247,7 @@ def test_dataloader_labels(input_, output_, automatic_batching):
assert data["data"]["target"].labels == ["u", "v", "w"] assert data["data"]["target"].labels == ["u", "v", "w"]
dataloader = dm.val_dataloader() dataloader = dm.val_dataloader()
assert isinstance(dataloader, DataLoader) assert isinstance(dataloader, PinaDataLoader)
assert len(dataloader) == 3 assert len(dataloader) == 3
data = next(iter(dataloader)) data = next(iter(dataloader))
assert isinstance(data, dict) assert isinstance(data, dict)
@@ -240,39 +264,6 @@ def test_dataloader_labels(input_, output_, automatic_batching):
assert data["data"]["target"].labels == ["u", "v", "w"] assert data["data"]["target"].labels == ["u", "v", "w"]
def test_get_all_data():
input = torch.stack([torch.zeros((1,)) + i for i in range(1000)])
target = input
problem = SupervisedProblem(input, target)
datamodule = PinaDataModule(
problem,
train_size=0.7,
test_size=0.2,
val_size=0.1,
batch_size=64,
shuffle=False,
repeat=False,
automatic_batching=None,
num_workers=0,
pin_memory=False,
)
datamodule.setup("fit")
datamodule.setup("test")
assert len(datamodule.train_dataset.get_all_data()["data"]["input"]) == 700
assert torch.isclose(
datamodule.train_dataset.get_all_data()["data"]["input"], input[:700]
).all()
assert len(datamodule.val_dataset.get_all_data()["data"]["input"]) == 100
assert torch.isclose(
datamodule.val_dataset.get_all_data()["data"]["input"], input[900:]
).all()
assert len(datamodule.test_dataset.get_all_data()["data"]["input"]) == 200
assert torch.isclose(
datamodule.test_dataset.get_all_data()["data"]["input"], input[700:900]
).all()
def test_input_propery_tensor(): def test_input_propery_tensor():
input = torch.stack([torch.zeros((1,)) + i for i in range(1000)]) input = torch.stack([torch.zeros((1,)) + i for i in range(1000)])
target = input target = input
@@ -285,7 +276,6 @@ def test_input_propery_tensor():
val_size=0.1, val_size=0.1,
batch_size=64, batch_size=64,
shuffle=False, shuffle=False,
repeat=False,
automatic_batching=None, automatic_batching=None,
num_workers=0, num_workers=0,
pin_memory=False, pin_memory=False,
@@ -311,7 +301,6 @@ def test_input_propery_graph():
val_size=0.1, val_size=0.1,
batch_size=64, batch_size=64,
shuffle=False, shuffle=False,
repeat=False,
automatic_batching=None, automatic_batching=None,
num_workers=0, num_workers=0,
pin_memory=False, pin_memory=False,

View File

@@ -1,138 +1,138 @@
import torch # import torch
import pytest # import pytest
from pina.data.dataset import PinaDatasetFactory, PinaGraphDataset # from pina.data.dataset import PinaDatasetFactory, PinaGraphDataset
from pina.graph import KNNGraph # from pina.graph import KNNGraph
from torch_geometric.data import Data # from torch_geometric.data import Data
x = torch.rand((100, 20, 10)) # x = torch.rand((100, 20, 10))
pos = torch.rand((100, 20, 2)) # pos = torch.rand((100, 20, 2))
input_ = [ # input_ = [
KNNGraph(x=x_, pos=pos_, neighbours=3, edge_attr=True) # KNNGraph(x=x_, pos=pos_, neighbours=3, edge_attr=True)
for x_, pos_ in zip(x, pos) # for x_, pos_ in zip(x, pos)
] # ]
output_ = torch.rand((100, 20, 10)) # output_ = torch.rand((100, 20, 10))
x_2 = torch.rand((50, 20, 10)) # x_2 = torch.rand((50, 20, 10))
pos_2 = torch.rand((50, 20, 2)) # pos_2 = torch.rand((50, 20, 2))
input_2_ = [ # input_2_ = [
KNNGraph(x=x_, pos=pos_, neighbours=3, edge_attr=True) # KNNGraph(x=x_, pos=pos_, neighbours=3, edge_attr=True)
for x_, pos_ in zip(x_2, pos_2) # for x_, pos_ in zip(x_2, pos_2)
] # ]
output_2_ = torch.rand((50, 20, 10)) # output_2_ = torch.rand((50, 20, 10))
# Problem with a single condition # # Problem with a single condition
conditions_dict_single = { # conditions_dict_single = {
"data": { # "data": {
"input": input_, # "input": input_,
"target": output_, # "target": output_,
} # }
} # }
max_conditions_lengths_single = {"data": 100} # max_conditions_lengths_single = {"data": 100}
# Problem with multiple conditions # # Problem with multiple conditions
conditions_dict_multi = { # conditions_dict_multi = {
"data_1": { # "data_1": {
"input": input_, # "input": input_,
"target": output_, # "target": output_,
}, # },
"data_2": { # "data_2": {
"input": input_2_, # "input": input_2_,
"target": output_2_, # "target": output_2_,
}, # },
} # }
max_conditions_lengths_multi = {"data_1": 100, "data_2": 50} # max_conditions_lengths_multi = {"data_1": 100, "data_2": 50}
@pytest.mark.parametrize( # @pytest.mark.parametrize(
"conditions_dict, max_conditions_lengths", # "conditions_dict, max_conditions_lengths",
[ # [
(conditions_dict_single, max_conditions_lengths_single), # (conditions_dict_single, max_conditions_lengths_single),
(conditions_dict_multi, max_conditions_lengths_multi), # (conditions_dict_multi, max_conditions_lengths_multi),
], # ],
) # )
def test_constructor(conditions_dict, max_conditions_lengths): # def test_constructor(conditions_dict, max_conditions_lengths):
dataset = PinaDatasetFactory( # dataset = PinaDatasetFactory(
conditions_dict, # conditions_dict,
max_conditions_lengths=max_conditions_lengths, # max_conditions_lengths=max_conditions_lengths,
automatic_batching=True, # automatic_batching=True,
) # )
assert isinstance(dataset, PinaGraphDataset) # assert isinstance(dataset, PinaGraphDataset)
assert len(dataset) == 100 # assert len(dataset) == 100
@pytest.mark.parametrize( # @pytest.mark.parametrize(
"conditions_dict, max_conditions_lengths", # "conditions_dict, max_conditions_lengths",
[ # [
(conditions_dict_single, max_conditions_lengths_single), # (conditions_dict_single, max_conditions_lengths_single),
(conditions_dict_multi, max_conditions_lengths_multi), # (conditions_dict_multi, max_conditions_lengths_multi),
], # ],
) # )
def test_getitem(conditions_dict, max_conditions_lengths): # def test_getitem(conditions_dict, max_conditions_lengths):
dataset = PinaDatasetFactory( # dataset = PinaDatasetFactory(
conditions_dict, # conditions_dict,
max_conditions_lengths=max_conditions_lengths, # max_conditions_lengths=max_conditions_lengths,
automatic_batching=True, # automatic_batching=True,
) # )
data = dataset[50] # data = dataset[50]
assert isinstance(data, dict) # assert isinstance(data, dict)
assert all([isinstance(d["input"], Data) for d in data.values()]) # assert all([isinstance(d["input"], Data) for d in data.values()])
assert all([isinstance(d["target"], torch.Tensor) for d in data.values()]) # assert all([isinstance(d["target"], torch.Tensor) for d in data.values()])
assert all( # assert all(
[d["input"].x.shape == torch.Size((20, 10)) for d in data.values()] # [d["input"].x.shape == torch.Size((20, 10)) for d in data.values()]
) # )
assert all( # assert all(
[d["target"].shape == torch.Size((20, 10)) for d in data.values()] # [d["target"].shape == torch.Size((20, 10)) for d in data.values()]
) # )
assert all( # assert all(
[ # [
d["input"].edge_index.shape == torch.Size((2, 60)) # d["input"].edge_index.shape == torch.Size((2, 60))
for d in data.values() # for d in data.values()
] # ]
) # )
assert all([d["input"].edge_attr.shape[0] == 60 for d in data.values()]) # assert all([d["input"].edge_attr.shape[0] == 60 for d in data.values()])
data = dataset.fetch_from_idx_list([i for i in range(20)]) # data = dataset.fetch_from_idx_list([i for i in range(20)])
assert isinstance(data, dict) # assert isinstance(data, dict)
assert all([isinstance(d["input"], Data) for d in data.values()]) # assert all([isinstance(d["input"], Data) for d in data.values()])
assert all([isinstance(d["target"], torch.Tensor) for d in data.values()]) # assert all([isinstance(d["target"], torch.Tensor) for d in data.values()])
assert all( # assert all(
[d["input"].x.shape == torch.Size((400, 10)) for d in data.values()] # [d["input"].x.shape == torch.Size((400, 10)) for d in data.values()]
) # )
assert all( # assert all(
[d["target"].shape == torch.Size((20, 20, 10)) for d in data.values()] # [d["target"].shape == torch.Size((20, 20, 10)) for d in data.values()]
) # )
assert all( # assert all(
[ # [
d["input"].edge_index.shape == torch.Size((2, 1200)) # d["input"].edge_index.shape == torch.Size((2, 1200))
for d in data.values() # for d in data.values()
] # ]
) # )
assert all([d["input"].edge_attr.shape[0] == 1200 for d in data.values()]) # assert all([d["input"].edge_attr.shape[0] == 1200 for d in data.values()])
def test_input_single_condition(): # def test_input_single_condition():
dataset = PinaDatasetFactory( # dataset = PinaDatasetFactory(
conditions_dict_single, # conditions_dict_single,
max_conditions_lengths=max_conditions_lengths_single, # max_conditions_lengths=max_conditions_lengths_single,
automatic_batching=True, # automatic_batching=True,
) # )
input_ = dataset.input # input_ = dataset.input
assert isinstance(input_, dict) # assert isinstance(input_, dict)
assert isinstance(input_["data"], list) # assert isinstance(input_["data"], list)
assert all([isinstance(d, Data) for d in input_["data"]]) # assert all([isinstance(d, Data) for d in input_["data"]])
def test_input_multi_condition(): # def test_input_multi_condition():
dataset = PinaDatasetFactory( # dataset = PinaDatasetFactory(
conditions_dict_multi, # conditions_dict_multi,
max_conditions_lengths=max_conditions_lengths_multi, # max_conditions_lengths=max_conditions_lengths_multi,
automatic_batching=True, # automatic_batching=True,
) # )
input_ = dataset.input # input_ = dataset.input
assert isinstance(input_, dict) # assert isinstance(input_, dict)
assert isinstance(input_["data_1"], list) # assert isinstance(input_["data_1"], list)
assert all([isinstance(d, Data) for d in input_["data_1"]]) # assert all([isinstance(d, Data) for d in input_["data_1"]])
assert isinstance(input_["data_2"], list) # assert isinstance(input_["data_2"], list)
assert all([isinstance(d, Data) for d in input_["data_2"]]) # assert all([isinstance(d, Data) for d in input_["data_2"]])

View File

@@ -1,86 +1,86 @@
import torch # import torch
import pytest # import pytest
from pina.data.dataset import PinaDatasetFactory, PinaTensorDataset # from pina.data.dataset import PinaDatasetFactory, PinaTensorDataset
input_tensor = torch.rand((100, 10)) # input_tensor = torch.rand((100, 10))
output_tensor = torch.rand((100, 2)) # output_tensor = torch.rand((100, 2))
input_tensor_2 = torch.rand((50, 10)) # input_tensor_2 = torch.rand((50, 10))
output_tensor_2 = torch.rand((50, 2)) # output_tensor_2 = torch.rand((50, 2))
conditions_dict_single = { # conditions_dict_single = {
"data": { # "data": {
"input": input_tensor, # "input": input_tensor,
"target": output_tensor, # "target": output_tensor,
} # }
} # }
conditions_dict_single_multi = { # conditions_dict_single_multi = {
"data_1": { # "data_1": {
"input": input_tensor, # "input": input_tensor,
"target": output_tensor, # "target": output_tensor,
}, # },
"data_2": { # "data_2": {
"input": input_tensor_2, # "input": input_tensor_2,
"target": output_tensor_2, # "target": output_tensor_2,
}, # },
} # }
max_conditions_lengths_single = {"data": 100} # max_conditions_lengths_single = {"data": 100}
max_conditions_lengths_multi = {"data_1": 100, "data_2": 50} # max_conditions_lengths_multi = {"data_1": 100, "data_2": 50}
@pytest.mark.parametrize( # @pytest.mark.parametrize(
"conditions_dict, max_conditions_lengths", # "conditions_dict, max_conditions_lengths",
[ # [
(conditions_dict_single, max_conditions_lengths_single), # (conditions_dict_single, max_conditions_lengths_single),
(conditions_dict_single_multi, max_conditions_lengths_multi), # (conditions_dict_single_multi, max_conditions_lengths_multi),
], # ],
) # )
def test_constructor_tensor(conditions_dict, max_conditions_lengths): # def test_constructor_tensor(conditions_dict, max_conditions_lengths):
dataset = PinaDatasetFactory( # dataset = PinaDatasetFactory(
conditions_dict, # conditions_dict,
max_conditions_lengths=max_conditions_lengths, # max_conditions_lengths=max_conditions_lengths,
automatic_batching=True, # automatic_batching=True,
) # )
assert isinstance(dataset, PinaTensorDataset) # assert isinstance(dataset, PinaTensorDataset)
def test_getitem_single(): # def test_getitem_single():
dataset = PinaDatasetFactory( # dataset = PinaDatasetFactory(
conditions_dict_single, # conditions_dict_single,
max_conditions_lengths=max_conditions_lengths_single, # max_conditions_lengths=max_conditions_lengths_single,
automatic_batching=False, # automatic_batching=False,
) # )
tensors = dataset.fetch_from_idx_list([i for i in range(70)]) # tensors = dataset.fetch_from_idx_list([i for i in range(70)])
assert isinstance(tensors, dict) # assert isinstance(tensors, dict)
assert list(tensors.keys()) == ["data"] # assert list(tensors.keys()) == ["data"]
assert sorted(list(tensors["data"].keys())) == ["input", "target"] # assert sorted(list(tensors["data"].keys())) == ["input", "target"]
assert isinstance(tensors["data"]["input"], torch.Tensor) # assert isinstance(tensors["data"]["input"], torch.Tensor)
assert tensors["data"]["input"].shape == torch.Size((70, 10)) # assert tensors["data"]["input"].shape == torch.Size((70, 10))
assert isinstance(tensors["data"]["target"], torch.Tensor) # assert isinstance(tensors["data"]["target"], torch.Tensor)
assert tensors["data"]["target"].shape == torch.Size((70, 2)) # assert tensors["data"]["target"].shape == torch.Size((70, 2))
def test_getitem_multi(): # def test_getitem_multi():
dataset = PinaDatasetFactory( # dataset = PinaDatasetFactory(
conditions_dict_single_multi, # conditions_dict_single_multi,
max_conditions_lengths=max_conditions_lengths_multi, # max_conditions_lengths=max_conditions_lengths_multi,
automatic_batching=False, # automatic_batching=False,
) # )
tensors = dataset.fetch_from_idx_list([i for i in range(70)]) # tensors = dataset.fetch_from_idx_list([i for i in range(70)])
assert isinstance(tensors, dict) # assert isinstance(tensors, dict)
assert list(tensors.keys()) == ["data_1", "data_2"] # assert list(tensors.keys()) == ["data_1", "data_2"]
assert sorted(list(tensors["data_1"].keys())) == ["input", "target"] # assert sorted(list(tensors["data_1"].keys())) == ["input", "target"]
assert isinstance(tensors["data_1"]["input"], torch.Tensor) # assert isinstance(tensors["data_1"]["input"], torch.Tensor)
assert tensors["data_1"]["input"].shape == torch.Size((70, 10)) # assert tensors["data_1"]["input"].shape == torch.Size((70, 10))
assert isinstance(tensors["data_1"]["target"], torch.Tensor) # assert isinstance(tensors["data_1"]["target"], torch.Tensor)
assert tensors["data_1"]["target"].shape == torch.Size((70, 2)) # assert tensors["data_1"]["target"].shape == torch.Size((70, 2))
assert sorted(list(tensors["data_2"].keys())) == ["input", "target"] # assert sorted(list(tensors["data_2"].keys())) == ["input", "target"]
assert isinstance(tensors["data_2"]["input"], torch.Tensor) # assert isinstance(tensors["data_2"]["input"], torch.Tensor)
assert tensors["data_2"]["input"].shape == torch.Size((50, 10)) # assert tensors["data_2"]["input"].shape == torch.Size((50, 10))
assert isinstance(tensors["data_2"]["target"], torch.Tensor) # assert isinstance(tensors["data_2"]["target"], torch.Tensor)
assert tensors["data_2"]["target"].shape == torch.Size((50, 2)) # assert tensors["data_2"]["target"].shape == torch.Size((50, 2))