Bug fix and add additional tests for Dataset and DataModule (#517)
This commit is contained in:
committed by
FilippoOlivo
parent
79a7199985
commit
80c257da4d
@@ -217,12 +217,11 @@ class PinaSampler:
|
|||||||
parameter and the environment in which the code is running.
|
parameter and the environment in which the code is running.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __new__(cls, dataset, shuffle):
|
def __new__(cls, dataset):
|
||||||
"""
|
"""
|
||||||
Instantiate and initialize the sampler.
|
Instantiate and initialize the sampler.
|
||||||
|
|
||||||
:param PinaDataset dataset: The dataset from which to sample.
|
:param PinaDataset dataset: The dataset from which to sample.
|
||||||
:param bool shuffle: Whether to shuffle the dataset.
|
|
||||||
:return: The sampler instance.
|
:return: The sampler instance.
|
||||||
:rtype: :class:`torch.utils.data.Sampler`
|
:rtype: :class:`torch.utils.data.Sampler`
|
||||||
"""
|
"""
|
||||||
@@ -231,12 +230,9 @@ class PinaSampler:
|
|||||||
torch.distributed.is_available()
|
torch.distributed.is_available()
|
||||||
and torch.distributed.is_initialized()
|
and torch.distributed.is_initialized()
|
||||||
):
|
):
|
||||||
sampler = DistributedSampler(dataset, shuffle=shuffle)
|
sampler = DistributedSampler(dataset)
|
||||||
else:
|
else:
|
||||||
if shuffle:
|
sampler = SequentialSampler(dataset)
|
||||||
sampler = RandomSampler(dataset)
|
|
||||||
else:
|
|
||||||
sampler = SequentialSampler(dataset)
|
|
||||||
return sampler
|
return sampler
|
||||||
|
|
||||||
|
|
||||||
@@ -496,8 +492,6 @@ class PinaDataModule(LightningDataModule):
|
|||||||
:return: The dataloader for the given split.
|
:return: The dataloader for the given split.
|
||||||
:rtype: torch.utils.data.DataLoader
|
:rtype: torch.utils.data.DataLoader
|
||||||
"""
|
"""
|
||||||
|
|
||||||
shuffle = self.shuffle if split == "train" else False
|
|
||||||
# Suppress the warning about num_workers.
|
# Suppress the warning about num_workers.
|
||||||
# In many cases, especially for PINNs,
|
# In many cases, especially for PINNs,
|
||||||
# serial data loading can outperform parallel data loading.
|
# serial data loading can outperform parallel data loading.
|
||||||
@@ -511,7 +505,7 @@ class PinaDataModule(LightningDataModule):
|
|||||||
)
|
)
|
||||||
# Use custom batching (good if batch size is large)
|
# Use custom batching (good if batch size is large)
|
||||||
if self.batch_size is not None:
|
if self.batch_size is not None:
|
||||||
sampler = PinaSampler(dataset, shuffle)
|
sampler = PinaSampler(dataset)
|
||||||
if self.automatic_batching:
|
if self.automatic_batching:
|
||||||
collate = Collator(
|
collate = Collator(
|
||||||
self.find_max_conditions_lengths(split),
|
self.find_max_conditions_lengths(split),
|
||||||
|
|||||||
@@ -167,9 +167,15 @@ class PinaDataset(Dataset, ABC):
|
|||||||
:return: A dictionary containing all the data in the dataset.
|
:return: A dictionary containing all the data in the dataset.
|
||||||
:rtype: dict
|
:rtype: dict
|
||||||
"""
|
"""
|
||||||
|
to_return_dict = {}
|
||||||
index = list(range(len(self)))
|
for condition, data in self.conditions_dict.items():
|
||||||
return self.fetch_from_idx_list(index)
|
len_condition = len(
|
||||||
|
data["input"]
|
||||||
|
) # Length of the current condition
|
||||||
|
to_return_dict[condition] = self._retrive_data(
|
||||||
|
data, list(range(len_condition))
|
||||||
|
) # Retrieve the data from the current condition
|
||||||
|
return to_return_dict
|
||||||
|
|
||||||
def fetch_from_idx_list(self, idx):
|
def fetch_from_idx_list(self, idx):
|
||||||
"""
|
"""
|
||||||
@@ -306,3 +312,13 @@ class PinaGraphDataset(PinaDataset):
|
|||||||
)
|
)
|
||||||
for k, v in data.items()
|
for k, v in data.items()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def input(self):
|
||||||
|
"""
|
||||||
|
Return the input data for the dataset.
|
||||||
|
|
||||||
|
:return: Dictionary containing the input points.
|
||||||
|
:rtype: dict
|
||||||
|
"""
|
||||||
|
return {k: v["input"] for k, v in self.conditions_dict.items()}
|
||||||
|
|||||||
@@ -238,3 +238,94 @@ def test_dataloader_labels(input_, output_, automatic_batching):
|
|||||||
assert data["data"]["input"].labels == ["u", "v", "w"]
|
assert data["data"]["input"].labels == ["u", "v", "w"]
|
||||||
assert isinstance(data["data"]["target"], torch.Tensor)
|
assert isinstance(data["data"]["target"], torch.Tensor)
|
||||||
assert data["data"]["target"].labels == ["u", "v", "w"]
|
assert data["data"]["target"].labels == ["u", "v", "w"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_all_data():
|
||||||
|
input = torch.stack([torch.zeros((1,)) + i for i in range(1000)])
|
||||||
|
target = input
|
||||||
|
|
||||||
|
problem = SupervisedProblem(input, target)
|
||||||
|
datamodule = PinaDataModule(
|
||||||
|
problem,
|
||||||
|
train_size=0.7,
|
||||||
|
test_size=0.2,
|
||||||
|
val_size=0.1,
|
||||||
|
batch_size=64,
|
||||||
|
shuffle=False,
|
||||||
|
repeat=False,
|
||||||
|
automatic_batching=None,
|
||||||
|
num_workers=0,
|
||||||
|
pin_memory=False,
|
||||||
|
)
|
||||||
|
datamodule.setup("fit")
|
||||||
|
datamodule.setup("test")
|
||||||
|
assert len(datamodule.train_dataset.get_all_data()["data"]["input"]) == 700
|
||||||
|
assert torch.isclose(
|
||||||
|
datamodule.train_dataset.get_all_data()["data"]["input"], input[:700]
|
||||||
|
).all()
|
||||||
|
assert len(datamodule.val_dataset.get_all_data()["data"]["input"]) == 100
|
||||||
|
assert torch.isclose(
|
||||||
|
datamodule.val_dataset.get_all_data()["data"]["input"], input[900:]
|
||||||
|
).all()
|
||||||
|
assert len(datamodule.test_dataset.get_all_data()["data"]["input"]) == 200
|
||||||
|
assert torch.isclose(
|
||||||
|
datamodule.test_dataset.get_all_data()["data"]["input"], input[700:900]
|
||||||
|
).all()
|
||||||
|
|
||||||
|
|
||||||
|
def test_input_propery_tensor():
|
||||||
|
input = torch.stack([torch.zeros((1,)) + i for i in range(1000)])
|
||||||
|
target = input
|
||||||
|
|
||||||
|
problem = SupervisedProblem(input, target)
|
||||||
|
datamodule = PinaDataModule(
|
||||||
|
problem,
|
||||||
|
train_size=0.7,
|
||||||
|
test_size=0.2,
|
||||||
|
val_size=0.1,
|
||||||
|
batch_size=64,
|
||||||
|
shuffle=False,
|
||||||
|
repeat=False,
|
||||||
|
automatic_batching=None,
|
||||||
|
num_workers=0,
|
||||||
|
pin_memory=False,
|
||||||
|
)
|
||||||
|
datamodule.setup("fit")
|
||||||
|
datamodule.setup("test")
|
||||||
|
input_ = datamodule.input
|
||||||
|
assert isinstance(input_, dict)
|
||||||
|
assert isinstance(input_["train"], dict)
|
||||||
|
assert isinstance(input_["val"], dict)
|
||||||
|
assert isinstance(input_["test"], dict)
|
||||||
|
assert torch.isclose(input_["train"]["data"], input[:700]).all()
|
||||||
|
assert torch.isclose(input_["val"]["data"], input[900:]).all()
|
||||||
|
assert torch.isclose(input_["test"]["data"], input[700:900]).all()
|
||||||
|
|
||||||
|
|
||||||
|
def test_input_propery_graph():
|
||||||
|
problem = SupervisedProblem(input_graph, output_graph)
|
||||||
|
datamodule = PinaDataModule(
|
||||||
|
problem,
|
||||||
|
train_size=0.7,
|
||||||
|
test_size=0.2,
|
||||||
|
val_size=0.1,
|
||||||
|
batch_size=64,
|
||||||
|
shuffle=False,
|
||||||
|
repeat=False,
|
||||||
|
automatic_batching=None,
|
||||||
|
num_workers=0,
|
||||||
|
pin_memory=False,
|
||||||
|
)
|
||||||
|
datamodule.setup("fit")
|
||||||
|
datamodule.setup("test")
|
||||||
|
input_ = datamodule.input
|
||||||
|
assert isinstance(input_, dict)
|
||||||
|
assert isinstance(input_["train"], dict)
|
||||||
|
assert isinstance(input_["val"], dict)
|
||||||
|
assert isinstance(input_["test"], dict)
|
||||||
|
assert isinstance(input_["train"]["data"], list)
|
||||||
|
assert isinstance(input_["val"]["data"], list)
|
||||||
|
assert isinstance(input_["test"]["data"], list)
|
||||||
|
assert len(input_["train"]["data"]) == 70
|
||||||
|
assert len(input_["val"]["data"]) == 10
|
||||||
|
assert len(input_["test"]["data"]) == 20
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ conditions_dict_single = {
|
|||||||
max_conditions_lengths_single = {"data": 100}
|
max_conditions_lengths_single = {"data": 100}
|
||||||
|
|
||||||
# Problem with multiple conditions
|
# Problem with multiple conditions
|
||||||
conditions_dict_single_multi = {
|
conditions_dict_multi = {
|
||||||
"data_1": {
|
"data_1": {
|
||||||
"input": input_,
|
"input": input_,
|
||||||
"target": output_,
|
"target": output_,
|
||||||
@@ -49,7 +49,7 @@ max_conditions_lengths_multi = {"data_1": 100, "data_2": 50}
|
|||||||
"conditions_dict, max_conditions_lengths",
|
"conditions_dict, max_conditions_lengths",
|
||||||
[
|
[
|
||||||
(conditions_dict_single, max_conditions_lengths_single),
|
(conditions_dict_single, max_conditions_lengths_single),
|
||||||
(conditions_dict_single_multi, max_conditions_lengths_multi),
|
(conditions_dict_multi, max_conditions_lengths_multi),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_constructor(conditions_dict, max_conditions_lengths):
|
def test_constructor(conditions_dict, max_conditions_lengths):
|
||||||
@@ -66,7 +66,7 @@ def test_constructor(conditions_dict, max_conditions_lengths):
|
|||||||
"conditions_dict, max_conditions_lengths",
|
"conditions_dict, max_conditions_lengths",
|
||||||
[
|
[
|
||||||
(conditions_dict_single, max_conditions_lengths_single),
|
(conditions_dict_single, max_conditions_lengths_single),
|
||||||
(conditions_dict_single_multi, max_conditions_lengths_multi),
|
(conditions_dict_multi, max_conditions_lengths_multi),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_getitem(conditions_dict, max_conditions_lengths):
|
def test_getitem(conditions_dict, max_conditions_lengths):
|
||||||
@@ -110,3 +110,29 @@ def test_getitem(conditions_dict, max_conditions_lengths):
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
assert all([d["input"].edge_attr.shape[0] == 1200 for d in data.values()])
|
assert all([d["input"].edge_attr.shape[0] == 1200 for d in data.values()])
|
||||||
|
|
||||||
|
|
||||||
|
def test_input_single_condition():
|
||||||
|
dataset = PinaDatasetFactory(
|
||||||
|
conditions_dict_single,
|
||||||
|
max_conditions_lengths=max_conditions_lengths_single,
|
||||||
|
automatic_batching=True,
|
||||||
|
)
|
||||||
|
input_ = dataset.input
|
||||||
|
assert isinstance(input_, dict)
|
||||||
|
assert isinstance(input_["data"], list)
|
||||||
|
assert all([isinstance(d, Data) for d in input_["data"]])
|
||||||
|
|
||||||
|
|
||||||
|
def test_input_multi_condition():
|
||||||
|
dataset = PinaDatasetFactory(
|
||||||
|
conditions_dict_multi,
|
||||||
|
max_conditions_lengths=max_conditions_lengths_multi,
|
||||||
|
automatic_batching=True,
|
||||||
|
)
|
||||||
|
input_ = dataset.input
|
||||||
|
assert isinstance(input_, dict)
|
||||||
|
assert isinstance(input_["data_1"], list)
|
||||||
|
assert all([isinstance(d, Data) for d in input_["data_1"]])
|
||||||
|
assert isinstance(input_["data_2"], list)
|
||||||
|
assert all([isinstance(d, Data) for d in input_["data_2"]])
|
||||||
|
|||||||
Reference in New Issue
Block a user