Add Graph support in Dataset and Dataloader

This commit is contained in:
FilippoOlivo
2024-10-23 15:04:28 +02:00
committed by Nicola Demo
parent eb146ea2ea
commit ccc5f5a322
11 changed files with 125 additions and 75 deletions

View File

@@ -94,7 +94,8 @@ class Collector:
self.data_collections[loc] = dict(zip(keys, values))
else:
raise RuntimeError(
'Try to sample variables which are not in problem defined in the problem')
'Try to sample variables which are not in problem defined '
'in the problem')
def add_points(self, new_points_dict):
"""

View File

@@ -4,6 +4,7 @@ Basic data module implementation
from torch.utils.data import Dataset
import torch
from ..label_tensor import LabelTensor
from ..graph import Graph
class BaseDataset(Dataset):
@@ -42,38 +43,43 @@ class BaseDataset(Dataset):
collector = problem.collector
for slot in self.__slots__:
setattr(self, slot, [])
num_el_per_condition = []
idx = 0
for name, data in collector.data_collections.items():
keys = []
for k, v in data.items():
if isinstance(v, LabelTensor):
keys.append(k)
keys = list(data.keys())
current_cond_num_el = None
if sorted(self.__slots__) == sorted(keys):
for slot in self.__slots__:
slot_data = data[slot]
if isinstance(slot_data, (LabelTensor, torch.Tensor,
Graph)):
if current_cond_num_el is None:
current_cond_num_el = len(slot_data)
elif current_cond_num_el != len(slot_data):
raise ValueError('Different number of conditions')
current_list = getattr(self, slot)
current_list.append(data[slot])
current_list += [data[slot]] if not (
isinstance(data[slot], list)) else data[slot]
num_el_per_condition.append(current_cond_num_el)
self.condition_names[idx] = name
idx += 1
if len(getattr(self, self.__slots__[0])) > 0:
input_list = getattr(self, self.__slots__[0])
if num_el_per_condition:
self.condition_indices = torch.cat(
[
torch.tensor([i] * len(input_list[i]), dtype=torch.uint8)
for i in range(len(self.condition_names))
torch.tensor([i] * num_el_per_condition[i],
dtype=torch.uint8)
for i in range(len(num_el_per_condition))
],
dim=0,
)
for slot in self.__slots__:
current_attribute = getattr(self, slot)
if all(isinstance(a, LabelTensor) for a in current_attribute):
setattr(self, slot, LabelTensor.vstack(current_attribute))
else:
self.condition_indices = torch.tensor([], dtype=torch.uint8)
for slot in self.__slots__:
setattr(self, slot, torch.tensor([]))
self.device = device
def __len__(self):
@@ -89,11 +95,10 @@ class BaseDataset(Dataset):
def __getitem__(self, idx):
if isinstance(idx, str):
return getattr(self, idx).to(self.device)
if isinstance(idx, slice):
to_return_list = []
for i in self.__slots__:
to_return_list.append(getattr(self, i)[[idx]].to(self.device))
to_return_list.append(getattr(self, i)[idx].to(self.device))
return to_return_list
if isinstance(idx, (tuple, list)):

View File

@@ -6,7 +6,8 @@ from .pina_subset import PinaSubset
class Batch:
"""
Implementation of the Batch class used during training to perform SGD optimization.
Implementation of the Batch class used during training to perform SGD
optimization.
"""
def __init__(self, dataset_dict, idx_dict):

View File

@@ -1,6 +1,8 @@
"""
Module for PinaSubset class
"""
from pina import LabelTensor
from torch import Tensor
class PinaSubset:
@@ -23,4 +25,9 @@ class PinaSubset:
return len(self.indices)
def __getattr__(self, name):
return self.dataset.__getattribute__(name)
tensor = self.dataset.__getattribute__(name)
if isinstance(tensor, (LabelTensor, Tensor)):
return tensor[self.indices]
if isinstance(tensor, list):
return [tensor[i] for i in self.indices]
raise AttributeError("No attribute named {}".format(name))

View File

@@ -2,6 +2,8 @@
Sample dataset module
"""
from .base_dataset import BaseDataset
from ..condition.input_equation_condition import InputPointsEquationCondition
class SamplePointDataset(BaseDataset):
"""
@@ -9,4 +11,4 @@ class SamplePointDataset(BaseDataset):
composed of only input points.
"""
data_type = 'physics'
__slots__ = ['input_points']
__slots__ = InputPointsEquationCondition.__slots__

View File

@@ -6,7 +6,8 @@ from .base_dataset import BaseDataset
class SupervisedDataset(BaseDataset):
"""
This class extends the BaseDataset to handle datasets that consist of input-output pairs.
This class extends the BaseDataset to handle datasets that consist of
input-output pairs.
"""
data_type = 'supervised'
__slots__ = ['input_points', 'output_points']

View File

@@ -413,7 +413,6 @@ class LabelTensor(torch.Tensor):
return selected_lt
def _getitem_permutation(self, index, selected_lt):
new_labels = deepcopy(self.full_labels)
new_labels.update(self._update_label_for_dim(self.full_labels, index,
0))
@@ -429,6 +428,8 @@ class LabelTensor(torch.Tensor):
:param dim:
:return:
"""
if isinstance(index, torch.Tensor):
index = index.nonzero()
if isinstance(index, list):
return {dim: {'dof': [old_labels[dim]['dof'][i] for i in index],
'name': old_labels[dim]['name']}}
@@ -436,7 +437,6 @@ class LabelTensor(torch.Tensor):
return {dim: {'dof': old_labels[dim]['dof'][index],
'name': old_labels[dim]['name']}}
def sort_labels(self, dim=None):
def argsort(lst):
return sorted(range(len(lst)), key=lambda x: lst[x])

View File

@@ -38,7 +38,7 @@ class SolverInterface(pytorch_lightning.LightningModule, metaclass=ABCMeta):
check_consistency(problem, AbstractProblem)
self._check_solver_consistency(problem)
#Check consistency of models argument and encapsulate in list
# Check consistency of models argument and encapsulate in list
if not isinstance(models, list):
check_consistency(models, torch.nn.Module)
# put everything in a list if only one input
@@ -49,17 +49,17 @@ class SolverInterface(pytorch_lightning.LightningModule, metaclass=ABCMeta):
check_consistency(models[idx], torch.nn.Module)
len_model = len(models)
#If use_lt is true add extract operation in input
# If use_lt is true add extract operation in input
if use_lt is True:
for idx in range(len(models)):
for idx, model in enumerate(models):
models[idx] = Network(
model=models[idx],
model=model,
input_variables=problem.input_variables,
output_variables=problem.output_variables,
extra_features=extra_features,
)
#Check scheduler consistency + encapsulation
# Check scheduler consistency + encapsulation
if not isinstance(schedulers, list):
check_consistency(schedulers, Scheduler)
schedulers = [schedulers]
@@ -67,7 +67,7 @@ class SolverInterface(pytorch_lightning.LightningModule, metaclass=ABCMeta):
for scheduler in schedulers:
check_consistency(scheduler, Scheduler)
#Check optimizer consistency + encapsulation
# Check optimizer consistency + encapsulation
if not isinstance(optimizers, list):
check_consistency(optimizers, Optimizer)
optimizers = [optimizers]
@@ -141,5 +141,6 @@ class SolverInterface(pytorch_lightning.LightningModule, metaclass=ABCMeta):
if not set(self.accepted_condition_types).issubset(
condition.condition_type):
raise ValueError(
f'{self.__name__} support only dose not support condition {condition.condition_type}'
f'{self.__name__} support only dose not support condition '
f'{condition.condition_type}'
)

View File

@@ -130,14 +130,13 @@ class SupervisedSolver(SolverInterface):
if not hasattr(condition, "output_points"):
raise NotImplementedError(
f"{type(self).__name__} works only in data-driven mode.")
output_pts = out[condition_idx == condition_id]
input_pts = pts[condition_idx == condition_id]
input_pts.labels = pts.labels
output_pts.labels = out.labels
loss = (self.loss_data(input_pts=input_pts, output_pts=output_pts))
loss = self.loss_data(input_pts=input_pts, output_pts=output_pts)
loss = loss.as_subclass(torch.Tensor)
self.log("mean_loss", float(loss), prog_bar=True, logger=True)

View File

@@ -60,9 +60,12 @@ class Trainer(pytorch_lightning.Trainer):
if not self.solver.problem.collector.full:
error_message = '\n'.join(
[
f'{" " * 13} ---> Condition {key} {"sampled" if value else "not sampled"}'
f"""{" " * 13} ---> Condition {key} {"sampled" if value else
"not sampled"}"""
for key, value in
self.solver.problem.collector._is_conditions_ready.items()])
self._solver.problem.collector._is_conditions_ready.items()
]
)
raise RuntimeError('Cannot create Trainer if not all conditions '
'are sampled. The Trainer got the following:\n'
f'{error_message}')

View File

@@ -1,13 +1,15 @@
import math
import torch
from pina.data import SamplePointDataset, SupervisedDataset, PinaDataModule, UnsupervisedDataset, unsupervised_dataset
from pina.data import SamplePointDataset, SupervisedDataset, PinaDataModule, \
UnsupervisedDataset
from pina.data import PinaDataLoader
from pina import LabelTensor, Condition
from pina.equation import Equation
from pina.domain import CartesianDomain
from pina.problem import SpatialProblem
from pina.problem import SpatialProblem, AbstractProblem
from pina.operators import laplacian
from pina.equation.equation_factory import FixedValue
from pina.graph import Graph
def laplace_equation(input_, output_):
@@ -98,8 +100,8 @@ def test_data():
assert dataset.input_points.shape == (61, 2)
assert dataset['input_points'].labels == ['x', 'y']
assert dataset.input_points.labels == ['x', 'y']
assert dataset['input_points', 3:].shape == (58, 2)
assert dataset[3:][1].labels == ['u']
assert dataset.input_points[3:].shape == (58, 2)
assert dataset.output_points[:3].labels == ['u']
assert dataset.output_points.shape == (61, 1)
assert dataset.output_points.labels == ['u']
assert dataset.condition_indices.dtype == torch.uint8
@@ -193,4 +195,32 @@ def test_loader():
assert i.unsupervised.input_points.requires_grad == True
test_loader()
coordinates = LabelTensor(torch.rand((100, 100, 2)), labels=['x', 'y'])
data = LabelTensor(torch.rand((100, 100, 3)), labels=['ux', 'uy', 'p'])
class GraphProblem(AbstractProblem):
output = LabelTensor(torch.rand((100, 3)), labels=['ux', 'uy', 'p'])
input = [Graph.build('radius',
nodes_coordinates=coordinates[i, :, :],
nodes_data=data[i, :, :], radius=0.2)
for i in
range(100)]
output_variables = ['u']
conditions = {
'graph_data': Condition(input_points=input, output_points=output)
}
graph_problem = GraphProblem()
def test_loader_graph():
data_module = PinaDataModule(graph_problem, device='cpu', batch_size=10)
data_module.setup()
loader = data_module.train_dataloader()
for i in loader:
assert len(i) <= 10
assert isinstance(i.supervised.input_points, list)
assert all(isinstance(x, Graph) for x in i.supervised.input_points)