Fix bug in Collector with Graph data (#456)
* Fix bug in Collector with Graph data * Add comments in DataModule class and bug fix in collate
This commit is contained in:
committed by
Nicola Demo
parent
dfd6d7b467
commit
9c9d4fe7e4
@@ -1,3 +1,7 @@
|
|||||||
|
"""
|
||||||
|
# TODO
|
||||||
|
"""
|
||||||
|
from .graph import Graph
|
||||||
from .utils import check_consistency
|
from .utils import check_consistency
|
||||||
|
|
||||||
|
|
||||||
@@ -52,6 +56,8 @@ class Collector:
|
|||||||
# get data
|
# get data
|
||||||
keys = condition.__slots__
|
keys = condition.__slots__
|
||||||
values = [getattr(condition, name) for name in keys]
|
values = [getattr(condition, name) for name in keys]
|
||||||
|
values = [value.data if isinstance(
|
||||||
|
value, Graph) else value for value in values]
|
||||||
self.data_collections[condition_name] = dict(zip(keys, values))
|
self.data_collections[condition_name] = dict(zip(keys, values))
|
||||||
# condition now is ready
|
# condition now is ready
|
||||||
self._is_conditions_ready[condition_name] = True
|
self._is_conditions_ready[condition_name] = True
|
||||||
|
|||||||
@@ -2,11 +2,11 @@ import logging
|
|||||||
import warnings
|
import warnings
|
||||||
from lightning.pytorch import LightningDataModule
|
from lightning.pytorch import LightningDataModule
|
||||||
import torch
|
import torch
|
||||||
from ..label_tensor import LabelTensor
|
from torch_geometric.data import Data
|
||||||
from torch.utils.data import DataLoader, BatchSampler, SequentialSampler, \
|
from torch.utils.data import DataLoader, SequentialSampler, RandomSampler
|
||||||
RandomSampler
|
|
||||||
from torch.utils.data.distributed import DistributedSampler
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
from .dataset import PinaDatasetFactory
|
from ..label_tensor import LabelTensor
|
||||||
|
from .dataset import PinaDatasetFactory, PinaTensorDataset
|
||||||
from ..collector import Collector
|
from ..collector import Collector
|
||||||
|
|
||||||
|
|
||||||
@@ -61,6 +61,10 @@ class Collator:
|
|||||||
max_conditions_lengths is None else (
|
max_conditions_lengths is None else (
|
||||||
self._collate_standard_dataloader)
|
self._collate_standard_dataloader)
|
||||||
self.dataset = dataset
|
self.dataset = dataset
|
||||||
|
if isinstance(self.dataset, PinaTensorDataset):
|
||||||
|
self._collate = self._collate_tensor_dataset
|
||||||
|
else:
|
||||||
|
self._collate = self._collate_graph_dataset
|
||||||
|
|
||||||
def _collate_custom_dataloader(self, batch):
|
def _collate_custom_dataloader(self, batch):
|
||||||
return self.dataset.fetch_from_idx_list(batch)
|
return self.dataset.fetch_from_idx_list(batch)
|
||||||
@@ -73,7 +77,6 @@ class Collator:
|
|||||||
if isinstance(batch, dict):
|
if isinstance(batch, dict):
|
||||||
return batch
|
return batch
|
||||||
conditions_names = batch[0].keys()
|
conditions_names = batch[0].keys()
|
||||||
|
|
||||||
# Condition names
|
# Condition names
|
||||||
for condition_name in conditions_names:
|
for condition_name in conditions_names:
|
||||||
single_cond_dict = {}
|
single_cond_dict = {}
|
||||||
@@ -82,16 +85,28 @@ class Collator:
|
|||||||
data_list = [batch[idx][condition_name][arg] for idx in range(
|
data_list = [batch[idx][condition_name][arg] for idx in range(
|
||||||
min(len(batch),
|
min(len(batch),
|
||||||
self.max_conditions_lengths[condition_name]))]
|
self.max_conditions_lengths[condition_name]))]
|
||||||
if isinstance(data_list[0], LabelTensor):
|
single_cond_dict[arg] = self._collate(data_list)
|
||||||
single_cond_dict[arg] = LabelTensor.stack(data_list)
|
|
||||||
elif isinstance(data_list[0], torch.Tensor):
|
|
||||||
single_cond_dict[arg] = torch.stack(data_list)
|
|
||||||
else:
|
|
||||||
raise NotImplementedError(
|
|
||||||
f"Data type {type(data_list[0])} not supported")
|
|
||||||
batch_dict[condition_name] = single_cond_dict
|
batch_dict[condition_name] = single_cond_dict
|
||||||
return batch_dict
|
return batch_dict
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _collate_tensor_dataset(data_list):
|
||||||
|
if isinstance(data_list[0], LabelTensor):
|
||||||
|
return LabelTensor.stack(data_list)
|
||||||
|
if isinstance(data_list[0], torch.Tensor):
|
||||||
|
return torch.stack(data_list)
|
||||||
|
raise RuntimeError("Data must be Tensors or LabelTensor ")
|
||||||
|
|
||||||
|
def _collate_graph_dataset(self, data_list):
|
||||||
|
if isinstance(data_list[0], LabelTensor):
|
||||||
|
return LabelTensor.cat(data_list)
|
||||||
|
if isinstance(data_list[0], torch.Tensor):
|
||||||
|
return torch.cat(data_list)
|
||||||
|
if isinstance(data_list[0], Data):
|
||||||
|
return self.dataset.create_graph_batch(data_list)
|
||||||
|
raise RuntimeError("Data must be Tensors or LabelTensor or pyG Data")
|
||||||
|
|
||||||
def __call__(self, batch):
|
def __call__(self, batch):
|
||||||
return self.callable_function(batch)
|
return self.callable_function(batch)
|
||||||
|
|
||||||
@@ -125,7 +140,7 @@ class PinaDataModule(LightningDataModule):
|
|||||||
batch_size=None,
|
batch_size=None,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
repeat=False,
|
repeat=False,
|
||||||
automatic_batching=False,
|
automatic_batching=None,
|
||||||
num_workers=0,
|
num_workers=0,
|
||||||
pin_memory=False,
|
pin_memory=False,
|
||||||
):
|
):
|
||||||
@@ -158,15 +173,35 @@ class PinaDataModule(LightningDataModule):
|
|||||||
logging.debug('Start initialization of Pina DataModule')
|
logging.debug('Start initialization of Pina DataModule')
|
||||||
logging.info('Start initialization of Pina DataModule')
|
logging.info('Start initialization of Pina DataModule')
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.automatic_batching = automatic_batching
|
|
||||||
|
# Store fixed attributes
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.shuffle = shuffle
|
self.shuffle = shuffle
|
||||||
self.repeat = repeat
|
self.repeat = repeat
|
||||||
|
self.automatic_batching = automatic_batching
|
||||||
|
if batch_size is None and num_workers != 0:
|
||||||
|
warnings.warn(
|
||||||
|
"Setting num_workers when batch_size is None has no effect on "
|
||||||
|
"the DataLoading process.")
|
||||||
|
self.num_workers = 0
|
||||||
|
else:
|
||||||
|
self.num_workers = num_workers
|
||||||
|
if batch_size is None and pin_memory:
|
||||||
|
warnings.warn("Setting pin_memory to True has no effect when "
|
||||||
|
"batch_size is None.")
|
||||||
|
self.pin_memory = False
|
||||||
|
else:
|
||||||
|
self.pin_memory = pin_memory
|
||||||
|
|
||||||
|
# Collect data
|
||||||
|
collector = Collector(problem)
|
||||||
|
collector.store_fixed_data()
|
||||||
|
collector.store_sample_domains()
|
||||||
|
|
||||||
# Check if the splits are correct
|
# Check if the splits are correct
|
||||||
self._check_slit_sizes(train_size, test_size, val_size, predict_size)
|
self._check_slit_sizes(train_size, test_size, val_size, predict_size)
|
||||||
|
|
||||||
# Begin Data splitting
|
# Split input data into subsets
|
||||||
splits_dict = {}
|
splits_dict = {}
|
||||||
if train_size > 0:
|
if train_size > 0:
|
||||||
splits_dict['train'] = train_size
|
splits_dict['train'] = train_size
|
||||||
@@ -188,19 +223,6 @@ class PinaDataModule(LightningDataModule):
|
|||||||
self.predict_dataset = None
|
self.predict_dataset = None
|
||||||
else:
|
else:
|
||||||
self.predict_dataloader = super().predict_dataloader
|
self.predict_dataloader = super().predict_dataloader
|
||||||
|
|
||||||
collector = Collector(problem)
|
|
||||||
collector.store_fixed_data()
|
|
||||||
collector.store_sample_domains()
|
|
||||||
if batch_size is None and num_workers != 0:
|
|
||||||
warnings.warn(
|
|
||||||
"Setting num_workers when batch_size is None has no effect on "
|
|
||||||
"the DataLoading process.")
|
|
||||||
if batch_size is None and pin_memory:
|
|
||||||
warnings.warn("Setting pin_memory to True has no effect when "
|
|
||||||
"batch_size is None.")
|
|
||||||
self.num_workers = num_workers
|
|
||||||
self.pin_memory = pin_memory
|
|
||||||
self.collector_splits = self._create_splits(collector, splits_dict)
|
self.collector_splits = self._create_splits(collector, splits_dict)
|
||||||
self.transfer_batch_to_device = self._transfer_batch_to_device
|
self.transfer_batch_to_device = self._transfer_batch_to_device
|
||||||
|
|
||||||
@@ -316,10 +338,10 @@ class PinaDataModule(LightningDataModule):
|
|||||||
if self.batch_size is not None:
|
if self.batch_size is not None:
|
||||||
sampler = PinaSampler(dataset, shuffle)
|
sampler = PinaSampler(dataset, shuffle)
|
||||||
if self.automatic_batching:
|
if self.automatic_batching:
|
||||||
collate = Collator(self.find_max_conditions_lengths(split))
|
collate = Collator(self.find_max_conditions_lengths(split),
|
||||||
|
dataset=dataset)
|
||||||
else:
|
else:
|
||||||
collate = Collator(None, dataset)
|
collate = Collator(None, dataset=dataset)
|
||||||
return DataLoader(dataset, self.batch_size,
|
return DataLoader(dataset, self.batch_size,
|
||||||
collate_fn=collate, sampler=sampler,
|
collate_fn=collate, sampler=sampler,
|
||||||
num_workers=self.num_workers)
|
num_workers=self.num_workers)
|
||||||
|
|||||||
@@ -1,10 +1,12 @@
|
|||||||
"""
|
"""
|
||||||
This module provide basic data management functionalities
|
This module provide basic data management functionalities
|
||||||
"""
|
"""
|
||||||
|
import functools
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from torch_geometric.data import Batch
|
from torch_geometric.data import Batch, Data
|
||||||
|
from pina import LabelTensor
|
||||||
|
|
||||||
|
|
||||||
class PinaDatasetFactory:
|
class PinaDatasetFactory:
|
||||||
@@ -62,7 +64,7 @@ class PinaTensorDataset(PinaDataset):
|
|||||||
if automatic_batching:
|
if automatic_batching:
|
||||||
self._getitem_func = self._getitem_int
|
self._getitem_func = self._getitem_int
|
||||||
else:
|
else:
|
||||||
self._getitem_func = self._getitem_list
|
self._getitem_func = self._getitem_dummy
|
||||||
|
|
||||||
def _getitem_int(self, idx):
|
def _getitem_int(self, idx):
|
||||||
return {
|
return {
|
||||||
@@ -82,7 +84,7 @@ class PinaTensorDataset(PinaDataset):
|
|||||||
return to_return_dict
|
return to_return_dict
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _getitem_list(idx):
|
def _getitem_dummy(idx):
|
||||||
return idx
|
return idx
|
||||||
|
|
||||||
def get_all_data(self):
|
def get_all_data(self):
|
||||||
@@ -102,15 +104,56 @@ class PinaTensorDataset(PinaDataset):
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class PinaBatch(Batch):
|
||||||
|
"""
|
||||||
|
Add extract function to torch_geometric Batch object
|
||||||
|
"""
|
||||||
|
def __init__(self):
|
||||||
|
|
||||||
|
super().__init__(self)
|
||||||
|
|
||||||
|
def extract(self, labels):
|
||||||
|
"""
|
||||||
|
Perform extraction of labels on node features (x)
|
||||||
|
|
||||||
|
:param labels: Labels to extract
|
||||||
|
:type labels: list[str] | tuple[str] | str
|
||||||
|
:return: Batch object with extraction performed on x
|
||||||
|
:rtype: PinaBatch
|
||||||
|
"""
|
||||||
|
self.x = self.x.extract(labels)
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
class PinaGraphDataset(PinaDataset):
|
class PinaGraphDataset(PinaDataset):
|
||||||
|
|
||||||
def __init__(self, conditions_dict, max_conditions_lengths,
|
def __init__(self, conditions_dict, max_conditions_lengths,
|
||||||
automatic_batching):
|
automatic_batching):
|
||||||
super().__init__(conditions_dict, max_conditions_lengths)
|
super().__init__(conditions_dict, max_conditions_lengths)
|
||||||
|
self.in_labels = {}
|
||||||
|
self.out_labels = None
|
||||||
if automatic_batching:
|
if automatic_batching:
|
||||||
self._getitem_func = self._getitem_int
|
self._getitem_func = self._getitem_int
|
||||||
else:
|
else:
|
||||||
self._getitem_func = self._getitem_list
|
self._getitem_func = self._getitem_dummy
|
||||||
|
|
||||||
|
ex_data = conditions_dict[list(conditions_dict.keys())[
|
||||||
|
0]]['input_points'][0]
|
||||||
|
for name, attr in ex_data.items():
|
||||||
|
if isinstance(attr, LabelTensor):
|
||||||
|
self.in_labels[name] = attr.stored_labels
|
||||||
|
ex_data = conditions_dict[list(conditions_dict.keys())[
|
||||||
|
0]]['output_points'][0]
|
||||||
|
if isinstance(ex_data, LabelTensor):
|
||||||
|
self.out_labels = ex_data.labels
|
||||||
|
|
||||||
|
self._create_graph_batch_from_list = self._labelise_batch(
|
||||||
|
self._base_create_graph_batch_from_list) if self.in_labels \
|
||||||
|
else self._base_create_graph_batch_from_list
|
||||||
|
|
||||||
|
self._create_output_batch = self._labelise_tensor(
|
||||||
|
self._base_create_output_batch) if self.out_labels is not None \
|
||||||
|
else self._base_create_output_batch
|
||||||
|
|
||||||
def fetch_from_idx_list(self, idx):
|
def fetch_from_idx_list(self, idx):
|
||||||
to_return_dict = {}
|
to_return_dict = {}
|
||||||
@@ -119,17 +162,24 @@ class PinaGraphDataset(PinaDataset):
|
|||||||
condition_len = self.conditions_length[condition]
|
condition_len = self.conditions_length[condition]
|
||||||
if self.length > condition_len:
|
if self.length > condition_len:
|
||||||
cond_idx = [idx % condition_len for idx in cond_idx]
|
cond_idx = [idx % condition_len for idx in cond_idx]
|
||||||
to_return_dict[condition] = {k: Batch.from_data_list([
|
to_return_dict[condition] = {
|
||||||
v[i] for i in cond_idx])
|
k: self._create_graph_batch_from_list([v[i] for i in idx])
|
||||||
if isinstance(v, list)
|
if isinstance(v, list)
|
||||||
else v[
|
else self._create_output_batch(v[idx])
|
||||||
cond_idx].reshape(
|
for k, v in data.items()
|
||||||
-1, *v[cond_idx].shape[2:])
|
}
|
||||||
for k, v in data.items()
|
|
||||||
}
|
|
||||||
return to_return_dict
|
return to_return_dict
|
||||||
|
|
||||||
def _getitem_list(self, idx):
|
def _base_create_graph_batch_from_list(self, data):
|
||||||
|
batch = PinaBatch.from_data_list(data)
|
||||||
|
return batch
|
||||||
|
|
||||||
|
def _base_create_output_batch(self, data):
|
||||||
|
out = data.reshape(-1, *data.shape[2:])
|
||||||
|
return out
|
||||||
|
|
||||||
|
def _getitem_dummy(self, idx):
|
||||||
return idx
|
return idx
|
||||||
|
|
||||||
def _getitem_int(self, idx):
|
def _getitem_int(self, idx):
|
||||||
@@ -144,3 +194,31 @@ class PinaGraphDataset(PinaDataset):
|
|||||||
|
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
return self._getitem_func(idx)
|
return self._getitem_func(idx)
|
||||||
|
|
||||||
|
def _labelise_batch(self, func):
|
||||||
|
@functools.wraps(func)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
batch = func(*args, **kwargs)
|
||||||
|
for k, v in self.in_labels.items():
|
||||||
|
tmp = batch[k]
|
||||||
|
tmp.labels = v
|
||||||
|
batch[k] = tmp
|
||||||
|
return batch
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
def _labelise_tensor(self, func):
|
||||||
|
@functools.wraps(func)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
out = func(*args, **kwargs)
|
||||||
|
if isinstance(out, LabelTensor):
|
||||||
|
out.labels = self.out_labels
|
||||||
|
return out
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
def create_graph_batch(self, data):
|
||||||
|
"""
|
||||||
|
# TODO
|
||||||
|
"""
|
||||||
|
if isinstance(data[0], Data):
|
||||||
|
return self._create_graph_batch_from_list(data)
|
||||||
|
return self._create_output_batch(data)
|
||||||
|
|||||||
@@ -108,16 +108,14 @@ class Graph:
|
|||||||
x)
|
x)
|
||||||
|
|
||||||
# Perform the graph construction
|
# Perform the graph construction
|
||||||
self._build_graph_list(x, pos, edge_index, edge_attr, additional_params)
|
self._build_graph_list(
|
||||||
|
x, pos, edge_index, edge_attr, additional_params)
|
||||||
|
|
||||||
def _build_graph_list(self, x, pos, edge_index, edge_attr,
|
def _build_graph_list(self, x, pos, edge_index, edge_attr,
|
||||||
additional_params):
|
additional_params):
|
||||||
for i, (x_, pos_, edge_index_) in enumerate(zip(x, pos, edge_index)):
|
for i, (x_, pos_, edge_index_) in enumerate(zip(x, pos, edge_index)):
|
||||||
if isinstance(x_, LabelTensor):
|
|
||||||
x_ = x_.tensor
|
|
||||||
add_params_local = {k: v[i] for k, v in additional_params.items()}
|
add_params_local = {k: v[i] for k, v in additional_params.items()}
|
||||||
if edge_attr is not None:
|
if edge_attr is not None:
|
||||||
|
|
||||||
self.data.append(Data(x=x_, pos=pos_, edge_index=edge_index_,
|
self.data.append(Data(x=x_, pos=pos_, edge_index=edge_index_,
|
||||||
edge_attr=edge_attr[i],
|
edge_attr=edge_attr[i],
|
||||||
**add_params_local))
|
**add_params_local))
|
||||||
@@ -127,7 +125,8 @@ class Graph:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _build_edge_attr(x, pos, edge_index):
|
def _build_edge_attr(x, pos, edge_index):
|
||||||
distance = torch.abs(pos[edge_index[0]] - pos[edge_index[1]])
|
distance = torch.abs(pos[edge_index[0]] -
|
||||||
|
pos[edge_index[1]]).as_subclass(torch.Tensor)
|
||||||
return distance
|
return distance
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -165,7 +164,8 @@ class Graph:
|
|||||||
# If edge_index is a 3D tensor, we split it into a list of 2D tensors
|
# If edge_index is a 3D tensor, we split it into a list of 2D tensors
|
||||||
if edge_index is not None:
|
if edge_index is not None:
|
||||||
if isinstance(edge_index, torch.Tensor) and edge_index.ndim == 3:
|
if isinstance(edge_index, torch.Tensor) and edge_index.ndim == 3:
|
||||||
edge_index = [edge_index[i] for i in range(edge_index.shape[0])]
|
edge_index = [edge_index[i]
|
||||||
|
for i in range(edge_index.shape[0])]
|
||||||
elif not (isinstance(edge_index, list) and all(
|
elif not (isinstance(edge_index, list) and all(
|
||||||
t.ndim == 2 for t in edge_index)) and not (
|
t.ndim == 2 for t in edge_index)) and not (
|
||||||
isinstance(edge_index,
|
isinstance(edge_index,
|
||||||
@@ -219,7 +219,7 @@ class Graph:
|
|||||||
if isinstance(edge_attr, list):
|
if isinstance(edge_attr, list):
|
||||||
if len(edge_attr) != data_len:
|
if len(edge_attr) != data_len:
|
||||||
raise TypeError("edge_attr must have the same length as x "
|
raise TypeError("edge_attr must have the same length as x "
|
||||||
"and pos.")
|
"and pos.")
|
||||||
return [edge_attr] * data_len
|
return [edge_attr] * data_len
|
||||||
|
|
||||||
if build_edge_attr:
|
if build_edge_attr:
|
||||||
@@ -258,6 +258,8 @@ class RadiusGraph(Graph):
|
|||||||
"""
|
"""
|
||||||
dist = torch.cdist(points, points, p=2)
|
dist = torch.cdist(points, points, p=2)
|
||||||
edge_index = torch.nonzero(dist <= r, as_tuple=False).t()
|
edge_index = torch.nonzero(dist <= r, as_tuple=False).t()
|
||||||
|
if isinstance(edge_index, LabelTensor):
|
||||||
|
edge_index = edge_index.tensor
|
||||||
return edge_index
|
return edge_index
|
||||||
|
|
||||||
|
|
||||||
@@ -293,4 +295,6 @@ class KNNGraph(Graph):
|
|||||||
row = torch.arange(points.size(0)).repeat_interleave(k)
|
row = torch.arange(points.size(0)).repeat_interleave(k)
|
||||||
col = knn_indices.flatten()
|
col = knn_indices.flatten()
|
||||||
edge_index = torch.stack([row, col], dim=0)
|
edge_index = torch.stack([row, col], dim=0)
|
||||||
|
if isinstance(edge_index, LabelTensor):
|
||||||
|
edge_index = edge_index.tensor
|
||||||
return edge_index
|
return edge_index
|
||||||
|
|||||||
@@ -105,9 +105,9 @@ class Trainer(lightning.pytorch.Trainer):
|
|||||||
# checking compilation and automatic batching
|
# checking compilation and automatic batching
|
||||||
if compile is None or sys.platform == "win32":
|
if compile is None or sys.platform == "win32":
|
||||||
compile = False
|
compile = False
|
||||||
if automatic_batching is None:
|
|
||||||
automatic_batching = False
|
|
||||||
|
|
||||||
|
self.automatic_batching = automatic_batching if automatic_batching \
|
||||||
|
is not None else False
|
||||||
# set attributes
|
# set attributes
|
||||||
self.compile = compile
|
self.compile = compile
|
||||||
self.solver = solver
|
self.solver = solver
|
||||||
@@ -115,7 +115,7 @@ class Trainer(lightning.pytorch.Trainer):
|
|||||||
self._move_to_device()
|
self._move_to_device()
|
||||||
self.data_module = None
|
self.data_module = None
|
||||||
self._create_datamodule(train_size, test_size, val_size, predict_size,
|
self._create_datamodule(train_size, test_size, val_size, predict_size,
|
||||||
batch_size, automatic_batching, pin_memory,
|
batch_size, automatic_batching, pin_memory,
|
||||||
num_workers)
|
num_workers)
|
||||||
|
|
||||||
# logging
|
# logging
|
||||||
|
|||||||
@@ -13,10 +13,10 @@ from torch.utils.data import DataLoader
|
|||||||
input_tensor = torch.rand((100, 10))
|
input_tensor = torch.rand((100, 10))
|
||||||
output_tensor = torch.rand((100, 2))
|
output_tensor = torch.rand((100, 2))
|
||||||
|
|
||||||
x = torch.rand((100, 50 , 10))
|
x = torch.rand((100, 50, 10))
|
||||||
pos = torch.rand((100, 50 , 2))
|
pos = torch.rand((100, 50, 2))
|
||||||
input_graph = RadiusGraph(x, pos, r=.1, build_edge_attr=True)
|
input_graph = RadiusGraph(x, pos, r=.1, build_edge_attr=True)
|
||||||
output_graph = torch.rand((100, 50 , 10))
|
output_graph = torch.rand((100, 50, 10))
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@@ -30,6 +30,7 @@ def test_constructor(input_, output_):
|
|||||||
problem = SupervisedProblem(input_=input_, output_=output_)
|
problem = SupervisedProblem(input_=input_, output_=output_)
|
||||||
PinaDataModule(problem)
|
PinaDataModule(problem)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"input_, output_",
|
"input_, output_",
|
||||||
[
|
[
|
||||||
@@ -46,14 +47,15 @@ def test_constructor(input_, output_):
|
|||||||
)
|
)
|
||||||
def test_setup_train(input_, output_, train_size, val_size, test_size):
|
def test_setup_train(input_, output_, train_size, val_size, test_size):
|
||||||
problem = SupervisedProblem(input_=input_, output_=output_)
|
problem = SupervisedProblem(input_=input_, output_=output_)
|
||||||
dm = PinaDataModule(problem, train_size=train_size, val_size=val_size, test_size=test_size)
|
dm = PinaDataModule(problem, train_size=train_size,
|
||||||
|
val_size=val_size, test_size=test_size)
|
||||||
dm.setup()
|
dm.setup()
|
||||||
assert hasattr(dm, "train_dataset")
|
assert hasattr(dm, "train_dataset")
|
||||||
if isinstance(input_, torch.Tensor):
|
if isinstance(input_, torch.Tensor):
|
||||||
assert isinstance(dm.train_dataset, PinaTensorDataset)
|
assert isinstance(dm.train_dataset, PinaTensorDataset)
|
||||||
else:
|
else:
|
||||||
assert isinstance(dm.train_dataset, PinaGraphDataset)
|
assert isinstance(dm.train_dataset, PinaGraphDataset)
|
||||||
#assert len(dm.train_dataset) == int(len(input_) * train_size)
|
# assert len(dm.train_dataset) == int(len(input_) * train_size)
|
||||||
if test_size > 0:
|
if test_size > 0:
|
||||||
assert hasattr(dm, "test_dataset")
|
assert hasattr(dm, "test_dataset")
|
||||||
assert dm.test_dataset is None
|
assert dm.test_dataset is None
|
||||||
@@ -64,7 +66,8 @@ def test_setup_train(input_, output_, train_size, val_size, test_size):
|
|||||||
assert isinstance(dm.val_dataset, PinaTensorDataset)
|
assert isinstance(dm.val_dataset, PinaTensorDataset)
|
||||||
else:
|
else:
|
||||||
assert isinstance(dm.val_dataset, PinaGraphDataset)
|
assert isinstance(dm.val_dataset, PinaGraphDataset)
|
||||||
#assert len(dm.val_dataset) == int(len(input_) * val_size)
|
# assert len(dm.val_dataset) == int(len(input_) * val_size)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"input_, output_",
|
"input_, output_",
|
||||||
@@ -82,7 +85,8 @@ def test_setup_train(input_, output_, train_size, val_size, test_size):
|
|||||||
)
|
)
|
||||||
def test_setup_test(input_, output_, train_size, val_size, test_size):
|
def test_setup_test(input_, output_, train_size, val_size, test_size):
|
||||||
problem = SupervisedProblem(input_=input_, output_=output_)
|
problem = SupervisedProblem(input_=input_, output_=output_)
|
||||||
dm = PinaDataModule(problem, train_size=train_size, val_size=val_size, test_size=test_size)
|
dm = PinaDataModule(problem, train_size=train_size,
|
||||||
|
val_size=val_size, test_size=test_size)
|
||||||
dm.setup(stage='test')
|
dm.setup(stage='test')
|
||||||
if train_size > 0:
|
if train_size > 0:
|
||||||
assert hasattr(dm, "train_dataset")
|
assert hasattr(dm, "train_dataset")
|
||||||
@@ -94,13 +98,14 @@ def test_setup_test(input_, output_, train_size, val_size, test_size):
|
|||||||
assert dm.val_dataset is None
|
assert dm.val_dataset is None
|
||||||
else:
|
else:
|
||||||
assert not hasattr(dm, "val_dataset")
|
assert not hasattr(dm, "val_dataset")
|
||||||
|
|
||||||
assert hasattr(dm, "test_dataset")
|
assert hasattr(dm, "test_dataset")
|
||||||
if isinstance(input_, torch.Tensor):
|
if isinstance(input_, torch.Tensor):
|
||||||
assert isinstance(dm.test_dataset, PinaTensorDataset)
|
assert isinstance(dm.test_dataset, PinaTensorDataset)
|
||||||
else:
|
else:
|
||||||
assert isinstance(dm.test_dataset, PinaGraphDataset)
|
assert isinstance(dm.test_dataset, PinaGraphDataset)
|
||||||
#assert len(dm.test_dataset) == int(len(input_) * test_size)
|
# assert len(dm.test_dataset) == int(len(input_) * test_size)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"input_, output_",
|
"input_, output_",
|
||||||
@@ -112,7 +117,8 @@ def test_setup_test(input_, output_, train_size, val_size, test_size):
|
|||||||
def test_dummy_dataloader(input_, output_):
|
def test_dummy_dataloader(input_, output_):
|
||||||
problem = SupervisedProblem(input_=input_, output_=output_)
|
problem = SupervisedProblem(input_=input_, output_=output_)
|
||||||
solver = SupervisedSolver(problem=problem, model=torch.nn.Linear(10, 10))
|
solver = SupervisedSolver(problem=problem, model=torch.nn.Linear(10, 10))
|
||||||
trainer = Trainer(solver, batch_size=None, train_size=.7, val_size=.3, test_size=0.)
|
trainer = Trainer(solver, batch_size=None, train_size=.7,
|
||||||
|
val_size=.3, test_size=0.)
|
||||||
dm = trainer.data_module
|
dm = trainer.data_module
|
||||||
dm.setup()
|
dm.setup()
|
||||||
dm.trainer = trainer
|
dm.trainer = trainer
|
||||||
@@ -140,6 +146,7 @@ def test_dummy_dataloader(input_, output_):
|
|||||||
assert isinstance(data[0][1]['input_points'], torch.Tensor)
|
assert isinstance(data[0][1]['input_points'], torch.Tensor)
|
||||||
assert isinstance(data[0][1]['output_points'], torch.Tensor)
|
assert isinstance(data[0][1]['output_points'], torch.Tensor)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"input_, output_",
|
"input_, output_",
|
||||||
[
|
[
|
||||||
@@ -147,10 +154,17 @@ def test_dummy_dataloader(input_, output_):
|
|||||||
(input_graph, output_graph)
|
(input_graph, output_graph)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
def test_dataloader(input_, output_):
|
@pytest.mark.parametrize(
|
||||||
|
"automatic_batching",
|
||||||
|
[
|
||||||
|
True, False
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def test_dataloader(input_, output_, automatic_batching):
|
||||||
problem = SupervisedProblem(input_=input_, output_=output_)
|
problem = SupervisedProblem(input_=input_, output_=output_)
|
||||||
solver = SupervisedSolver(problem=problem, model=torch.nn.Linear(10, 10))
|
solver = SupervisedSolver(problem=problem, model=torch.nn.Linear(10, 10))
|
||||||
trainer = Trainer(solver, batch_size=10, train_size=.7, val_size=.3, test_size=0.)
|
trainer = Trainer(solver, batch_size=10, train_size=.7, val_size=.3,
|
||||||
|
test_size=0., automatic_batching=automatic_batching)
|
||||||
dm = trainer.data_module
|
dm = trainer.data_module
|
||||||
dm.setup()
|
dm.setup()
|
||||||
dm.trainer = trainer
|
dm.trainer = trainer
|
||||||
@@ -176,3 +190,67 @@ def test_dataloader(input_, output_):
|
|||||||
assert isinstance(data['data']['input_points'], torch.Tensor)
|
assert isinstance(data['data']['input_points'], torch.Tensor)
|
||||||
assert isinstance(data['data']['output_points'], torch.Tensor)
|
assert isinstance(data['data']['output_points'], torch.Tensor)
|
||||||
|
|
||||||
|
from pina import LabelTensor
|
||||||
|
|
||||||
|
input_tensor = LabelTensor(torch.rand((100, 3)), ['u', 'v', 'w'])
|
||||||
|
output_tensor = LabelTensor(torch.rand((100, 3)), ['u', 'v', 'w'])
|
||||||
|
|
||||||
|
x = LabelTensor(torch.rand((100, 50, 3)), ['u', 'v', 'w'])
|
||||||
|
pos = LabelTensor(torch.rand((100, 50, 2)), ['x', 'y'])
|
||||||
|
input_graph = RadiusGraph(x, pos, r=.1, build_edge_attr=True)
|
||||||
|
output_graph = LabelTensor(torch.rand((100, 50, 3)), ['u', 'v', 'w'])
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"input_, output_",
|
||||||
|
[
|
||||||
|
(input_tensor, output_tensor),
|
||||||
|
(input_graph, output_graph)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"automatic_batching",
|
||||||
|
[
|
||||||
|
True, False
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def test_dataloader_labels(input_, output_, automatic_batching):
|
||||||
|
problem = SupervisedProblem(input_=input_, output_=output_)
|
||||||
|
solver = SupervisedSolver(problem=problem, model=torch.nn.Linear(10, 10))
|
||||||
|
trainer = Trainer(solver, batch_size=10, train_size=.7, val_size=.3,
|
||||||
|
test_size=0., automatic_batching=automatic_batching)
|
||||||
|
dm = trainer.data_module
|
||||||
|
dm.setup()
|
||||||
|
dm.trainer = trainer
|
||||||
|
dataloader = dm.train_dataloader()
|
||||||
|
assert isinstance(dataloader, DataLoader)
|
||||||
|
assert len(dataloader) == 7
|
||||||
|
data = next(iter(dataloader))
|
||||||
|
assert isinstance(data, dict)
|
||||||
|
if isinstance(input_, RadiusGraph):
|
||||||
|
assert isinstance(data['data']['input_points'], Batch)
|
||||||
|
assert isinstance(data['data']['input_points'].x, LabelTensor)
|
||||||
|
assert data['data']['input_points'].x.labels == ['u', 'v', 'w']
|
||||||
|
assert data['data']['input_points'].pos.labels == ['x', 'y']
|
||||||
|
else:
|
||||||
|
assert isinstance(data['data']['input_points'], LabelTensor)
|
||||||
|
assert data['data']['input_points'].labels == ['u', 'v', 'w']
|
||||||
|
assert isinstance(data['data']['output_points'], LabelTensor)
|
||||||
|
assert data['data']['output_points'].labels == ['u', 'v', 'w']
|
||||||
|
|
||||||
|
dataloader = dm.val_dataloader()
|
||||||
|
assert isinstance(dataloader, DataLoader)
|
||||||
|
assert len(dataloader) == 3
|
||||||
|
data = next(iter(dataloader))
|
||||||
|
assert isinstance(data, dict)
|
||||||
|
if isinstance(input_, RadiusGraph):
|
||||||
|
assert isinstance(data['data']['input_points'], Batch)
|
||||||
|
assert isinstance(data['data']['input_points'].x, LabelTensor)
|
||||||
|
assert data['data']['input_points'].x.labels == ['u', 'v', 'w']
|
||||||
|
assert data['data']['input_points'].pos.labels == ['x', 'y']
|
||||||
|
else:
|
||||||
|
assert isinstance(data['data']['input_points'], torch.Tensor)
|
||||||
|
assert isinstance(data['data']['input_points'], LabelTensor)
|
||||||
|
assert data['data']['input_points'].labels == ['u', 'v', 'w']
|
||||||
|
assert isinstance(data['data']['output_points'], torch.Tensor)
|
||||||
|
assert data['data']['output_points'].labels == ['u', 'v', 'w']
|
||||||
|
test_dataloader_labels(input_graph, output_graph, True)
|
||||||
Reference in New Issue
Block a user