Add Graph support in Dataset and Dataloader

This commit is contained in:
FilippoOlivo
2024-10-23 15:04:28 +02:00
committed by Nicola Demo
parent eb146ea2ea
commit ccc5f5a322
11 changed files with 125 additions and 75 deletions

View File

@@ -49,7 +49,7 @@ class Collector:
# if the condition is not ready and domain is not attribute
# of condition, we get and store the data
if (not self._is_conditions_ready[condition_name]) and (
not hasattr(condition, "domain")):
not hasattr(condition, "domain")):
# get data
keys = condition.__slots__
values = [getattr(condition, name) for name in keys]
@@ -94,7 +94,8 @@ class Collector:
self.data_collections[loc] = dict(zip(keys, values))
else:
raise RuntimeError(
'Try to sample variables which are not in problem defined in the problem')
'Try to sample variables which are not in problem defined '
'in the problem')
def add_points(self, new_points_dict):
"""

View File

@@ -4,6 +4,7 @@ Basic data module implementation
from torch.utils.data import Dataset
import torch
from ..label_tensor import LabelTensor
from ..graph import Graph
class BaseDataset(Dataset):
@@ -42,38 +43,43 @@ class BaseDataset(Dataset):
collector = problem.collector
for slot in self.__slots__:
setattr(self, slot, [])
num_el_per_condition = []
idx = 0
for name, data in collector.data_collections.items():
keys = []
for k, v in data.items():
if isinstance(v, LabelTensor):
keys.append(k)
keys = list(data.keys())
current_cond_num_el = None
if sorted(self.__slots__) == sorted(keys):
for slot in self.__slots__:
slot_data = data[slot]
if isinstance(slot_data, (LabelTensor, torch.Tensor,
Graph)):
if current_cond_num_el is None:
current_cond_num_el = len(slot_data)
elif current_cond_num_el != len(slot_data):
raise ValueError('Different number of conditions')
current_list = getattr(self, slot)
current_list.append(data[slot])
current_list += [data[slot]] if not (
isinstance(data[slot], list)) else data[slot]
num_el_per_condition.append(current_cond_num_el)
self.condition_names[idx] = name
idx += 1
if len(getattr(self, self.__slots__[0])) > 0:
input_list = getattr(self, self.__slots__[0])
if num_el_per_condition:
self.condition_indices = torch.cat(
[
torch.tensor([i] * len(input_list[i]), dtype=torch.uint8)
for i in range(len(self.condition_names))
torch.tensor([i] * num_el_per_condition[i],
dtype=torch.uint8)
for i in range(len(num_el_per_condition))
],
dim=0,
)
for slot in self.__slots__:
current_attribute = getattr(self, slot)
setattr(self, slot, LabelTensor.vstack(current_attribute))
if all(isinstance(a, LabelTensor) for a in current_attribute):
setattr(self, slot, LabelTensor.vstack(current_attribute))
else:
self.condition_indices = torch.tensor([], dtype=torch.uint8)
for slot in self.__slots__:
setattr(self, slot, torch.tensor([]))
self.device = device
def __len__(self):
@@ -89,11 +95,10 @@ class BaseDataset(Dataset):
def __getitem__(self, idx):
if isinstance(idx, str):
return getattr(self, idx).to(self.device)
if isinstance(idx, slice):
to_return_list = []
for i in self.__slots__:
to_return_list.append(getattr(self, i)[[idx]].to(self.device))
to_return_list.append(getattr(self, i)[idx].to(self.device))
return to_return_list
if isinstance(idx, (tuple, list)):

View File

@@ -6,7 +6,8 @@ from .pina_subset import PinaSubset
class Batch:
"""
Implementation of the Batch class used during training to perform SGD optimization.
Implementation of the Batch class used during training to perform SGD
optimization.
"""
def __init__(self, dataset_dict, idx_dict):

View File

@@ -1,6 +1,8 @@
"""
Module for PinaSubset class
"""
from pina import LabelTensor
from torch import Tensor
class PinaSubset:
@@ -23,4 +25,9 @@ class PinaSubset:
return len(self.indices)
def __getattr__(self, name):
return self.dataset.__getattribute__(name)
tensor = self.dataset.__getattribute__(name)
if isinstance(tensor, (LabelTensor, Tensor)):
return tensor[self.indices]
if isinstance(tensor, list):
return [tensor[i] for i in self.indices]
raise AttributeError("No attribute named {}".format(name))

View File

@@ -2,6 +2,8 @@
Sample dataset module
"""
from .base_dataset import BaseDataset
from ..condition.input_equation_condition import InputPointsEquationCondition
class SamplePointDataset(BaseDataset):
"""
@@ -9,4 +11,4 @@ class SamplePointDataset(BaseDataset):
composed of only input points.
"""
data_type = 'physics'
__slots__ = ['input_points']
__slots__ = InputPointsEquationCondition.__slots__

View File

@@ -6,7 +6,8 @@ from .base_dataset import BaseDataset
class SupervisedDataset(BaseDataset):
"""
This class extends the BaseDataset to handle datasets that consist of input-output pairs.
This class extends the BaseDataset to handle datasets that consist of
input-output pairs.
"""
data_type = 'supervised'
__slots__ = ['input_points', 'output_points']

View File

@@ -413,7 +413,6 @@ class LabelTensor(torch.Tensor):
return selected_lt
def _getitem_permutation(self, index, selected_lt):
new_labels = deepcopy(self.full_labels)
new_labels.update(self._update_label_for_dim(self.full_labels, index,
0))
@@ -429,6 +428,8 @@ class LabelTensor(torch.Tensor):
:param dim:
:return:
"""
if isinstance(index, torch.Tensor):
index = index.nonzero()
if isinstance(index, list):
return {dim: {'dof': [old_labels[dim]['dof'][i] for i in index],
'name': old_labels[dim]['name']}}
@@ -436,7 +437,6 @@ class LabelTensor(torch.Tensor):
return {dim: {'dof': old_labels[dim]['dof'][index],
'name': old_labels[dim]['name']}}
def sort_labels(self, dim=None):
def argsort(lst):
return sorted(range(len(lst)), key=lambda x: lst[x])

View File

@@ -38,7 +38,7 @@ class SolverInterface(pytorch_lightning.LightningModule, metaclass=ABCMeta):
check_consistency(problem, AbstractProblem)
self._check_solver_consistency(problem)
#Check consistency of models argument and encapsulate in list
# Check consistency of models argument and encapsulate in list
if not isinstance(models, list):
check_consistency(models, torch.nn.Module)
# put everything in a list if only one input
@@ -49,17 +49,17 @@ class SolverInterface(pytorch_lightning.LightningModule, metaclass=ABCMeta):
check_consistency(models[idx], torch.nn.Module)
len_model = len(models)
#If use_lt is true add extract operation in input
# If use_lt is true add extract operation in input
if use_lt is True:
for idx in range(len(models)):
for idx, model in enumerate(models):
models[idx] = Network(
model=models[idx],
model=model,
input_variables=problem.input_variables,
output_variables=problem.output_variables,
extra_features=extra_features,
)
#Check scheduler consistency + encapsulation
# Check scheduler consistency + encapsulation
if not isinstance(schedulers, list):
check_consistency(schedulers, Scheduler)
schedulers = [schedulers]
@@ -67,7 +67,7 @@ class SolverInterface(pytorch_lightning.LightningModule, metaclass=ABCMeta):
for scheduler in schedulers:
check_consistency(scheduler, Scheduler)
#Check optimizer consistency + encapsulation
# Check optimizer consistency + encapsulation
if not isinstance(optimizers, list):
check_consistency(optimizers, Optimizer)
optimizers = [optimizers]
@@ -141,5 +141,6 @@ class SolverInterface(pytorch_lightning.LightningModule, metaclass=ABCMeta):
if not set(self.accepted_condition_types).issubset(
condition.condition_type):
raise ValueError(
f'{self.__name__} support only dose not support condition {condition.condition_type}'
f'{self.__name__} support only dose not support condition '
f'{condition.condition_type}'
)

View File

@@ -130,14 +130,13 @@ class SupervisedSolver(SolverInterface):
if not hasattr(condition, "output_points"):
raise NotImplementedError(
f"{type(self).__name__} works only in data-driven mode.")
output_pts = out[condition_idx == condition_id]
input_pts = pts[condition_idx == condition_id]
input_pts.labels = pts.labels
output_pts.labels = out.labels
loss = (self.loss_data(input_pts=input_pts, output_pts=output_pts))
loss = self.loss_data(input_pts=input_pts, output_pts=output_pts)
loss = loss.as_subclass(torch.Tensor)
self.log("mean_loss", float(loss), prog_bar=True, logger=True)

View File

@@ -60,9 +60,12 @@ class Trainer(pytorch_lightning.Trainer):
if not self.solver.problem.collector.full:
error_message = '\n'.join(
[
f'{" " * 13} ---> Condition {key} {"sampled" if value else "not sampled"}'
f"""{" " * 13} ---> Condition {key} {"sampled" if value else
"not sampled"}"""
for key, value in
self.solver.problem.collector._is_conditions_ready.items()])
self._solver.problem.collector._is_conditions_ready.items()
]
)
raise RuntimeError('Cannot create Trainer if not all conditions '
'are sampled. The Trainer got the following:\n'
f'{error_message}')