Fix bug in Collector with Graph data (#456)

* Fix bug in Collector with Graph data
* Add comments in DataModule class and bug fix in collate
This commit is contained in:
Filippo Olivo
2025-02-20 13:49:01 +01:00
committed by Nicola Demo
parent dfd6d7b467
commit 9c9d4fe7e4
6 changed files with 254 additions and 66 deletions

View File

@@ -1,3 +1,7 @@
"""
# TODO
"""
from .graph import Graph
from .utils import check_consistency from .utils import check_consistency
@@ -52,6 +56,8 @@ class Collector:
# get data # get data
keys = condition.__slots__ keys = condition.__slots__
values = [getattr(condition, name) for name in keys] values = [getattr(condition, name) for name in keys]
values = [value.data if isinstance(
value, Graph) else value for value in values]
self.data_collections[condition_name] = dict(zip(keys, values)) self.data_collections[condition_name] = dict(zip(keys, values))
# condition now is ready # condition now is ready
self._is_conditions_ready[condition_name] = True self._is_conditions_ready[condition_name] = True

View File

@@ -2,11 +2,11 @@ import logging
import warnings import warnings
from lightning.pytorch import LightningDataModule from lightning.pytorch import LightningDataModule
import torch import torch
from ..label_tensor import LabelTensor from torch_geometric.data import Data
from torch.utils.data import DataLoader, BatchSampler, SequentialSampler, \ from torch.utils.data import DataLoader, SequentialSampler, RandomSampler
RandomSampler
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
from .dataset import PinaDatasetFactory from ..label_tensor import LabelTensor
from .dataset import PinaDatasetFactory, PinaTensorDataset
from ..collector import Collector from ..collector import Collector
@@ -61,6 +61,10 @@ class Collator:
max_conditions_lengths is None else ( max_conditions_lengths is None else (
self._collate_standard_dataloader) self._collate_standard_dataloader)
self.dataset = dataset self.dataset = dataset
if isinstance(self.dataset, PinaTensorDataset):
self._collate = self._collate_tensor_dataset
else:
self._collate = self._collate_graph_dataset
def _collate_custom_dataloader(self, batch): def _collate_custom_dataloader(self, batch):
return self.dataset.fetch_from_idx_list(batch) return self.dataset.fetch_from_idx_list(batch)
@@ -73,7 +77,6 @@ class Collator:
if isinstance(batch, dict): if isinstance(batch, dict):
return batch return batch
conditions_names = batch[0].keys() conditions_names = batch[0].keys()
# Condition names # Condition names
for condition_name in conditions_names: for condition_name in conditions_names:
single_cond_dict = {} single_cond_dict = {}
@@ -82,16 +85,28 @@ class Collator:
data_list = [batch[idx][condition_name][arg] for idx in range( data_list = [batch[idx][condition_name][arg] for idx in range(
min(len(batch), min(len(batch),
self.max_conditions_lengths[condition_name]))] self.max_conditions_lengths[condition_name]))]
if isinstance(data_list[0], LabelTensor): single_cond_dict[arg] = self._collate(data_list)
single_cond_dict[arg] = LabelTensor.stack(data_list)
elif isinstance(data_list[0], torch.Tensor):
single_cond_dict[arg] = torch.stack(data_list)
else:
raise NotImplementedError(
f"Data type {type(data_list[0])} not supported")
batch_dict[condition_name] = single_cond_dict batch_dict[condition_name] = single_cond_dict
return batch_dict return batch_dict
@staticmethod
def _collate_tensor_dataset(data_list):
if isinstance(data_list[0], LabelTensor):
return LabelTensor.stack(data_list)
if isinstance(data_list[0], torch.Tensor):
return torch.stack(data_list)
raise RuntimeError("Data must be Tensors or LabelTensor ")
def _collate_graph_dataset(self, data_list):
if isinstance(data_list[0], LabelTensor):
return LabelTensor.cat(data_list)
if isinstance(data_list[0], torch.Tensor):
return torch.cat(data_list)
if isinstance(data_list[0], Data):
return self.dataset.create_graph_batch(data_list)
raise RuntimeError("Data must be Tensors or LabelTensor or pyG Data")
def __call__(self, batch): def __call__(self, batch):
return self.callable_function(batch) return self.callable_function(batch)
@@ -125,7 +140,7 @@ class PinaDataModule(LightningDataModule):
batch_size=None, batch_size=None,
shuffle=True, shuffle=True,
repeat=False, repeat=False,
automatic_batching=False, automatic_batching=None,
num_workers=0, num_workers=0,
pin_memory=False, pin_memory=False,
): ):
@@ -158,15 +173,35 @@ class PinaDataModule(LightningDataModule):
logging.debug('Start initialization of Pina DataModule') logging.debug('Start initialization of Pina DataModule')
logging.info('Start initialization of Pina DataModule') logging.info('Start initialization of Pina DataModule')
super().__init__() super().__init__()
self.automatic_batching = automatic_batching
# Store fixed attributes
self.batch_size = batch_size self.batch_size = batch_size
self.shuffle = shuffle self.shuffle = shuffle
self.repeat = repeat self.repeat = repeat
self.automatic_batching = automatic_batching
if batch_size is None and num_workers != 0:
warnings.warn(
"Setting num_workers when batch_size is None has no effect on "
"the DataLoading process.")
self.num_workers = 0
else:
self.num_workers = num_workers
if batch_size is None and pin_memory:
warnings.warn("Setting pin_memory to True has no effect when "
"batch_size is None.")
self.pin_memory = False
else:
self.pin_memory = pin_memory
# Collect data
collector = Collector(problem)
collector.store_fixed_data()
collector.store_sample_domains()
# Check if the splits are correct # Check if the splits are correct
self._check_slit_sizes(train_size, test_size, val_size, predict_size) self._check_slit_sizes(train_size, test_size, val_size, predict_size)
# Begin Data splitting # Split input data into subsets
splits_dict = {} splits_dict = {}
if train_size > 0: if train_size > 0:
splits_dict['train'] = train_size splits_dict['train'] = train_size
@@ -188,19 +223,6 @@ class PinaDataModule(LightningDataModule):
self.predict_dataset = None self.predict_dataset = None
else: else:
self.predict_dataloader = super().predict_dataloader self.predict_dataloader = super().predict_dataloader
collector = Collector(problem)
collector.store_fixed_data()
collector.store_sample_domains()
if batch_size is None and num_workers != 0:
warnings.warn(
"Setting num_workers when batch_size is None has no effect on "
"the DataLoading process.")
if batch_size is None and pin_memory:
warnings.warn("Setting pin_memory to True has no effect when "
"batch_size is None.")
self.num_workers = num_workers
self.pin_memory = pin_memory
self.collector_splits = self._create_splits(collector, splits_dict) self.collector_splits = self._create_splits(collector, splits_dict)
self.transfer_batch_to_device = self._transfer_batch_to_device self.transfer_batch_to_device = self._transfer_batch_to_device
@@ -316,10 +338,10 @@ class PinaDataModule(LightningDataModule):
if self.batch_size is not None: if self.batch_size is not None:
sampler = PinaSampler(dataset, shuffle) sampler = PinaSampler(dataset, shuffle)
if self.automatic_batching: if self.automatic_batching:
collate = Collator(self.find_max_conditions_lengths(split)) collate = Collator(self.find_max_conditions_lengths(split),
dataset=dataset)
else: else:
collate = Collator(None, dataset) collate = Collator(None, dataset=dataset)
return DataLoader(dataset, self.batch_size, return DataLoader(dataset, self.batch_size,
collate_fn=collate, sampler=sampler, collate_fn=collate, sampler=sampler,
num_workers=self.num_workers) num_workers=self.num_workers)

View File

@@ -1,10 +1,12 @@
""" """
This module provide basic data management functionalities This module provide basic data management functionalities
""" """
import functools
import torch import torch
from torch.utils.data import Dataset from torch.utils.data import Dataset
from abc import abstractmethod from abc import abstractmethod
from torch_geometric.data import Batch from torch_geometric.data import Batch, Data
from pina import LabelTensor
class PinaDatasetFactory: class PinaDatasetFactory:
@@ -62,7 +64,7 @@ class PinaTensorDataset(PinaDataset):
if automatic_batching: if automatic_batching:
self._getitem_func = self._getitem_int self._getitem_func = self._getitem_int
else: else:
self._getitem_func = self._getitem_list self._getitem_func = self._getitem_dummy
def _getitem_int(self, idx): def _getitem_int(self, idx):
return { return {
@@ -82,7 +84,7 @@ class PinaTensorDataset(PinaDataset):
return to_return_dict return to_return_dict
@staticmethod @staticmethod
def _getitem_list(idx): def _getitem_dummy(idx):
return idx return idx
def get_all_data(self): def get_all_data(self):
@@ -102,15 +104,56 @@ class PinaTensorDataset(PinaDataset):
} }
class PinaBatch(Batch):
"""
Add extract function to torch_geometric Batch object
"""
def __init__(self):
super().__init__(self)
def extract(self, labels):
"""
Perform extraction of labels on node features (x)
:param labels: Labels to extract
:type labels: list[str] | tuple[str] | str
:return: Batch object with extraction performed on x
:rtype: PinaBatch
"""
self.x = self.x.extract(labels)
return self
class PinaGraphDataset(PinaDataset): class PinaGraphDataset(PinaDataset):
def __init__(self, conditions_dict, max_conditions_lengths, def __init__(self, conditions_dict, max_conditions_lengths,
automatic_batching): automatic_batching):
super().__init__(conditions_dict, max_conditions_lengths) super().__init__(conditions_dict, max_conditions_lengths)
self.in_labels = {}
self.out_labels = None
if automatic_batching: if automatic_batching:
self._getitem_func = self._getitem_int self._getitem_func = self._getitem_int
else: else:
self._getitem_func = self._getitem_list self._getitem_func = self._getitem_dummy
ex_data = conditions_dict[list(conditions_dict.keys())[
0]]['input_points'][0]
for name, attr in ex_data.items():
if isinstance(attr, LabelTensor):
self.in_labels[name] = attr.stored_labels
ex_data = conditions_dict[list(conditions_dict.keys())[
0]]['output_points'][0]
if isinstance(ex_data, LabelTensor):
self.out_labels = ex_data.labels
self._create_graph_batch_from_list = self._labelise_batch(
self._base_create_graph_batch_from_list) if self.in_labels \
else self._base_create_graph_batch_from_list
self._create_output_batch = self._labelise_tensor(
self._base_create_output_batch) if self.out_labels is not None \
else self._base_create_output_batch
def fetch_from_idx_list(self, idx): def fetch_from_idx_list(self, idx):
to_return_dict = {} to_return_dict = {}
@@ -119,17 +162,24 @@ class PinaGraphDataset(PinaDataset):
condition_len = self.conditions_length[condition] condition_len = self.conditions_length[condition]
if self.length > condition_len: if self.length > condition_len:
cond_idx = [idx % condition_len for idx in cond_idx] cond_idx = [idx % condition_len for idx in cond_idx]
to_return_dict[condition] = {k: Batch.from_data_list([ to_return_dict[condition] = {
v[i] for i in cond_idx]) k: self._create_graph_batch_from_list([v[i] for i in idx])
if isinstance(v, list) if isinstance(v, list)
else v[ else self._create_output_batch(v[idx])
cond_idx].reshape( for k, v in data.items()
-1, *v[cond_idx].shape[2:]) }
for k, v in data.items()
}
return to_return_dict return to_return_dict
def _getitem_list(self, idx): def _base_create_graph_batch_from_list(self, data):
batch = PinaBatch.from_data_list(data)
return batch
def _base_create_output_batch(self, data):
out = data.reshape(-1, *data.shape[2:])
return out
def _getitem_dummy(self, idx):
return idx return idx
def _getitem_int(self, idx): def _getitem_int(self, idx):
@@ -144,3 +194,31 @@ class PinaGraphDataset(PinaDataset):
def __getitem__(self, idx): def __getitem__(self, idx):
return self._getitem_func(idx) return self._getitem_func(idx)
def _labelise_batch(self, func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
batch = func(*args, **kwargs)
for k, v in self.in_labels.items():
tmp = batch[k]
tmp.labels = v
batch[k] = tmp
return batch
return wrapper
def _labelise_tensor(self, func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
out = func(*args, **kwargs)
if isinstance(out, LabelTensor):
out.labels = self.out_labels
return out
return wrapper
def create_graph_batch(self, data):
"""
# TODO
"""
if isinstance(data[0], Data):
return self._create_graph_batch_from_list(data)
return self._create_output_batch(data)

View File

@@ -108,16 +108,14 @@ class Graph:
x) x)
# Perform the graph construction # Perform the graph construction
self._build_graph_list(x, pos, edge_index, edge_attr, additional_params) self._build_graph_list(
x, pos, edge_index, edge_attr, additional_params)
def _build_graph_list(self, x, pos, edge_index, edge_attr, def _build_graph_list(self, x, pos, edge_index, edge_attr,
additional_params): additional_params):
for i, (x_, pos_, edge_index_) in enumerate(zip(x, pos, edge_index)): for i, (x_, pos_, edge_index_) in enumerate(zip(x, pos, edge_index)):
if isinstance(x_, LabelTensor):
x_ = x_.tensor
add_params_local = {k: v[i] for k, v in additional_params.items()} add_params_local = {k: v[i] for k, v in additional_params.items()}
if edge_attr is not None: if edge_attr is not None:
self.data.append(Data(x=x_, pos=pos_, edge_index=edge_index_, self.data.append(Data(x=x_, pos=pos_, edge_index=edge_index_,
edge_attr=edge_attr[i], edge_attr=edge_attr[i],
**add_params_local)) **add_params_local))
@@ -127,7 +125,8 @@ class Graph:
@staticmethod @staticmethod
def _build_edge_attr(x, pos, edge_index): def _build_edge_attr(x, pos, edge_index):
distance = torch.abs(pos[edge_index[0]] - pos[edge_index[1]]) distance = torch.abs(pos[edge_index[0]] -
pos[edge_index[1]]).as_subclass(torch.Tensor)
return distance return distance
@staticmethod @staticmethod
@@ -165,7 +164,8 @@ class Graph:
# If edge_index is a 3D tensor, we split it into a list of 2D tensors # If edge_index is a 3D tensor, we split it into a list of 2D tensors
if edge_index is not None: if edge_index is not None:
if isinstance(edge_index, torch.Tensor) and edge_index.ndim == 3: if isinstance(edge_index, torch.Tensor) and edge_index.ndim == 3:
edge_index = [edge_index[i] for i in range(edge_index.shape[0])] edge_index = [edge_index[i]
for i in range(edge_index.shape[0])]
elif not (isinstance(edge_index, list) and all( elif not (isinstance(edge_index, list) and all(
t.ndim == 2 for t in edge_index)) and not ( t.ndim == 2 for t in edge_index)) and not (
isinstance(edge_index, isinstance(edge_index,
@@ -219,7 +219,7 @@ class Graph:
if isinstance(edge_attr, list): if isinstance(edge_attr, list):
if len(edge_attr) != data_len: if len(edge_attr) != data_len:
raise TypeError("edge_attr must have the same length as x " raise TypeError("edge_attr must have the same length as x "
"and pos.") "and pos.")
return [edge_attr] * data_len return [edge_attr] * data_len
if build_edge_attr: if build_edge_attr:
@@ -258,6 +258,8 @@ class RadiusGraph(Graph):
""" """
dist = torch.cdist(points, points, p=2) dist = torch.cdist(points, points, p=2)
edge_index = torch.nonzero(dist <= r, as_tuple=False).t() edge_index = torch.nonzero(dist <= r, as_tuple=False).t()
if isinstance(edge_index, LabelTensor):
edge_index = edge_index.tensor
return edge_index return edge_index
@@ -293,4 +295,6 @@ class KNNGraph(Graph):
row = torch.arange(points.size(0)).repeat_interleave(k) row = torch.arange(points.size(0)).repeat_interleave(k)
col = knn_indices.flatten() col = knn_indices.flatten()
edge_index = torch.stack([row, col], dim=0) edge_index = torch.stack([row, col], dim=0)
if isinstance(edge_index, LabelTensor):
edge_index = edge_index.tensor
return edge_index return edge_index

View File

@@ -105,9 +105,9 @@ class Trainer(lightning.pytorch.Trainer):
# checking compilation and automatic batching # checking compilation and automatic batching
if compile is None or sys.platform == "win32": if compile is None or sys.platform == "win32":
compile = False compile = False
if automatic_batching is None:
automatic_batching = False
self.automatic_batching = automatic_batching if automatic_batching \
is not None else False
# set attributes # set attributes
self.compile = compile self.compile = compile
self.solver = solver self.solver = solver
@@ -115,7 +115,7 @@ class Trainer(lightning.pytorch.Trainer):
self._move_to_device() self._move_to_device()
self.data_module = None self.data_module = None
self._create_datamodule(train_size, test_size, val_size, predict_size, self._create_datamodule(train_size, test_size, val_size, predict_size,
batch_size, automatic_batching, pin_memory, batch_size, automatic_batching, pin_memory,
num_workers) num_workers)
# logging # logging

View File

@@ -13,10 +13,10 @@ from torch.utils.data import DataLoader
input_tensor = torch.rand((100, 10)) input_tensor = torch.rand((100, 10))
output_tensor = torch.rand((100, 2)) output_tensor = torch.rand((100, 2))
x = torch.rand((100, 50 , 10)) x = torch.rand((100, 50, 10))
pos = torch.rand((100, 50 , 2)) pos = torch.rand((100, 50, 2))
input_graph = RadiusGraph(x, pos, r=.1, build_edge_attr=True) input_graph = RadiusGraph(x, pos, r=.1, build_edge_attr=True)
output_graph = torch.rand((100, 50 , 10)) output_graph = torch.rand((100, 50, 10))
@pytest.mark.parametrize( @pytest.mark.parametrize(
@@ -30,6 +30,7 @@ def test_constructor(input_, output_):
problem = SupervisedProblem(input_=input_, output_=output_) problem = SupervisedProblem(input_=input_, output_=output_)
PinaDataModule(problem) PinaDataModule(problem)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"input_, output_", "input_, output_",
[ [
@@ -46,14 +47,15 @@ def test_constructor(input_, output_):
) )
def test_setup_train(input_, output_, train_size, val_size, test_size): def test_setup_train(input_, output_, train_size, val_size, test_size):
problem = SupervisedProblem(input_=input_, output_=output_) problem = SupervisedProblem(input_=input_, output_=output_)
dm = PinaDataModule(problem, train_size=train_size, val_size=val_size, test_size=test_size) dm = PinaDataModule(problem, train_size=train_size,
val_size=val_size, test_size=test_size)
dm.setup() dm.setup()
assert hasattr(dm, "train_dataset") assert hasattr(dm, "train_dataset")
if isinstance(input_, torch.Tensor): if isinstance(input_, torch.Tensor):
assert isinstance(dm.train_dataset, PinaTensorDataset) assert isinstance(dm.train_dataset, PinaTensorDataset)
else: else:
assert isinstance(dm.train_dataset, PinaGraphDataset) assert isinstance(dm.train_dataset, PinaGraphDataset)
#assert len(dm.train_dataset) == int(len(input_) * train_size) # assert len(dm.train_dataset) == int(len(input_) * train_size)
if test_size > 0: if test_size > 0:
assert hasattr(dm, "test_dataset") assert hasattr(dm, "test_dataset")
assert dm.test_dataset is None assert dm.test_dataset is None
@@ -64,7 +66,8 @@ def test_setup_train(input_, output_, train_size, val_size, test_size):
assert isinstance(dm.val_dataset, PinaTensorDataset) assert isinstance(dm.val_dataset, PinaTensorDataset)
else: else:
assert isinstance(dm.val_dataset, PinaGraphDataset) assert isinstance(dm.val_dataset, PinaGraphDataset)
#assert len(dm.val_dataset) == int(len(input_) * val_size) # assert len(dm.val_dataset) == int(len(input_) * val_size)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"input_, output_", "input_, output_",
@@ -82,7 +85,8 @@ def test_setup_train(input_, output_, train_size, val_size, test_size):
) )
def test_setup_test(input_, output_, train_size, val_size, test_size): def test_setup_test(input_, output_, train_size, val_size, test_size):
problem = SupervisedProblem(input_=input_, output_=output_) problem = SupervisedProblem(input_=input_, output_=output_)
dm = PinaDataModule(problem, train_size=train_size, val_size=val_size, test_size=test_size) dm = PinaDataModule(problem, train_size=train_size,
val_size=val_size, test_size=test_size)
dm.setup(stage='test') dm.setup(stage='test')
if train_size > 0: if train_size > 0:
assert hasattr(dm, "train_dataset") assert hasattr(dm, "train_dataset")
@@ -94,13 +98,14 @@ def test_setup_test(input_, output_, train_size, val_size, test_size):
assert dm.val_dataset is None assert dm.val_dataset is None
else: else:
assert not hasattr(dm, "val_dataset") assert not hasattr(dm, "val_dataset")
assert hasattr(dm, "test_dataset") assert hasattr(dm, "test_dataset")
if isinstance(input_, torch.Tensor): if isinstance(input_, torch.Tensor):
assert isinstance(dm.test_dataset, PinaTensorDataset) assert isinstance(dm.test_dataset, PinaTensorDataset)
else: else:
assert isinstance(dm.test_dataset, PinaGraphDataset) assert isinstance(dm.test_dataset, PinaGraphDataset)
#assert len(dm.test_dataset) == int(len(input_) * test_size) # assert len(dm.test_dataset) == int(len(input_) * test_size)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"input_, output_", "input_, output_",
@@ -112,7 +117,8 @@ def test_setup_test(input_, output_, train_size, val_size, test_size):
def test_dummy_dataloader(input_, output_): def test_dummy_dataloader(input_, output_):
problem = SupervisedProblem(input_=input_, output_=output_) problem = SupervisedProblem(input_=input_, output_=output_)
solver = SupervisedSolver(problem=problem, model=torch.nn.Linear(10, 10)) solver = SupervisedSolver(problem=problem, model=torch.nn.Linear(10, 10))
trainer = Trainer(solver, batch_size=None, train_size=.7, val_size=.3, test_size=0.) trainer = Trainer(solver, batch_size=None, train_size=.7,
val_size=.3, test_size=0.)
dm = trainer.data_module dm = trainer.data_module
dm.setup() dm.setup()
dm.trainer = trainer dm.trainer = trainer
@@ -140,6 +146,7 @@ def test_dummy_dataloader(input_, output_):
assert isinstance(data[0][1]['input_points'], torch.Tensor) assert isinstance(data[0][1]['input_points'], torch.Tensor)
assert isinstance(data[0][1]['output_points'], torch.Tensor) assert isinstance(data[0][1]['output_points'], torch.Tensor)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"input_, output_", "input_, output_",
[ [
@@ -147,10 +154,17 @@ def test_dummy_dataloader(input_, output_):
(input_graph, output_graph) (input_graph, output_graph)
] ]
) )
def test_dataloader(input_, output_): @pytest.mark.parametrize(
"automatic_batching",
[
True, False
]
)
def test_dataloader(input_, output_, automatic_batching):
problem = SupervisedProblem(input_=input_, output_=output_) problem = SupervisedProblem(input_=input_, output_=output_)
solver = SupervisedSolver(problem=problem, model=torch.nn.Linear(10, 10)) solver = SupervisedSolver(problem=problem, model=torch.nn.Linear(10, 10))
trainer = Trainer(solver, batch_size=10, train_size=.7, val_size=.3, test_size=0.) trainer = Trainer(solver, batch_size=10, train_size=.7, val_size=.3,
test_size=0., automatic_batching=automatic_batching)
dm = trainer.data_module dm = trainer.data_module
dm.setup() dm.setup()
dm.trainer = trainer dm.trainer = trainer
@@ -176,3 +190,67 @@ def test_dataloader(input_, output_):
assert isinstance(data['data']['input_points'], torch.Tensor) assert isinstance(data['data']['input_points'], torch.Tensor)
assert isinstance(data['data']['output_points'], torch.Tensor) assert isinstance(data['data']['output_points'], torch.Tensor)
from pina import LabelTensor
input_tensor = LabelTensor(torch.rand((100, 3)), ['u', 'v', 'w'])
output_tensor = LabelTensor(torch.rand((100, 3)), ['u', 'v', 'w'])
x = LabelTensor(torch.rand((100, 50, 3)), ['u', 'v', 'w'])
pos = LabelTensor(torch.rand((100, 50, 2)), ['x', 'y'])
input_graph = RadiusGraph(x, pos, r=.1, build_edge_attr=True)
output_graph = LabelTensor(torch.rand((100, 50, 3)), ['u', 'v', 'w'])
@pytest.mark.parametrize(
"input_, output_",
[
(input_tensor, output_tensor),
(input_graph, output_graph)
]
)
@pytest.mark.parametrize(
"automatic_batching",
[
True, False
]
)
def test_dataloader_labels(input_, output_, automatic_batching):
problem = SupervisedProblem(input_=input_, output_=output_)
solver = SupervisedSolver(problem=problem, model=torch.nn.Linear(10, 10))
trainer = Trainer(solver, batch_size=10, train_size=.7, val_size=.3,
test_size=0., automatic_batching=automatic_batching)
dm = trainer.data_module
dm.setup()
dm.trainer = trainer
dataloader = dm.train_dataloader()
assert isinstance(dataloader, DataLoader)
assert len(dataloader) == 7
data = next(iter(dataloader))
assert isinstance(data, dict)
if isinstance(input_, RadiusGraph):
assert isinstance(data['data']['input_points'], Batch)
assert isinstance(data['data']['input_points'].x, LabelTensor)
assert data['data']['input_points'].x.labels == ['u', 'v', 'w']
assert data['data']['input_points'].pos.labels == ['x', 'y']
else:
assert isinstance(data['data']['input_points'], LabelTensor)
assert data['data']['input_points'].labels == ['u', 'v', 'w']
assert isinstance(data['data']['output_points'], LabelTensor)
assert data['data']['output_points'].labels == ['u', 'v', 'w']
dataloader = dm.val_dataloader()
assert isinstance(dataloader, DataLoader)
assert len(dataloader) == 3
data = next(iter(dataloader))
assert isinstance(data, dict)
if isinstance(input_, RadiusGraph):
assert isinstance(data['data']['input_points'], Batch)
assert isinstance(data['data']['input_points'].x, LabelTensor)
assert data['data']['input_points'].x.labels == ['u', 'v', 'w']
assert data['data']['input_points'].pos.labels == ['x', 'y']
else:
assert isinstance(data['data']['input_points'], torch.Tensor)
assert isinstance(data['data']['input_points'], LabelTensor)
assert data['data']['input_points'].labels == ['u', 'v', 'w']
assert isinstance(data['data']['output_points'], torch.Tensor)
assert data['data']['output_points'].labels == ['u', 'v', 'w']
test_dataloader_labels(input_graph, output_graph, True)