Add Graph support in Dataset and Dataloader
This commit is contained in:
committed by
Nicola Demo
parent
eb146ea2ea
commit
ccc5f5a322
@@ -94,7 +94,8 @@ class Collector:
|
|||||||
self.data_collections[loc] = dict(zip(keys, values))
|
self.data_collections[loc] = dict(zip(keys, values))
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
'Try to sample variables which are not in problem defined in the problem')
|
'Try to sample variables which are not in problem defined '
|
||||||
|
'in the problem')
|
||||||
|
|
||||||
def add_points(self, new_points_dict):
|
def add_points(self, new_points_dict):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ Basic data module implementation
|
|||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
import torch
|
import torch
|
||||||
from ..label_tensor import LabelTensor
|
from ..label_tensor import LabelTensor
|
||||||
|
from ..graph import Graph
|
||||||
|
|
||||||
|
|
||||||
class BaseDataset(Dataset):
|
class BaseDataset(Dataset):
|
||||||
@@ -42,38 +43,43 @@ class BaseDataset(Dataset):
|
|||||||
collector = problem.collector
|
collector = problem.collector
|
||||||
for slot in self.__slots__:
|
for slot in self.__slots__:
|
||||||
setattr(self, slot, [])
|
setattr(self, slot, [])
|
||||||
|
num_el_per_condition = []
|
||||||
idx = 0
|
idx = 0
|
||||||
for name, data in collector.data_collections.items():
|
for name, data in collector.data_collections.items():
|
||||||
keys = []
|
keys = list(data.keys())
|
||||||
for k, v in data.items():
|
current_cond_num_el = None
|
||||||
if isinstance(v, LabelTensor):
|
|
||||||
keys.append(k)
|
|
||||||
if sorted(self.__slots__) == sorted(keys):
|
if sorted(self.__slots__) == sorted(keys):
|
||||||
|
|
||||||
for slot in self.__slots__:
|
for slot in self.__slots__:
|
||||||
|
slot_data = data[slot]
|
||||||
|
if isinstance(slot_data, (LabelTensor, torch.Tensor,
|
||||||
|
Graph)):
|
||||||
|
if current_cond_num_el is None:
|
||||||
|
current_cond_num_el = len(slot_data)
|
||||||
|
elif current_cond_num_el != len(slot_data):
|
||||||
|
raise ValueError('Different number of conditions')
|
||||||
current_list = getattr(self, slot)
|
current_list = getattr(self, slot)
|
||||||
current_list.append(data[slot])
|
current_list += [data[slot]] if not (
|
||||||
|
isinstance(data[slot], list)) else data[slot]
|
||||||
|
num_el_per_condition.append(current_cond_num_el)
|
||||||
self.condition_names[idx] = name
|
self.condition_names[idx] = name
|
||||||
idx += 1
|
idx += 1
|
||||||
|
if num_el_per_condition:
|
||||||
if len(getattr(self, self.__slots__[0])) > 0:
|
|
||||||
input_list = getattr(self, self.__slots__[0])
|
|
||||||
self.condition_indices = torch.cat(
|
self.condition_indices = torch.cat(
|
||||||
[
|
[
|
||||||
torch.tensor([i] * len(input_list[i]), dtype=torch.uint8)
|
torch.tensor([i] * num_el_per_condition[i],
|
||||||
for i in range(len(self.condition_names))
|
dtype=torch.uint8)
|
||||||
|
for i in range(len(num_el_per_condition))
|
||||||
],
|
],
|
||||||
dim=0,
|
dim=0,
|
||||||
)
|
)
|
||||||
for slot in self.__slots__:
|
for slot in self.__slots__:
|
||||||
current_attribute = getattr(self, slot)
|
current_attribute = getattr(self, slot)
|
||||||
|
if all(isinstance(a, LabelTensor) for a in current_attribute):
|
||||||
setattr(self, slot, LabelTensor.vstack(current_attribute))
|
setattr(self, slot, LabelTensor.vstack(current_attribute))
|
||||||
else:
|
else:
|
||||||
self.condition_indices = torch.tensor([], dtype=torch.uint8)
|
self.condition_indices = torch.tensor([], dtype=torch.uint8)
|
||||||
for slot in self.__slots__:
|
for slot in self.__slots__:
|
||||||
setattr(self, slot, torch.tensor([]))
|
setattr(self, slot, torch.tensor([]))
|
||||||
|
|
||||||
self.device = device
|
self.device = device
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
@@ -89,11 +95,10 @@ class BaseDataset(Dataset):
|
|||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
if isinstance(idx, str):
|
if isinstance(idx, str):
|
||||||
return getattr(self, idx).to(self.device)
|
return getattr(self, idx).to(self.device)
|
||||||
|
|
||||||
if isinstance(idx, slice):
|
if isinstance(idx, slice):
|
||||||
to_return_list = []
|
to_return_list = []
|
||||||
for i in self.__slots__:
|
for i in self.__slots__:
|
||||||
to_return_list.append(getattr(self, i)[[idx]].to(self.device))
|
to_return_list.append(getattr(self, i)[idx].to(self.device))
|
||||||
return to_return_list
|
return to_return_list
|
||||||
|
|
||||||
if isinstance(idx, (tuple, list)):
|
if isinstance(idx, (tuple, list)):
|
||||||
|
|||||||
@@ -6,7 +6,8 @@ from .pina_subset import PinaSubset
|
|||||||
|
|
||||||
class Batch:
|
class Batch:
|
||||||
"""
|
"""
|
||||||
Implementation of the Batch class used during training to perform SGD optimization.
|
Implementation of the Batch class used during training to perform SGD
|
||||||
|
optimization.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, dataset_dict, idx_dict):
|
def __init__(self, dataset_dict, idx_dict):
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
"""
|
"""
|
||||||
Module for PinaSubset class
|
Module for PinaSubset class
|
||||||
"""
|
"""
|
||||||
|
from pina import LabelTensor
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
|
||||||
class PinaSubset:
|
class PinaSubset:
|
||||||
@@ -23,4 +25,9 @@ class PinaSubset:
|
|||||||
return len(self.indices)
|
return len(self.indices)
|
||||||
|
|
||||||
def __getattr__(self, name):
|
def __getattr__(self, name):
|
||||||
return self.dataset.__getattribute__(name)
|
tensor = self.dataset.__getattribute__(name)
|
||||||
|
if isinstance(tensor, (LabelTensor, Tensor)):
|
||||||
|
return tensor[self.indices]
|
||||||
|
if isinstance(tensor, list):
|
||||||
|
return [tensor[i] for i in self.indices]
|
||||||
|
raise AttributeError("No attribute named {}".format(name))
|
||||||
|
|||||||
@@ -2,6 +2,8 @@
|
|||||||
Sample dataset module
|
Sample dataset module
|
||||||
"""
|
"""
|
||||||
from .base_dataset import BaseDataset
|
from .base_dataset import BaseDataset
|
||||||
|
from ..condition.input_equation_condition import InputPointsEquationCondition
|
||||||
|
|
||||||
|
|
||||||
class SamplePointDataset(BaseDataset):
|
class SamplePointDataset(BaseDataset):
|
||||||
"""
|
"""
|
||||||
@@ -9,4 +11,4 @@ class SamplePointDataset(BaseDataset):
|
|||||||
composed of only input points.
|
composed of only input points.
|
||||||
"""
|
"""
|
||||||
data_type = 'physics'
|
data_type = 'physics'
|
||||||
__slots__ = ['input_points']
|
__slots__ = InputPointsEquationCondition.__slots__
|
||||||
|
|||||||
@@ -6,7 +6,8 @@ from .base_dataset import BaseDataset
|
|||||||
|
|
||||||
class SupervisedDataset(BaseDataset):
|
class SupervisedDataset(BaseDataset):
|
||||||
"""
|
"""
|
||||||
This class extends the BaseDataset to handle datasets that consist of input-output pairs.
|
This class extends the BaseDataset to handle datasets that consist of
|
||||||
|
input-output pairs.
|
||||||
"""
|
"""
|
||||||
data_type = 'supervised'
|
data_type = 'supervised'
|
||||||
__slots__ = ['input_points', 'output_points']
|
__slots__ = ['input_points', 'output_points']
|
||||||
|
|||||||
@@ -413,7 +413,6 @@ class LabelTensor(torch.Tensor):
|
|||||||
return selected_lt
|
return selected_lt
|
||||||
|
|
||||||
def _getitem_permutation(self, index, selected_lt):
|
def _getitem_permutation(self, index, selected_lt):
|
||||||
|
|
||||||
new_labels = deepcopy(self.full_labels)
|
new_labels = deepcopy(self.full_labels)
|
||||||
new_labels.update(self._update_label_for_dim(self.full_labels, index,
|
new_labels.update(self._update_label_for_dim(self.full_labels, index,
|
||||||
0))
|
0))
|
||||||
@@ -429,6 +428,8 @@ class LabelTensor(torch.Tensor):
|
|||||||
:param dim:
|
:param dim:
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
|
if isinstance(index, torch.Tensor):
|
||||||
|
index = index.nonzero()
|
||||||
if isinstance(index, list):
|
if isinstance(index, list):
|
||||||
return {dim: {'dof': [old_labels[dim]['dof'][i] for i in index],
|
return {dim: {'dof': [old_labels[dim]['dof'][i] for i in index],
|
||||||
'name': old_labels[dim]['name']}}
|
'name': old_labels[dim]['name']}}
|
||||||
@@ -436,7 +437,6 @@ class LabelTensor(torch.Tensor):
|
|||||||
return {dim: {'dof': old_labels[dim]['dof'][index],
|
return {dim: {'dof': old_labels[dim]['dof'][index],
|
||||||
'name': old_labels[dim]['name']}}
|
'name': old_labels[dim]['name']}}
|
||||||
|
|
||||||
|
|
||||||
def sort_labels(self, dim=None):
|
def sort_labels(self, dim=None):
|
||||||
def argsort(lst):
|
def argsort(lst):
|
||||||
return sorted(range(len(lst)), key=lambda x: lst[x])
|
return sorted(range(len(lst)), key=lambda x: lst[x])
|
||||||
|
|||||||
@@ -38,7 +38,7 @@ class SolverInterface(pytorch_lightning.LightningModule, metaclass=ABCMeta):
|
|||||||
check_consistency(problem, AbstractProblem)
|
check_consistency(problem, AbstractProblem)
|
||||||
self._check_solver_consistency(problem)
|
self._check_solver_consistency(problem)
|
||||||
|
|
||||||
#Check consistency of models argument and encapsulate in list
|
# Check consistency of models argument and encapsulate in list
|
||||||
if not isinstance(models, list):
|
if not isinstance(models, list):
|
||||||
check_consistency(models, torch.nn.Module)
|
check_consistency(models, torch.nn.Module)
|
||||||
# put everything in a list if only one input
|
# put everything in a list if only one input
|
||||||
@@ -49,17 +49,17 @@ class SolverInterface(pytorch_lightning.LightningModule, metaclass=ABCMeta):
|
|||||||
check_consistency(models[idx], torch.nn.Module)
|
check_consistency(models[idx], torch.nn.Module)
|
||||||
len_model = len(models)
|
len_model = len(models)
|
||||||
|
|
||||||
#If use_lt is true add extract operation in input
|
# If use_lt is true add extract operation in input
|
||||||
if use_lt is True:
|
if use_lt is True:
|
||||||
for idx in range(len(models)):
|
for idx, model in enumerate(models):
|
||||||
models[idx] = Network(
|
models[idx] = Network(
|
||||||
model=models[idx],
|
model=model,
|
||||||
input_variables=problem.input_variables,
|
input_variables=problem.input_variables,
|
||||||
output_variables=problem.output_variables,
|
output_variables=problem.output_variables,
|
||||||
extra_features=extra_features,
|
extra_features=extra_features,
|
||||||
)
|
)
|
||||||
|
|
||||||
#Check scheduler consistency + encapsulation
|
# Check scheduler consistency + encapsulation
|
||||||
if not isinstance(schedulers, list):
|
if not isinstance(schedulers, list):
|
||||||
check_consistency(schedulers, Scheduler)
|
check_consistency(schedulers, Scheduler)
|
||||||
schedulers = [schedulers]
|
schedulers = [schedulers]
|
||||||
@@ -67,7 +67,7 @@ class SolverInterface(pytorch_lightning.LightningModule, metaclass=ABCMeta):
|
|||||||
for scheduler in schedulers:
|
for scheduler in schedulers:
|
||||||
check_consistency(scheduler, Scheduler)
|
check_consistency(scheduler, Scheduler)
|
||||||
|
|
||||||
#Check optimizer consistency + encapsulation
|
# Check optimizer consistency + encapsulation
|
||||||
if not isinstance(optimizers, list):
|
if not isinstance(optimizers, list):
|
||||||
check_consistency(optimizers, Optimizer)
|
check_consistency(optimizers, Optimizer)
|
||||||
optimizers = [optimizers]
|
optimizers = [optimizers]
|
||||||
@@ -141,5 +141,6 @@ class SolverInterface(pytorch_lightning.LightningModule, metaclass=ABCMeta):
|
|||||||
if not set(self.accepted_condition_types).issubset(
|
if not set(self.accepted_condition_types).issubset(
|
||||||
condition.condition_type):
|
condition.condition_type):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f'{self.__name__} support only dose not support condition {condition.condition_type}'
|
f'{self.__name__} support only dose not support condition '
|
||||||
|
f'{condition.condition_type}'
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -130,14 +130,13 @@ class SupervisedSolver(SolverInterface):
|
|||||||
if not hasattr(condition, "output_points"):
|
if not hasattr(condition, "output_points"):
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
f"{type(self).__name__} works only in data-driven mode.")
|
f"{type(self).__name__} works only in data-driven mode.")
|
||||||
|
|
||||||
output_pts = out[condition_idx == condition_id]
|
output_pts = out[condition_idx == condition_id]
|
||||||
input_pts = pts[condition_idx == condition_id]
|
input_pts = pts[condition_idx == condition_id]
|
||||||
|
|
||||||
input_pts.labels = pts.labels
|
input_pts.labels = pts.labels
|
||||||
output_pts.labels = out.labels
|
output_pts.labels = out.labels
|
||||||
|
|
||||||
loss = (self.loss_data(input_pts=input_pts, output_pts=output_pts))
|
loss = self.loss_data(input_pts=input_pts, output_pts=output_pts)
|
||||||
loss = loss.as_subclass(torch.Tensor)
|
loss = loss.as_subclass(torch.Tensor)
|
||||||
|
|
||||||
self.log("mean_loss", float(loss), prog_bar=True, logger=True)
|
self.log("mean_loss", float(loss), prog_bar=True, logger=True)
|
||||||
|
|||||||
@@ -60,9 +60,12 @@ class Trainer(pytorch_lightning.Trainer):
|
|||||||
if not self.solver.problem.collector.full:
|
if not self.solver.problem.collector.full:
|
||||||
error_message = '\n'.join(
|
error_message = '\n'.join(
|
||||||
[
|
[
|
||||||
f'{" " * 13} ---> Condition {key} {"sampled" if value else "not sampled"}'
|
f"""{" " * 13} ---> Condition {key} {"sampled" if value else
|
||||||
|
"not sampled"}"""
|
||||||
for key, value in
|
for key, value in
|
||||||
self.solver.problem.collector._is_conditions_ready.items()])
|
self._solver.problem.collector._is_conditions_ready.items()
|
||||||
|
]
|
||||||
|
)
|
||||||
raise RuntimeError('Cannot create Trainer if not all conditions '
|
raise RuntimeError('Cannot create Trainer if not all conditions '
|
||||||
'are sampled. The Trainer got the following:\n'
|
'are sampled. The Trainer got the following:\n'
|
||||||
f'{error_message}')
|
f'{error_message}')
|
||||||
|
|||||||
@@ -1,13 +1,15 @@
|
|||||||
import math
|
import math
|
||||||
import torch
|
import torch
|
||||||
from pina.data import SamplePointDataset, SupervisedDataset, PinaDataModule, UnsupervisedDataset, unsupervised_dataset
|
from pina.data import SamplePointDataset, SupervisedDataset, PinaDataModule, \
|
||||||
|
UnsupervisedDataset
|
||||||
from pina.data import PinaDataLoader
|
from pina.data import PinaDataLoader
|
||||||
from pina import LabelTensor, Condition
|
from pina import LabelTensor, Condition
|
||||||
from pina.equation import Equation
|
from pina.equation import Equation
|
||||||
from pina.domain import CartesianDomain
|
from pina.domain import CartesianDomain
|
||||||
from pina.problem import SpatialProblem
|
from pina.problem import SpatialProblem, AbstractProblem
|
||||||
from pina.operators import laplacian
|
from pina.operators import laplacian
|
||||||
from pina.equation.equation_factory import FixedValue
|
from pina.equation.equation_factory import FixedValue
|
||||||
|
from pina.graph import Graph
|
||||||
|
|
||||||
|
|
||||||
def laplace_equation(input_, output_):
|
def laplace_equation(input_, output_):
|
||||||
@@ -98,8 +100,8 @@ def test_data():
|
|||||||
assert dataset.input_points.shape == (61, 2)
|
assert dataset.input_points.shape == (61, 2)
|
||||||
assert dataset['input_points'].labels == ['x', 'y']
|
assert dataset['input_points'].labels == ['x', 'y']
|
||||||
assert dataset.input_points.labels == ['x', 'y']
|
assert dataset.input_points.labels == ['x', 'y']
|
||||||
assert dataset['input_points', 3:].shape == (58, 2)
|
assert dataset.input_points[3:].shape == (58, 2)
|
||||||
assert dataset[3:][1].labels == ['u']
|
assert dataset.output_points[:3].labels == ['u']
|
||||||
assert dataset.output_points.shape == (61, 1)
|
assert dataset.output_points.shape == (61, 1)
|
||||||
assert dataset.output_points.labels == ['u']
|
assert dataset.output_points.labels == ['u']
|
||||||
assert dataset.condition_indices.dtype == torch.uint8
|
assert dataset.condition_indices.dtype == torch.uint8
|
||||||
@@ -193,4 +195,32 @@ def test_loader():
|
|||||||
assert i.unsupervised.input_points.requires_grad == True
|
assert i.unsupervised.input_points.requires_grad == True
|
||||||
|
|
||||||
|
|
||||||
test_loader()
|
coordinates = LabelTensor(torch.rand((100, 100, 2)), labels=['x', 'y'])
|
||||||
|
data = LabelTensor(torch.rand((100, 100, 3)), labels=['ux', 'uy', 'p'])
|
||||||
|
|
||||||
|
|
||||||
|
class GraphProblem(AbstractProblem):
|
||||||
|
output = LabelTensor(torch.rand((100, 3)), labels=['ux', 'uy', 'p'])
|
||||||
|
input = [Graph.build('radius',
|
||||||
|
nodes_coordinates=coordinates[i, :, :],
|
||||||
|
nodes_data=data[i, :, :], radius=0.2)
|
||||||
|
for i in
|
||||||
|
range(100)]
|
||||||
|
output_variables = ['u']
|
||||||
|
|
||||||
|
conditions = {
|
||||||
|
'graph_data': Condition(input_points=input, output_points=output)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
graph_problem = GraphProblem()
|
||||||
|
|
||||||
|
|
||||||
|
def test_loader_graph():
|
||||||
|
data_module = PinaDataModule(graph_problem, device='cpu', batch_size=10)
|
||||||
|
data_module.setup()
|
||||||
|
loader = data_module.train_dataloader()
|
||||||
|
for i in loader:
|
||||||
|
assert len(i) <= 10
|
||||||
|
assert isinstance(i.supervised.input_points, list)
|
||||||
|
assert all(isinstance(x, Graph) for x in i.supervised.input_points)
|
||||||
|
|||||||
Reference in New Issue
Block a user