Add Graph support in Dataset and Dataloader
This commit is contained in:
committed by
Nicola Demo
parent
eb146ea2ea
commit
ccc5f5a322
@@ -4,6 +4,7 @@ Basic data module implementation
|
||||
from torch.utils.data import Dataset
|
||||
import torch
|
||||
from ..label_tensor import LabelTensor
|
||||
from ..graph import Graph
|
||||
|
||||
|
||||
class BaseDataset(Dataset):
|
||||
@@ -42,38 +43,43 @@ class BaseDataset(Dataset):
|
||||
collector = problem.collector
|
||||
for slot in self.__slots__:
|
||||
setattr(self, slot, [])
|
||||
|
||||
num_el_per_condition = []
|
||||
idx = 0
|
||||
for name, data in collector.data_collections.items():
|
||||
keys = []
|
||||
for k, v in data.items():
|
||||
if isinstance(v, LabelTensor):
|
||||
keys.append(k)
|
||||
keys = list(data.keys())
|
||||
current_cond_num_el = None
|
||||
if sorted(self.__slots__) == sorted(keys):
|
||||
|
||||
for slot in self.__slots__:
|
||||
slot_data = data[slot]
|
||||
if isinstance(slot_data, (LabelTensor, torch.Tensor,
|
||||
Graph)):
|
||||
if current_cond_num_el is None:
|
||||
current_cond_num_el = len(slot_data)
|
||||
elif current_cond_num_el != len(slot_data):
|
||||
raise ValueError('Different number of conditions')
|
||||
current_list = getattr(self, slot)
|
||||
current_list.append(data[slot])
|
||||
current_list += [data[slot]] if not (
|
||||
isinstance(data[slot], list)) else data[slot]
|
||||
num_el_per_condition.append(current_cond_num_el)
|
||||
self.condition_names[idx] = name
|
||||
idx += 1
|
||||
|
||||
if len(getattr(self, self.__slots__[0])) > 0:
|
||||
input_list = getattr(self, self.__slots__[0])
|
||||
if num_el_per_condition:
|
||||
self.condition_indices = torch.cat(
|
||||
[
|
||||
torch.tensor([i] * len(input_list[i]), dtype=torch.uint8)
|
||||
for i in range(len(self.condition_names))
|
||||
torch.tensor([i] * num_el_per_condition[i],
|
||||
dtype=torch.uint8)
|
||||
for i in range(len(num_el_per_condition))
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
for slot in self.__slots__:
|
||||
current_attribute = getattr(self, slot)
|
||||
setattr(self, slot, LabelTensor.vstack(current_attribute))
|
||||
if all(isinstance(a, LabelTensor) for a in current_attribute):
|
||||
setattr(self, slot, LabelTensor.vstack(current_attribute))
|
||||
else:
|
||||
self.condition_indices = torch.tensor([], dtype=torch.uint8)
|
||||
for slot in self.__slots__:
|
||||
setattr(self, slot, torch.tensor([]))
|
||||
|
||||
self.device = device
|
||||
|
||||
def __len__(self):
|
||||
@@ -89,11 +95,10 @@ class BaseDataset(Dataset):
|
||||
def __getitem__(self, idx):
|
||||
if isinstance(idx, str):
|
||||
return getattr(self, idx).to(self.device)
|
||||
|
||||
if isinstance(idx, slice):
|
||||
to_return_list = []
|
||||
for i in self.__slots__:
|
||||
to_return_list.append(getattr(self, i)[[idx]].to(self.device))
|
||||
to_return_list.append(getattr(self, i)[idx].to(self.device))
|
||||
return to_return_list
|
||||
|
||||
if isinstance(idx, (tuple, list)):
|
||||
|
||||
@@ -6,7 +6,8 @@ from .pina_subset import PinaSubset
|
||||
|
||||
class Batch:
|
||||
"""
|
||||
Implementation of the Batch class used during training to perform SGD optimization.
|
||||
Implementation of the Batch class used during training to perform SGD
|
||||
optimization.
|
||||
"""
|
||||
|
||||
def __init__(self, dataset_dict, idx_dict):
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
"""
|
||||
Module for PinaSubset class
|
||||
"""
|
||||
from pina import LabelTensor
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
class PinaSubset:
|
||||
@@ -23,4 +25,9 @@ class PinaSubset:
|
||||
return len(self.indices)
|
||||
|
||||
def __getattr__(self, name):
|
||||
return self.dataset.__getattribute__(name)
|
||||
tensor = self.dataset.__getattribute__(name)
|
||||
if isinstance(tensor, (LabelTensor, Tensor)):
|
||||
return tensor[self.indices]
|
||||
if isinstance(tensor, list):
|
||||
return [tensor[i] for i in self.indices]
|
||||
raise AttributeError("No attribute named {}".format(name))
|
||||
|
||||
@@ -2,6 +2,8 @@
|
||||
Sample dataset module
|
||||
"""
|
||||
from .base_dataset import BaseDataset
|
||||
from ..condition.input_equation_condition import InputPointsEquationCondition
|
||||
|
||||
|
||||
class SamplePointDataset(BaseDataset):
|
||||
"""
|
||||
@@ -9,4 +11,4 @@ class SamplePointDataset(BaseDataset):
|
||||
composed of only input points.
|
||||
"""
|
||||
data_type = 'physics'
|
||||
__slots__ = ['input_points']
|
||||
__slots__ = InputPointsEquationCondition.__slots__
|
||||
|
||||
@@ -6,7 +6,8 @@ from .base_dataset import BaseDataset
|
||||
|
||||
class SupervisedDataset(BaseDataset):
|
||||
"""
|
||||
This class extends the BaseDataset to handle datasets that consist of input-output pairs.
|
||||
This class extends the BaseDataset to handle datasets that consist of
|
||||
input-output pairs.
|
||||
"""
|
||||
data_type = 'supervised'
|
||||
__slots__ = ['input_points', 'output_points']
|
||||
|
||||
Reference in New Issue
Block a user