From ccc5f5a3227b827ef27cd74f6952fe11cdd0b803 Mon Sep 17 00:00:00 2001 From: FilippoOlivo Date: Wed, 23 Oct 2024 15:04:28 +0200 Subject: [PATCH] Add Graph support in Dataset and Dataloader --- pina/collector.py | 5 +- pina/data/base_dataset.py | 37 ++++++----- pina/data/pina_batch.py | 3 +- pina/data/pina_subset.py | 9 ++- pina/data/sample_dataset.py | 4 +- pina/data/supervised_dataset.py | 3 +- pina/label_tensor.py | 4 +- pina/solvers/solver.py | 15 +++-- pina/solvers/supervised.py | 3 +- pina/trainer.py | 7 +- tests/test_dataset.py | 110 ++++++++++++++++++++------------ 11 files changed, 125 insertions(+), 75 deletions(-) diff --git a/pina/collector.py b/pina/collector.py index 4ebf236..c48c674 100644 --- a/pina/collector.py +++ b/pina/collector.py @@ -49,7 +49,7 @@ class Collector: # if the condition is not ready and domain is not attribute # of condition, we get and store the data if (not self._is_conditions_ready[condition_name]) and ( - not hasattr(condition, "domain")): + not hasattr(condition, "domain")): # get data keys = condition.__slots__ values = [getattr(condition, name) for name in keys] @@ -94,7 +94,8 @@ class Collector: self.data_collections[loc] = dict(zip(keys, values)) else: raise RuntimeError( - 'Try to sample variables which are not in problem defined in the problem') + 'Try to sample variables which are not in problem defined ' + 'in the problem') def add_points(self, new_points_dict): """ diff --git a/pina/data/base_dataset.py b/pina/data/base_dataset.py index b15a0be..d859aac 100644 --- a/pina/data/base_dataset.py +++ b/pina/data/base_dataset.py @@ -4,6 +4,7 @@ Basic data module implementation from torch.utils.data import Dataset import torch from ..label_tensor import LabelTensor +from ..graph import Graph class BaseDataset(Dataset): @@ -42,38 +43,43 @@ class BaseDataset(Dataset): collector = problem.collector for slot in self.__slots__: setattr(self, slot, []) - + num_el_per_condition = [] idx = 0 for name, data in collector.data_collections.items(): - keys = [] - for k, v in data.items(): - if isinstance(v, LabelTensor): - keys.append(k) + keys = list(data.keys()) + current_cond_num_el = None if sorted(self.__slots__) == sorted(keys): - for slot in self.__slots__: + slot_data = data[slot] + if isinstance(slot_data, (LabelTensor, torch.Tensor, + Graph)): + if current_cond_num_el is None: + current_cond_num_el = len(slot_data) + elif current_cond_num_el != len(slot_data): + raise ValueError('Different number of conditions') current_list = getattr(self, slot) - current_list.append(data[slot]) + current_list += [data[slot]] if not ( + isinstance(data[slot], list)) else data[slot] + num_el_per_condition.append(current_cond_num_el) self.condition_names[idx] = name idx += 1 - - if len(getattr(self, self.__slots__[0])) > 0: - input_list = getattr(self, self.__slots__[0]) + if num_el_per_condition: self.condition_indices = torch.cat( [ - torch.tensor([i] * len(input_list[i]), dtype=torch.uint8) - for i in range(len(self.condition_names)) + torch.tensor([i] * num_el_per_condition[i], + dtype=torch.uint8) + for i in range(len(num_el_per_condition)) ], dim=0, ) for slot in self.__slots__: current_attribute = getattr(self, slot) - setattr(self, slot, LabelTensor.vstack(current_attribute)) + if all(isinstance(a, LabelTensor) for a in current_attribute): + setattr(self, slot, LabelTensor.vstack(current_attribute)) else: self.condition_indices = torch.tensor([], dtype=torch.uint8) for slot in self.__slots__: setattr(self, slot, torch.tensor([])) - self.device = device def __len__(self): @@ -89,11 +95,10 @@ class BaseDataset(Dataset): def __getitem__(self, idx): if isinstance(idx, str): return getattr(self, idx).to(self.device) - if isinstance(idx, slice): to_return_list = [] for i in self.__slots__: - to_return_list.append(getattr(self, i)[[idx]].to(self.device)) + to_return_list.append(getattr(self, i)[idx].to(self.device)) return to_return_list if isinstance(idx, (tuple, list)): diff --git a/pina/data/pina_batch.py b/pina/data/pina_batch.py index ed34a91..65b5ac5 100644 --- a/pina/data/pina_batch.py +++ b/pina/data/pina_batch.py @@ -6,7 +6,8 @@ from .pina_subset import PinaSubset class Batch: """ - Implementation of the Batch class used during training to perform SGD optimization. + Implementation of the Batch class used during training to perform SGD + optimization. """ def __init__(self, dataset_dict, idx_dict): diff --git a/pina/data/pina_subset.py b/pina/data/pina_subset.py index 844321b..f1347b6 100644 --- a/pina/data/pina_subset.py +++ b/pina/data/pina_subset.py @@ -1,6 +1,8 @@ """ Module for PinaSubset class """ +from pina import LabelTensor +from torch import Tensor class PinaSubset: @@ -23,4 +25,9 @@ class PinaSubset: return len(self.indices) def __getattr__(self, name): - return self.dataset.__getattribute__(name) + tensor = self.dataset.__getattribute__(name) + if isinstance(tensor, (LabelTensor, Tensor)): + return tensor[self.indices] + if isinstance(tensor, list): + return [tensor[i] for i in self.indices] + raise AttributeError("No attribute named {}".format(name)) diff --git a/pina/data/sample_dataset.py b/pina/data/sample_dataset.py index ba8bd19..99811ca 100644 --- a/pina/data/sample_dataset.py +++ b/pina/data/sample_dataset.py @@ -2,6 +2,8 @@ Sample dataset module """ from .base_dataset import BaseDataset +from ..condition.input_equation_condition import InputPointsEquationCondition + class SamplePointDataset(BaseDataset): """ @@ -9,4 +11,4 @@ class SamplePointDataset(BaseDataset): composed of only input points. """ data_type = 'physics' - __slots__ = ['input_points'] + __slots__ = InputPointsEquationCondition.__slots__ diff --git a/pina/data/supervised_dataset.py b/pina/data/supervised_dataset.py index 2403e3d..be60105 100644 --- a/pina/data/supervised_dataset.py +++ b/pina/data/supervised_dataset.py @@ -6,7 +6,8 @@ from .base_dataset import BaseDataset class SupervisedDataset(BaseDataset): """ - This class extends the BaseDataset to handle datasets that consist of input-output pairs. + This class extends the BaseDataset to handle datasets that consist of + input-output pairs. """ data_type = 'supervised' __slots__ = ['input_points', 'output_points'] diff --git a/pina/label_tensor.py b/pina/label_tensor.py index 62d8795..87def2f 100644 --- a/pina/label_tensor.py +++ b/pina/label_tensor.py @@ -413,7 +413,6 @@ class LabelTensor(torch.Tensor): return selected_lt def _getitem_permutation(self, index, selected_lt): - new_labels = deepcopy(self.full_labels) new_labels.update(self._update_label_for_dim(self.full_labels, index, 0)) @@ -429,6 +428,8 @@ class LabelTensor(torch.Tensor): :param dim: :return: """ + if isinstance(index, torch.Tensor): + index = index.nonzero() if isinstance(index, list): return {dim: {'dof': [old_labels[dim]['dof'][i] for i in index], 'name': old_labels[dim]['name']}} @@ -436,7 +437,6 @@ class LabelTensor(torch.Tensor): return {dim: {'dof': old_labels[dim]['dof'][index], 'name': old_labels[dim]['name']}} - def sort_labels(self, dim=None): def argsort(lst): return sorted(range(len(lst)), key=lambda x: lst[x]) diff --git a/pina/solvers/solver.py b/pina/solvers/solver.py index 1c6aa2b..6f55ded 100644 --- a/pina/solvers/solver.py +++ b/pina/solvers/solver.py @@ -38,7 +38,7 @@ class SolverInterface(pytorch_lightning.LightningModule, metaclass=ABCMeta): check_consistency(problem, AbstractProblem) self._check_solver_consistency(problem) - #Check consistency of models argument and encapsulate in list + # Check consistency of models argument and encapsulate in list if not isinstance(models, list): check_consistency(models, torch.nn.Module) # put everything in a list if only one input @@ -49,17 +49,17 @@ class SolverInterface(pytorch_lightning.LightningModule, metaclass=ABCMeta): check_consistency(models[idx], torch.nn.Module) len_model = len(models) - #If use_lt is true add extract operation in input + # If use_lt is true add extract operation in input if use_lt is True: - for idx in range(len(models)): + for idx, model in enumerate(models): models[idx] = Network( - model=models[idx], + model=model, input_variables=problem.input_variables, output_variables=problem.output_variables, extra_features=extra_features, ) - #Check scheduler consistency + encapsulation + # Check scheduler consistency + encapsulation if not isinstance(schedulers, list): check_consistency(schedulers, Scheduler) schedulers = [schedulers] @@ -67,7 +67,7 @@ class SolverInterface(pytorch_lightning.LightningModule, metaclass=ABCMeta): for scheduler in schedulers: check_consistency(scheduler, Scheduler) - #Check optimizer consistency + encapsulation + # Check optimizer consistency + encapsulation if not isinstance(optimizers, list): check_consistency(optimizers, Optimizer) optimizers = [optimizers] @@ -141,5 +141,6 @@ class SolverInterface(pytorch_lightning.LightningModule, metaclass=ABCMeta): if not set(self.accepted_condition_types).issubset( condition.condition_type): raise ValueError( - f'{self.__name__} support only dose not support condition {condition.condition_type}' + f'{self.__name__} support only dose not support condition ' + f'{condition.condition_type}' ) diff --git a/pina/solvers/supervised.py b/pina/solvers/supervised.py index a0b0f83..62fc991 100644 --- a/pina/solvers/supervised.py +++ b/pina/solvers/supervised.py @@ -130,14 +130,13 @@ class SupervisedSolver(SolverInterface): if not hasattr(condition, "output_points"): raise NotImplementedError( f"{type(self).__name__} works only in data-driven mode.") - output_pts = out[condition_idx == condition_id] input_pts = pts[condition_idx == condition_id] input_pts.labels = pts.labels output_pts.labels = out.labels - loss = (self.loss_data(input_pts=input_pts, output_pts=output_pts)) + loss = self.loss_data(input_pts=input_pts, output_pts=output_pts) loss = loss.as_subclass(torch.Tensor) self.log("mean_loss", float(loss), prog_bar=True, logger=True) diff --git a/pina/trainer.py b/pina/trainer.py index 884eef7..3de0d7e 100644 --- a/pina/trainer.py +++ b/pina/trainer.py @@ -60,9 +60,12 @@ class Trainer(pytorch_lightning.Trainer): if not self.solver.problem.collector.full: error_message = '\n'.join( [ - f'{" " * 13} ---> Condition {key} {"sampled" if value else "not sampled"}' + f"""{" " * 13} ---> Condition {key} {"sampled" if value else + "not sampled"}""" for key, value in - self.solver.problem.collector._is_conditions_ready.items()]) + self._solver.problem.collector._is_conditions_ready.items() + ] + ) raise RuntimeError('Cannot create Trainer if not all conditions ' 'are sampled. The Trainer got the following:\n' f'{error_message}') diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 653b0d6..503ddd6 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -1,13 +1,15 @@ import math import torch -from pina.data import SamplePointDataset, SupervisedDataset, PinaDataModule, UnsupervisedDataset, unsupervised_dataset +from pina.data import SamplePointDataset, SupervisedDataset, PinaDataModule, \ + UnsupervisedDataset from pina.data import PinaDataLoader from pina import LabelTensor, Condition from pina.equation import Equation from pina.domain import CartesianDomain -from pina.problem import SpatialProblem +from pina.problem import SpatialProblem, AbstractProblem from pina.operators import laplacian from pina.equation.equation_factory import FixedValue +from pina.graph import Graph def laplace_equation(input_, output_): @@ -30,49 +32,49 @@ class Poisson(SpatialProblem): conditions = { 'gamma1': - Condition(domain=CartesianDomain({ - 'x': [0, 1], - 'y': 1 - }), - equation=FixedValue(0.0)), + Condition(domain=CartesianDomain({ + 'x': [0, 1], + 'y': 1 + }), + equation=FixedValue(0.0)), 'gamma2': - Condition(domain=CartesianDomain({ - 'x': [0, 1], - 'y': 0 - }), - equation=FixedValue(0.0)), + Condition(domain=CartesianDomain({ + 'x': [0, 1], + 'y': 0 + }), + equation=FixedValue(0.0)), 'gamma3': - Condition(domain=CartesianDomain({ - 'x': 1, - 'y': [0, 1] - }), - equation=FixedValue(0.0)), + Condition(domain=CartesianDomain({ + 'x': 1, + 'y': [0, 1] + }), + equation=FixedValue(0.0)), 'gamma4': - Condition(domain=CartesianDomain({ - 'x': 0, - 'y': [0, 1] - }), - equation=FixedValue(0.0)), + Condition(domain=CartesianDomain({ + 'x': 0, + 'y': [0, 1] + }), + equation=FixedValue(0.0)), 'D': - Condition(input_points=LabelTensor(torch.rand(size=(100, 2)), - ['x', 'y']), - equation=my_laplace), + Condition(input_points=LabelTensor(torch.rand(size=(100, 2)), + ['x', 'y']), + equation=my_laplace), 'data': - Condition(input_points=in_, output_points=out_), + Condition(input_points=in_, output_points=out_), 'data2': - Condition(input_points=in2_, output_points=out2_), + Condition(input_points=in2_, output_points=out2_), 'unsupervised': - Condition( - input_points=LabelTensor(torch.rand(size=(45, 2)), ['x', 'y']), - conditional_variables=LabelTensor(torch.ones(size=(45, 1)), - ['alpha']), - ), + Condition( + input_points=LabelTensor(torch.rand(size=(45, 2)), ['x', 'y']), + conditional_variables=LabelTensor(torch.ones(size=(45, 1)), + ['alpha']), + ), 'unsupervised2': - Condition( - input_points=LabelTensor(torch.rand(size=(90, 2)), ['x', 'y']), - conditional_variables=LabelTensor(torch.ones(size=(90, 1)), - ['alpha']), - ) + Condition( + input_points=LabelTensor(torch.rand(size=(90, 2)), ['x', 'y']), + conditional_variables=LabelTensor(torch.ones(size=(90, 1)), + ['alpha']), + ) } @@ -98,8 +100,8 @@ def test_data(): assert dataset.input_points.shape == (61, 2) assert dataset['input_points'].labels == ['x', 'y'] assert dataset.input_points.labels == ['x', 'y'] - assert dataset['input_points', 3:].shape == (58, 2) - assert dataset[3:][1].labels == ['u'] + assert dataset.input_points[3:].shape == (58, 2) + assert dataset.output_points[:3].labels == ['u'] assert dataset.output_points.shape == (61, 1) assert dataset.output_points.labels == ['u'] assert dataset.condition_indices.dtype == torch.uint8 @@ -193,4 +195,32 @@ def test_loader(): assert i.unsupervised.input_points.requires_grad == True -test_loader() +coordinates = LabelTensor(torch.rand((100, 100, 2)), labels=['x', 'y']) +data = LabelTensor(torch.rand((100, 100, 3)), labels=['ux', 'uy', 'p']) + + +class GraphProblem(AbstractProblem): + output = LabelTensor(torch.rand((100, 3)), labels=['ux', 'uy', 'p']) + input = [Graph.build('radius', + nodes_coordinates=coordinates[i, :, :], + nodes_data=data[i, :, :], radius=0.2) + for i in + range(100)] + output_variables = ['u'] + + conditions = { + 'graph_data': Condition(input_points=input, output_points=output) + } + + +graph_problem = GraphProblem() + + +def test_loader_graph(): + data_module = PinaDataModule(graph_problem, device='cpu', batch_size=10) + data_module.setup() + loader = data_module.train_dataloader() + for i in loader: + assert len(i) <= 10 + assert isinstance(i.supervised.input_points, list) + assert all(isinstance(x, Graph) for x in i.supervised.input_points)