Implement Dataset, Dataloader and DataModule class and fix SupervisedSolver

This commit is contained in:
FilippoOlivo
2024-10-16 11:24:37 +02:00
committed by Nicola Demo
parent b9753c34b2
commit c9304fb9bb
30 changed files with 770 additions and 784 deletions

View File

@@ -5,7 +5,8 @@ __all__ = [
"Plotter", "Plotter",
"Condition", "Condition",
"SamplePointDataset", "SamplePointDataset",
"SamplePointLoader", "PinaDataModule",
"PinaDataLoader"
] ]
from .meta import * from .meta import *
@@ -15,4 +16,5 @@ from .trainer import Trainer
from .plotter import Plotter from .plotter import Plotter
from .condition.condition import Condition from .condition.condition import Condition
from .data import SamplePointDataset from .data import SamplePointDataset
from .data import SamplePointLoader from .data import PinaDataModule
from .data import PinaDataLoader

View File

@@ -3,6 +3,7 @@ from sympy.strategies.branch import condition
from . import LabelTensor from . import LabelTensor
from .utils import check_consistency, merge_tensors from .utils import check_consistency, merge_tensors
class Collector: class Collector:
def __init__(self, problem): def __init__(self, problem):
# creating a hook between collector and problem # creating a hook between collector and problem

View File

@@ -5,7 +5,7 @@ class ConditionInterface(metaclass=ABCMeta):
condition_types = ['physics', 'supervised', 'unsupervised'] condition_types = ['physics', 'supervised', 'unsupervised']
def __init__(self, *args, **wargs): def __init__(self, *args, **kwargs):
self._condition_type = None self._condition_type = None
self._problem = None self._problem = None

View File

@@ -22,11 +22,11 @@ class DataConditionInterface(ConditionInterface):
super().__init__() super().__init__()
self.input_points = input_points self.input_points = input_points
self.conditional_variables = conditional_variables self.conditional_variables = conditional_variables
self.condition_type = 'unsupervised' self._condition_type = 'unsupervised'
def __setattr__(self, key, value): def __setattr__(self, key, value):
if (key == 'input_points') or (key == 'conditional_variables'): if (key == 'input_points') or (key == 'conditional_variables'):
check_consistency(value, (LabelTensor, Graph, torch.Tensor)) check_consistency(value, (LabelTensor, Graph, torch.Tensor))
DataConditionInterface.__dict__[key].__set__(self, value) DataConditionInterface.__dict__[key].__set__(self, value)
elif key in ('problem', 'condition_type'): elif key in ('_problem', '_condition_type'):
super().__setattr__(key, value) super().__setattr__(key, value)

View File

@@ -20,7 +20,7 @@ class DomainEquationCondition(ConditionInterface):
super().__init__() super().__init__()
self.domain = domain self.domain = domain
self.equation = equation self.equation = equation
self.condition_type = 'physics' self._condition_type = 'physics'
def __setattr__(self, key, value): def __setattr__(self, key, value):
if key == 'domain': if key == 'domain':
@@ -29,5 +29,5 @@ class DomainEquationCondition(ConditionInterface):
elif key == 'equation': elif key == 'equation':
check_consistency(value, (EquationInterface)) check_consistency(value, (EquationInterface))
DomainEquationCondition.__dict__[key].__set__(self, value) DomainEquationCondition.__dict__[key].__set__(self, value)
elif key in ('problem', 'condition_type'): elif key in ('_problem', '_condition_type'):
super().__setattr__(key, value) super().__setattr__(key, value)

View File

@@ -30,5 +30,5 @@ class InputPointsEquationCondition(ConditionInterface):
elif key == 'equation': elif key == 'equation':
check_consistency(value, (EquationInterface)) check_consistency(value, (EquationInterface))
InputPointsEquationCondition.__dict__[key].__set__(self, value) InputPointsEquationCondition.__dict__[key].__set__(self, value)
elif key in ('problem', 'condition_type'): elif key in ('_problem', '_condition_type'):
super().__setattr__(key, value) super().__setattr__(key, value)

View File

@@ -21,11 +21,11 @@ class InputOutputPointsCondition(ConditionInterface):
super().__init__() super().__init__()
self.input_points = input_points self.input_points = input_points
self.output_points = output_points self.output_points = output_points
self.condition_type = ['supervised', 'physics'] self._condition_type = ['supervised', 'physics']
def __setattr__(self, key, value): def __setattr__(self, key, value):
if (key == 'input_points') or (key == 'output_points'): if (key == 'input_points') or (key == 'output_points'):
check_consistency(value, (LabelTensor, Graph, torch.Tensor)) check_consistency(value, (LabelTensor, Graph, torch.Tensor))
InputOutputPointsCondition.__dict__[key].__set__(self, value) InputOutputPointsCondition.__dict__[key].__set__(self, value)
elif key in ('problem', 'condition_type'): elif key in ('_problem', '_condition_type'):
super().__setattr__(key, value) super().__setattr__(key, value)

View File

@@ -1,7 +1,20 @@
"""
Import data classes
"""
__all__ = [ __all__ = [
'PinaDataLoader',
'SupervisedDataset',
'SamplePointDataset',
'UnsupervisedDataset',
'Batch',
'PinaDataModule',
'BaseDataset'
] ]
from .pina_dataloader import SamplePointLoader from .pina_dataloader import PinaDataLoader
from .data_dataset import DataPointDataset from .supervised_dataset import SupervisedDataset
from .sample_dataset import SamplePointDataset from .sample_dataset import SamplePointDataset
from .unsupervised_dataset import UnsupervisedDataset
from .pina_batch import Batch from .pina_batch import Batch
from .data_module import PinaDataModule
from .base_dataset import BaseDataset

107
pina/data/base_dataset.py Normal file
View File

@@ -0,0 +1,107 @@
"""
Basic data module implementation
"""
from torch.utils.data import Dataset
import torch
from ..label_tensor import LabelTensor
class BaseDataset(Dataset):
"""
BaseDataset class, which handle initialization and data retrieval
:var condition_indices: List of indices
:var device: torch.device
:var condition_names: dict of condition index and corresponding name
"""
def __new__(cls, problem, device):
"""
Ensure correct definition of __slots__ before initialization
:param AbstractProblem problem: The formulation of the problem.
:param torch.device device: The device on which the
dataset will be loaded.
"""
if cls is BaseDataset:
raise TypeError('BaseDataset cannot be instantiated directly. Use a subclass.')
if not hasattr(cls, '__slots__'):
raise TypeError('Something is wrong, __slots__ must be defined in subclasses.')
return super().__new__(cls)
def __init__(self, problem, device):
""""
Initialize the object based on __slots__
:param AbstractProblem problem: The formulation of the problem.
:param torch.device device: The device on which the
dataset will be loaded.
"""
super().__init__()
self.condition_names = {}
collector = problem.collector
for slot in self.__slots__:
setattr(self, slot, [])
idx = 0
for name, data in collector.data_collections.items():
keys = []
for k, v in data.items():
if isinstance(v, LabelTensor):
keys.append(k)
if sorted(self.__slots__) == sorted(keys):
for slot in self.__slots__:
current_list = getattr(self, slot)
current_list.append(data[slot])
self.condition_names[idx] = name
idx += 1
if len(getattr(self, self.__slots__[0])) > 0:
input_list = getattr(self, self.__slots__[0])
self.condition_indices = torch.cat(
[
torch.tensor([i] * len(input_list[i]), dtype=torch.uint8)
for i in range(len(self.condition_names))
],
dim=0,
)
for slot in self.__slots__:
current_attribute = getattr(self, slot)
setattr(self, slot, LabelTensor.vstack(current_attribute))
else:
self.condition_indices = torch.tensor([], dtype=torch.uint8)
for slot in self.__slots__:
setattr(self, slot, torch.tensor([]))
self.device = device
def __len__(self):
return len(getattr(self, self.__slots__[0]))
def __getattribute__(self, item):
attribute = super().__getattribute__(item)
if isinstance(attribute, LabelTensor) and attribute.dtype == torch.float32:
attribute = attribute.to(device=self.device).requires_grad_()
return attribute
def __getitem__(self, idx):
if isinstance(idx, str):
return getattr(self, idx).to(self.device)
if isinstance(idx, slice):
to_return_list = []
for i in self.__slots__:
to_return_list.append(getattr(self, i)[[idx]].to(self.device))
return to_return_list
if isinstance(idx, (tuple, list)):
if (len(idx) == 2 and isinstance(idx[0], str)
and isinstance(idx[1], (list, slice))):
tensor = getattr(self, idx[0])
return tensor[[idx[1]]].to(self.device)
if all(isinstance(x, int) for x in idx):
to_return_list = []
for i in self.__slots__:
to_return_list.append(getattr(self, i)[[idx]].to(self.device))
return to_return_list
raise ValueError(f'Invalid index {idx}')

View File

@@ -1,41 +0,0 @@
from torch.utils.data import Dataset
import torch
from ..label_tensor import LabelTensor
class DataPointDataset(Dataset):
def __init__(self, problem, device) -> None:
super().__init__()
input_list = []
output_list = []
self.condition_names = []
for name, condition in problem.conditions.items():
if hasattr(condition, "output_points"):
input_list.append(problem.conditions[name].input_points)
output_list.append(problem.conditions[name].output_points)
self.condition_names.append(name)
self.input_pts = LabelTensor.stack(input_list)
self.output_pts = LabelTensor.stack(output_list)
if self.input_pts != []:
self.condition_indeces = torch.cat(
[
torch.tensor([i] * len(input_list[i]))
for i in range(len(self.condition_names))
],
dim=0,
)
else: # if there are no data points
self.condition_indeces = torch.tensor([])
self.input_pts = torch.tensor([])
self.output_pts = torch.tensor([])
self.input_pts = self.input_pts.to(device)
self.output_pts = self.output_pts.to(device)
self.condition_indeces = self.condition_indeces.to(device)
def __len__(self):
return self.input_pts.shape[0]

172
pina/data/data_module.py Normal file
View File

@@ -0,0 +1,172 @@
"""
This module provide basic data management functionalities
"""
import math
import torch
from lightning import LightningDataModule
from .sample_dataset import SamplePointDataset
from .supervised_dataset import SupervisedDataset
from .unsupervised_dataset import UnsupervisedDataset
from .pina_dataloader import PinaDataLoader
from .pina_subset import PinaSubset
class PinaDataModule(LightningDataModule):
"""
This class extend LightningDataModule, allowing proper creation and
management of different types of Datasets defined in PINA
"""
def __init__(self,
problem,
device,
train_size=.7,
test_size=.2,
eval_size=.1,
batch_size=None,
shuffle=True,
datasets = None):
"""
Initialize the object, creating dataset based on input problem
:param AbstractProblem problem: PINA problem
:param device: Device used for training and testing
:param train_size: number/percentage of elements in train split
:param test_size: number/percentage of elements in test split
:param eval_size: number/percentage of elements in evaluation split
:param batch_size: batch size used for training
:param datasets: list of datasets objects
"""
super().__init__()
dataset_classes = [SupervisedDataset, UnsupervisedDataset, SamplePointDataset]
if datasets is None:
self.datasets = [DatasetClass(problem, device) for DatasetClass in dataset_classes]
else:
self.datasets = datasets
self.split_length = []
self.split_names = []
if train_size > 0:
self.split_names.append('train')
self.split_length.append(train_size)
if test_size > 0:
self.split_length.append(test_size)
self.split_names.append('test')
if eval_size > 0:
self.split_length.append(eval_size)
self.split_names.append('eval')
self.batch_size = batch_size
self.condition_names = None
self.splits = {k: {} for k in self.split_names}
self.shuffle = shuffle
def setup(self, stage=None):
"""
Perform the splitting of the dataset
"""
self.extract_conditions()
if stage == 'fit' or stage is None:
for dataset in self.datasets:
if len(dataset) > 0:
splits = self.dataset_split(dataset,
self.split_length,
shuffle=self.shuffle)
for i in range(len(self.split_length)):
self.splits[
self.split_names[i]][dataset.data_type] = splits[i]
elif stage == 'test':
raise NotImplementedError("Testing pipeline not implemented yet")
else:
raise ValueError("stage must be either 'fit' or 'test'")
def extract_conditions(self):
"""
Extract conditions from dataset and update condition indices
"""
# Extract number of conditions
n_conditions = 0
for dataset in self.datasets:
if n_conditions != 0:
dataset.condition_names = {
key + n_conditions: value
for key, value in dataset.condition_names.items()
}
n_conditions += len(dataset.condition_names)
self.condition_names = {
key: value
for dataset in self.datasets
for key, value in dataset.condition_names.items()
}
def train_dataloader(self):
"""
Return the training dataloader for the dataset
:return: data loader
:rtype: PinaDataLoader
"""
return PinaDataLoader(self.splits['train'], self.batch_size,
self.condition_names)
def test_dataloader(self):
"""
Return the testing dataloader for the dataset
:return: data loader
:rtype: PinaDataLoader
"""
return PinaDataLoader(self.splits['test'], self.batch_size,
self.condition_names)
def eval_dataloader(self):
"""
Return the evaluation dataloader for the dataset
:return: data loader
:rtype: PinaDataLoader
"""
return PinaDataLoader(self.splits['eval'], self.batch_size,
self.condition_names)
@staticmethod
def dataset_split(dataset, lengths, seed=None, shuffle=True):
"""
Perform the splitting of the dataset
:param dataset: dataset object we wanted to split
:param lengths: lengths of elements in dataset
:param seed: random seed
:param shuffle: shuffle dataset
:return: split dataset
:rtype: PinaSubset
"""
if sum(lengths) - 1 < 1e-3:
lengths = [
int(math.floor(len(dataset) * length)) for length in lengths
]
remainder = len(dataset) - sum(lengths)
for i in range(remainder):
lengths[i % len(lengths)] += 1
elif sum(lengths) - 1 >= 1e-3:
raise ValueError(f"Sum of lengths is {sum(lengths)} less than 1")
if sum(lengths) != len(dataset):
raise ValueError("Sum of lengths is not equal to dataset length")
if shuffle:
if seed is not None:
generator = torch.Generator()
generator.manual_seed(seed)
indices = torch.randperm(sum(lengths), generator=generator).tolist()
else:
indices = torch.arange(sum(lengths)).tolist()
else:
indices = torch.arange(0, sum(lengths), 1, dtype=torch.uint8).tolist()
offsets = [
sum(lengths[:i]) if i > 0 else 0 for i in range(len(lengths))
]
return [
PinaSubset(dataset, indices[offset:offset + length])
for offset, length in zip(offsets, lengths)
]

View File

@@ -1,36 +1,33 @@
"""
Batch management module
"""
from .pina_subset import PinaSubset
class Batch: class Batch:
def __init__(self, dataset_dict, idx_dict):
for k, v in dataset_dict.items():
setattr(self, k, v)
for k, v in idx_dict.items():
setattr(self, k + '_idx', v)
def __len__(self):
""" """
This class is used to create a dataset of sample points. Returns the number of elements in the batch
:return: number of elements in the batch
:rtype: int
""" """
length = 0
for dataset in dir(self):
attribute = getattr(self, dataset)
if isinstance(attribute, list):
length += len(getattr(self, dataset))
return length
def __init__(self, type_, idx, *args, **kwargs) -> None: def __getattr__(self, item):
""" if not item in dir(self):
""" raise AttributeError(f'Batch instance has no attribute {item}')
if type_ == "sample": return PinaSubset(getattr(self, item).dataset,
getattr(self, item).indices[self.coordinates_dict[item]])
if len(args) != 2:
raise RuntimeError
input = args[0]
conditions = args[1]
self.input = input[idx]
self.condition = conditions[idx]
elif type_ == "data":
if len(args) != 3:
raise RuntimeError
input = args[0]
output = args[1]
conditions = args[2]
self.input = input[idx]
self.output = output[idx]
self.condition = conditions[idx]
else:
raise ValueError("Invalid number of arguments.")

View File

@@ -1,11 +1,11 @@
import torch """
This module is used to create an iterable object used during training
from .sample_dataset import SamplePointDataset """
from .data_dataset import DataPointDataset import math
from .pina_batch import Batch from .pina_batch import Batch
class SamplePointLoader: class PinaDataLoader:
""" """
This class is used to create a dataloader to use during the training. This class is used to create a dataloader to use during the training.
@@ -14,198 +14,54 @@ class SamplePointLoader:
:vartype condition_names: list[str] :vartype condition_names: list[str]
""" """
def __init__( def __init__(self, dataset_dict, batch_size, condition_names) -> None:
self, sample_dataset, data_dataset, batch_size=None, shuffle=True
) -> None:
""" """
Constructor. Initialize local variables
:param dataset_dict: Dictionary of datasets
:param SamplePointDataset sample_pts: The sample points dataset. :type dataset_dict: dict
:param int batch_size: The batch size. If ``None``, the batch size is :param batch_size: Size of the batch
set to the number of sample points. Default is ``None``. :type batch_size: int
:param bool shuffle: If ``True``, the sample points are shuffled. :param condition_names: Names of the conditions
Default is ``True``. :type condition_names: list[str]
""" """
if not isinstance(sample_dataset, SamplePointDataset): self.condition_names = condition_names
raise TypeError( self.dataset_dict = dataset_dict
f"Expected SamplePointDataset, got {type(sample_dataset)}" self._init_batches(batch_size)
)
if not isinstance(data_dataset, DataPointDataset):
raise TypeError(
f"Expected DataPointDataset, got {type(data_dataset)}"
)
self.n_data_conditions = len(data_dataset.condition_names) def _init_batches(self, batch_size=None):
self.n_phys_conditions = len(sample_dataset.condition_names)
data_dataset.condition_indeces += self.n_phys_conditions
self._prepare_sample_dataset(sample_dataset, batch_size, shuffle)
self._prepare_data_dataset(data_dataset, batch_size, shuffle)
self.condition_names = (
sample_dataset.condition_names + data_dataset.condition_names
)
self.batch_list = []
for i in range(len(self.batch_sample_pts)):
self.batch_list.append(("sample", i))
for i in range(len(self.batch_input_pts)):
self.batch_list.append(("data", i))
if shuffle:
self.random_idx = torch.randperm(len(self.batch_list))
else:
self.random_idx = torch.arange(len(self.batch_list))
self._prepare_batches()
def _prepare_data_dataset(self, dataset, batch_size, shuffle):
""" """
Prepare the dataset for data points. Create batches according to the batch_size provided in input.
:param SamplePointDataset dataset: The dataset.
:param int batch_size: The batch size.
:param bool shuffle: If ``True``, the sample points are shuffled.
"""
self.sample_dataset = dataset
if len(dataset) == 0:
self.batch_data_conditions = []
self.batch_input_pts = []
self.batch_output_pts = []
return
if batch_size is None:
batch_size = len(dataset)
batch_num = len(dataset) // batch_size
if len(dataset) % batch_size != 0:
batch_num += 1
output_labels = dataset.output_pts.labels
input_labels = dataset.input_pts.labels
self.tensor_conditions = dataset.condition_indeces
if shuffle:
idx = torch.randperm(dataset.input_pts.shape[0])
self.input_pts = dataset.input_pts[idx]
self.output_pts = dataset.output_pts[idx]
self.tensor_conditions = dataset.condition_indeces[idx]
self.batch_input_pts = torch.tensor_split(dataset.input_pts, batch_num)
self.batch_output_pts = torch.tensor_split(
dataset.output_pts, batch_num
)
#print(input_labels)
for i in range(len(self.batch_input_pts)):
self.batch_input_pts[i].labels = input_labels
self.batch_output_pts[i].labels = output_labels
self.batch_data_conditions = torch.tensor_split(
self.tensor_conditions, batch_num
)
def _prepare_sample_dataset(self, dataset, batch_size, shuffle):
"""
Prepare the dataset for sample points.
:param DataPointDataset dataset: The dataset.
:param int batch_size: The batch size.
:param bool shuffle: If ``True``, the sample points are shuffled.
"""
self.sample_dataset = dataset
if len(dataset) == 0:
self.batch_sample_conditions = []
self.batch_sample_pts = []
return
if batch_size is None:
batch_size = len(dataset)
batch_num = len(dataset) // batch_size
if len(dataset) % batch_size != 0:
batch_num += 1
self.tensor_pts = dataset.pts
self.tensor_conditions = dataset.condition_indeces
# if shuffle:
# idx = torch.randperm(self.tensor_pts.shape[0])
# self.tensor_pts = self.tensor_pts[idx]
# self.tensor_conditions = self.tensor_conditions[idx]
self.batch_sample_pts = torch.tensor_split(self.tensor_pts, batch_num)
for i in range(len(self.batch_sample_pts)):
self.batch_sample_pts[i].labels = dataset.pts.labels
self.batch_sample_conditions = torch.tensor_split(
self.tensor_conditions, batch_num
)
def _prepare_batches(self):
"""
Prepare the batches.
""" """
self.batches = [] self.batches = []
for i in range(len(self.batch_list)): n_elements = sum([len(v) for v in self.dataset_dict.values()])
type_, idx_ = self.batch_list[i] if batch_size is None:
batch_size = n_elements
if type_ == "sample": indexes_dict = {}
batch = Batch( n_batches = int(math.ceil(n_elements / batch_size))
"sample", idx_, for k, v in self.dataset_dict.items():
self.batch_sample_pts, if n_batches != 1:
self.batch_sample_conditions) indexes_dict[k] = math.floor(len(v) / (n_batches - 1))
else: else:
batch = Batch( indexes_dict[k] = len(v)
"data", idx_, for i in range(n_batches):
self.batch_input_pts, temp_dict = {}
self.batch_output_pts, for k, v in indexes_dict.items():
self.batch_data_conditions) if i != n_batches - 1:
temp_dict[k] = slice(i * v, (i + 1) * v)
self.batches.append(batch) else:
temp_dict[k] = slice(i * v, len(self.dataset_dict[k]))
self.batches.append(Batch(idx_dict=temp_dict, dataset_dict=self.dataset_dict))
def __iter__(self): def __iter__(self):
""" """
Return an iterator over the points. Any element of the iterator is a Makes dataloader object iterable
dictionary with the following keys:
- ``pts``: The input sample points. It is a LabelTensor with the
shape ``(batch_size, input_dimension)``.
- ``output``: The output sample points. This key is present only
if data conditions are present. It is a LabelTensor with the
shape ``(batch_size, output_dimension)``.
- ``condition``: The integer condition indeces. It is a tensor
with the shape ``(batch_size, )`` of type ``torch.int64`` and
indicates for any ``pts`` the corresponding problem condition.
:return: An iterator over the points.
:rtype: iter
""" """
# for i in self.random_idx: yield from self.batches
for i in self.random_idx:
yield self.batches[i]
# for i in range(len(self.batch_list)):
# type_, idx_ = self.batch_list[i]
# if type_ == "sample":
# d = {
# "pts": self.batch_sample_pts[idx_].requires_grad_(True),
# "condition": self.batch_sample_conditions[idx_],
# }
# else:
# d = {
# "pts": self.batch_input_pts[idx_].requires_grad_(True),
# "output": self.batch_output_pts[idx_],
# "condition": self.batch_data_conditions[idx_],
# }
# yield d
def __len__(self): def __len__(self):
""" """
Return the number of batches. Return the number of batches.
:return: The number of batches. :return: The number of batches.
:rtype: int :rtype: int
""" """
return len(self.batch_list) return len(self.batches)

21
pina/data/pina_subset.py Normal file
View File

@@ -0,0 +1,21 @@
class PinaSubset:
"""
TODO
"""
__slots__ = ['dataset', 'indices']
def __init__(self, dataset, indices):
"""
TODO
"""
self.dataset = dataset
self.indices = indices
def __len__(self):
"""
TODO
"""
return len(self.indices)
def __getattr__(self, name):
return self.dataset.__getattribute__(name)

View File

@@ -1,43 +1,12 @@
from torch.utils.data import Dataset
import torch
from ..label_tensor import LabelTensor
class SamplePointDataset(Dataset):
""" """
This class is used to create a dataset of sample points. Sample dataset module
""" """
from .base_dataset import BaseDataset
def __init__(self, problem, device) -> None: class SamplePointDataset(BaseDataset):
""" """
:param dict input_pts: The input points. This class extends the BaseDataset to handle physical datasets
composed of only input points.
""" """
super().__init__() data_type = 'physics'
pts_list = [] __slots__ = ['input_points']
self.condition_names = []
for name, condition in problem.conditions.items():
if not hasattr(condition, "output_points"):
pts_list.append(problem.input_pts[name])
self.condition_names.append(name)
self.pts = LabelTensor.stack(pts_list)
if self.pts != []:
self.condition_indeces = torch.cat(
[
torch.tensor([i] * len(pts_list[i]))
for i in range(len(self.condition_names))
],
dim=0,
)
else: # if there are no sample points
self.condition_indeces = torch.tensor([])
self.pts = torch.tensor([])
self.pts = self.pts.to(device)
self.condition_indeces = self.condition_indeces.to(device)
def __len__(self):
return self.pts.shape[0]

View File

@@ -0,0 +1,12 @@
"""
Supervised dataset module
"""
from .base_dataset import BaseDataset
class SupervisedDataset(BaseDataset):
"""
This class extends the BaseDataset to handle datasets that consist of input-output pairs.
"""
data_type = 'supervised'
__slots__ = ['input_points', 'output_points']

View File

@@ -0,0 +1,13 @@
"""
Unsupervised dataset module
"""
from .base_dataset import BaseDataset
class UnsupervisedDataset(BaseDataset):
"""
This class extend BaseDataset class to handle unsupervised dataset,
composed of input points and, optionally, conditional variables
"""
data_type = 'unsupervised'
__slots__ = ['input_points', 'conditional_variables']

View File

@@ -55,7 +55,6 @@ class EllipsoidDomain(DomainInterface):
# perform operation only for not fixed variables (if any) # perform operation only for not fixed variables (if any)
if self.range_: if self.range_:
# convert dict vals to torch [dim, 2] matrix # convert dict vals to torch [dim, 2] matrix
list_dict_vals = list(self.range_.values()) list_dict_vals = list(self.range_.values())
tmp = torch.tensor(list_dict_vals, dtype=torch.float) tmp = torch.tensor(list_dict_vals, dtype=torch.float)

View File

@@ -3,6 +3,7 @@ from copy import deepcopy, copy
import torch import torch
from torch import Tensor from torch import Tensor
def issubset(a, b): def issubset(a, b):
""" """
Check if a is a subset of b. Check if a is a subset of b.
@@ -382,6 +383,7 @@ class LabelTensor(torch.Tensor):
def sort_labels(self, dim=None): def sort_labels(self, dim=None):
def argsort(lst): def argsort(lst):
return sorted(range(len(lst)), key=lambda x: lst[x]) return sorted(range(len(lst)), key=lambda x: lst[x])
if dim is None: if dim is None:
dim = self.tensor.ndim - 1 dim = self.tensor.ndim - 1
labels = self.full_labels[dim]['dof'] labels = self.full_labels[dim]['dof']

View File

@@ -10,168 +10,6 @@ import torch
import sys import sys
# class SolverInterface(pytorch_lightning.LightningModule, metaclass=ABCMeta):
# """
# Solver base class. This class inherits is a wrapper of
# LightningModule class, inheriting all the
# LightningModule methods.
# """
# def __init__(
# self,
# models,
# problem,
# optimizers,
# optimizers_kwargs,
# extra_features=None,
# ):
# """
# :param models: A torch neural network model instance.
# :type models: torch.nn.Module
# :param problem: A problem definition instance.
# :type problem: AbstractProblem
# :param list(torch.optim.Optimizer) optimizer: A list of neural network optimizers to
# use.
# :param list(dict) optimizer_kwargs: A list of optimizer constructor keyword args.
# :param list(torch.nn.Module) extra_features: The additional input
# features to use as augmented input. If ``None`` no extra features
# are passed. If it is a list of :class:`torch.nn.Module`, the extra feature
# list is passed to all models. If it is a list of extra features' lists,
# each single list of extra feature is passed to a model.
# """
# super().__init__()
# # check consistency of the inputs
# check_consistency(models, torch.nn.Module)
# check_consistency(problem, AbstractProblem)
# check_consistency(optimizers, torch.optim.Optimizer, subclass=True)
# check_consistency(optimizers_kwargs, dict)
# # put everything in a list if only one input
# if not isinstance(models, list):
# models = [models]
# if not isinstance(optimizers, list):
# optimizers = [optimizers]
# optimizers_kwargs = [optimizers_kwargs]
# # number of models and optimizers
# len_model = len(models)
# len_optimizer = len(optimizers)
# len_optimizer_kwargs = len(optimizers_kwargs)
# # check length consistency optimizers
# if len_model != len_optimizer:
# raise ValueError(
# "You must define one optimizer for each model."
# f"Got {len_model} models, and {len_optimizer}"
# " optimizers."
# )
# # check length consistency optimizers kwargs
# if len_optimizer_kwargs != len_optimizer:
# raise ValueError(
# "You must define one dictionary of keyword"
# " arguments for each optimizers."
# f"Got {len_optimizer} optimizers, and"
# f" {len_optimizer_kwargs} dicitionaries"
# )
# # extra features handling
# if (extra_features is None) or (len(extra_features) == 0):
# extra_features = [None] * len_model
# else:
# # if we only have a list of extra features
# if not isinstance(extra_features[0], (tuple, list)):
# extra_features = [extra_features] * len_model
# else: # if we have a list of list extra features
# if len(extra_features) != len_model:
# raise ValueError(
# "You passed a list of extrafeatures list with len"
# f"different of models len. Expected {len_model} "
# f"got {len(extra_features)}. If you want to use "
# "the same list of extra features for all models, "
# "just pass a list of extrafeatures and not a list "
# "of list of extra features."
# )
# # assigning model and optimizers
# self._pina_models = []
# self._pina_optimizers = []
# for idx in range(len_model):
# model_ = Network(
# model=models[idx],
# input_variables=problem.input_variables,
# output_variables=problem.output_variables,
# extra_features=extra_features[idx],
# )
# optim_ = optimizers[idx](
# model_.parameters(), **optimizers_kwargs[idx]
# )
# self._pina_models.append(model_)
# self._pina_optimizers.append(optim_)
# # assigning problem
# self._pina_problem = problem
# @abstractmethod
# def forward(self, *args, **kwargs):
# pass
# @abstractmethod
# def training_step(self):
# pass
# @abstractmethod
# def configure_optimizers(self):
# pass
# @property
# def models(self):
# """
# The torch model."""
# return self._pina_models
# @property
# def optimizers(self):
# """
# The torch model."""
# return self._pina_optimizers
# @property
# def problem(self):
# """
# The problem formulation."""
# return self._pina_problem
# def on_train_start(self):
# """
# On training epoch start this function is call to do global checks for
# the different solvers.
# """
# # 1. Check the verison for dataloader
# dataloader = self.trainer.train_dataloader
# if sys.version_info < (3, 8):
# dataloader = dataloader.loaders
# self._dataloader = dataloader
# return super().on_train_start()
# @model.setter
# def model(self, new_model):
# """
# Set the torch."""
# check_consistency(new_model, nn.Module, 'torch model')
# self._model= new_model
# @problem.setter
# def problem(self, problem):
# """
# Set the problem formulation."""
# check_consistency(problem, AbstractProblem, 'pina problem')
# self._problem = problem
class SolverInterface(pytorch_lightning.LightningModule, metaclass=ABCMeta): class SolverInterface(pytorch_lightning.LightningModule, metaclass=ABCMeta):
""" """
Solver base class. This class inherits is a wrapper of Solver base class. This class inherits is a wrapper of
@@ -181,10 +19,12 @@ class SolverInterface(pytorch_lightning.LightningModule, metaclass=ABCMeta):
def __init__( def __init__(
self, self,
model, models,
problem, problem,
optimizer, optimizers,
scheduler, schedulers,
extra_features,
use_lt=True
): ):
""" """
:param model: A torch neural network model instance. :param model: A torch neural network model instance.
@@ -197,22 +37,45 @@ class SolverInterface(pytorch_lightning.LightningModule, metaclass=ABCMeta):
super().__init__() super().__init__()
# check consistency of the inputs # check consistency of the inputs
check_consistency(model, torch.nn.Module)
check_consistency(problem, AbstractProblem) check_consistency(problem, AbstractProblem)
check_consistency(optimizer, Optimizer) self._check_solver_consistency(problem)
#Check consistency of models argument and encapsulate in list
if not isinstance(models, list):
check_consistency(models, torch.nn.Module)
# put everything in a list if only one input
models = [models]
else:
for idx in range(len(models)):
# Check consistency
check_consistency(models[idx], torch.nn.Module)
len_model = len(models)
#If use_lt is true add extract operation in input
if use_lt is True:
for idx in range(len(models)):
models[idx] = Network(
model = models[idx],
input_variables=problem.input_variables,
output_variables=problem.output_variables,
extra_features=extra_features, )
#Check scheduler consistency + encapsulation
if not isinstance(schedulers, list):
check_consistency(schedulers, Scheduler)
schedulers = [schedulers]
else:
for scheduler in schedulers:
check_consistency(scheduler, Scheduler) check_consistency(scheduler, Scheduler)
# put everything in a list if only one input #Check optimizer consistency + encapsulation
if not isinstance(model, list): if not isinstance(optimizers, list):
model = [model] check_consistency(optimizers, Optimizer)
if not isinstance(scheduler, list): optimizers = [optimizers]
scheduler = [scheduler] else:
if not isinstance(optimizer, list): for optimizer in optimizers:
optimizer = [optimizer] check_consistency(optimizer, Optimizer)
len_optimizer = len(optimizers)
# number of models and optimizers
len_model = len(model)
len_optimizer = len(optimizer)
# check length consistency optimizers # check length consistency optimizers
if len_model != len_optimizer: if len_model != len_optimizer:
@@ -223,10 +86,12 @@ class SolverInterface(pytorch_lightning.LightningModule, metaclass=ABCMeta):
) )
# extra features handling # extra features handling
self._pina_models = models
self._pina_optimizers = optimizers
self._pina_schedulers = schedulers
self._pina_problem = problem self._pina_problem = problem
self._pina_model = model
self._pina_optimizer = optimizer
self._pina_scheduler = scheduler
@abstractmethod @abstractmethod
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
@@ -244,13 +109,13 @@ class SolverInterface(pytorch_lightning.LightningModule, metaclass=ABCMeta):
def models(self): def models(self):
""" """
The torch model.""" The torch model."""
return self._pina_model return self._pina_models
@property @property
def optimizers(self): def optimizers(self):
""" """
The torch model.""" The torch model."""
return self._pina_optimizer return self._pina_optimizers
@property @property
def problem(self): def problem(self):
@@ -272,16 +137,10 @@ class SolverInterface(pytorch_lightning.LightningModule, metaclass=ABCMeta):
return super().on_train_start() return super().on_train_start()
# @model.setter def _check_solver_consistency(self, problem):
# def model(self, new_model): """
# """ TODO
# Set the torch.""" """
# check_consistency(new_model, nn.Module, 'torch model') for _, condition in problem.conditions.items():
# self._model= new_model if not set(self.accepted_condition_types).issubset(condition.condition_type):
raise ValueError(f'{self.__name__} support only dose not support condition {condition.condition_type}')
# @problem.setter
# def problem(self, problem):
# """
# Set the problem formulation."""
# check_consistency(problem, AbstractProblem, 'pina problem')
# self._problem = problem

View File

@@ -2,9 +2,7 @@
import torch import torch
from torch.nn.modules.loss import _Loss from torch.nn.modules.loss import _Loss
from ..optim import TorchOptimizer, TorchScheduler
from ..optim import Optimizer, Scheduler, TorchOptimizer, TorchScheduler
from .solver import SolverInterface from .solver import SolverInterface
from ..label_tensor import LabelTensor from ..label_tensor import LabelTensor
from ..utils import check_consistency from ..utils import check_consistency
@@ -39,6 +37,8 @@ class SupervisedSolver(SolverInterface):
we are seeking to approximate multiple (discretised) functions given we are seeking to approximate multiple (discretised) functions given
multiple (discretised) input functions. multiple (discretised) input functions.
""" """
accepted_condition_types = ['supervised']
__name__ = 'SupervisedSolver'
def __init__( def __init__(
self, self,
@@ -47,6 +47,7 @@ class SupervisedSolver(SolverInterface):
loss=None, loss=None,
optimizer=None, optimizer=None,
scheduler=None, scheduler=None,
extra_features=None
): ):
""" """
:param AbstractProblem problem: The formualation of the problem. :param AbstractProblem problem: The formualation of the problem.
@@ -57,11 +58,8 @@ class SupervisedSolver(SolverInterface):
features to use as augmented input. features to use as augmented input.
:param torch.optim.Optimizer optimizer: The neural network optimizer to :param torch.optim.Optimizer optimizer: The neural network optimizer to
use; default is :class:`torch.optim.Adam`. use; default is :class:`torch.optim.Adam`.
:param dict optimizer_kwargs: Optimizer constructor keyword args.
:param float lr: The learning rate; default is 0.001.
:param torch.optim.LRScheduler scheduler: Learning :param torch.optim.LRScheduler scheduler: Learning
rate scheduler. rate scheduler.
:param dict scheduler_kwargs: LR scheduler constructor keyword args.
""" """
if loss is None: if loss is None:
loss = torch.nn.MSELoss() loss = torch.nn.MSELoss()
@@ -74,18 +72,19 @@ class SupervisedSolver(SolverInterface):
torch.optim.lr_scheduler.ConstantLR) torch.optim.lr_scheduler.ConstantLR)
super().__init__( super().__init__(
model=model, models=model,
problem=problem, problem=problem,
optimizer=optimizer, optimizers=optimizer,
scheduler=scheduler, schedulers=scheduler,
extra_features=extra_features
) )
# check consistency # check consistency
check_consistency(loss, (LossInterface, _Loss), subclass=False) check_consistency(loss, (LossInterface, _Loss), subclass=False)
self._loss = loss self._loss = loss
self._model = self._pina_model[0] self._model = self._pina_models[0]
self._optimizer = self._pina_optimizer[0] self._optimizer = self._pina_optimizers[0]
self._scheduler = self._pina_scheduler[0] self._scheduler = self._pina_schedulers[0]
def forward(self, x): def forward(self, x):
"""Forward pass implementation for the solver. """Forward pass implementation for the solver.
@@ -97,12 +96,7 @@ class SupervisedSolver(SolverInterface):
output = self._model(x) output = self._model(x)
output.labels = { output.labels = self.problem.output_variables
1: {
"name": "output",
"dof": self.problem.output_variables
}
}
return output return output
def configure_optimizers(self): def configure_optimizers(self):
@@ -128,16 +122,14 @@ class SupervisedSolver(SolverInterface):
:return: The sum of the loss functions. :return: The sum of the loss functions.
:rtype: LabelTensor :rtype: LabelTensor
""" """
condition_idx = batch.supervised.condition_indices
condition_idx = batch.condition
for condition_id in range(condition_idx.min(), condition_idx.max() + 1): for condition_id in range(condition_idx.min(), condition_idx.max() + 1):
condition_name = self._dataloader.condition_names[condition_id] condition_name = self._dataloader.condition_names[condition_id]
condition = self.problem.conditions[condition_name] condition = self.problem.conditions[condition_name]
pts = batch.input pts = batch.supervised.input_points
out = batch.output out = batch.supervised.output_points
if condition_name not in self.problem.conditions: if condition_name not in self.problem.conditions:
raise RuntimeError("Something wrong happened.") raise RuntimeError("Something wrong happened.")
@@ -167,8 +159,8 @@ class SupervisedSolver(SolverInterface):
the network output against the true solution. This function the network output against the true solution. This function
should not be override if not intentionally. should not be override if not intentionally.
:param LabelTensor input_tensor: The input to the neural networks. :param LabelTensor input_pts: The input to the neural networks.
:param LabelTensor output_tensor: The true solution to compare the :param LabelTensor output_pts: The true solution to compare the
network solution. network solution.
:return: The residual loss averaged on the input coordinates :return: The residual loss averaged on the input coordinates
:rtype: torch.Tensor :rtype: torch.Tensor

View File

@@ -3,13 +3,13 @@
import torch import torch
import pytorch_lightning import pytorch_lightning
from .utils import check_consistency from .utils import check_consistency
from .data import SamplePointDataset, SamplePointLoader, DataPointDataset from .data import PinaDataModule
from .solvers.solver import SolverInterface from .solvers.solver import SolverInterface
class Trainer(pytorch_lightning.Trainer): class Trainer(pytorch_lightning.Trainer):
def __init__(self, solver, batch_size=None, **kwargs): def __init__(self, solver, batch_size=None, train_size=.7, test_size=.2, eval_size=.1, **kwargs):
""" """
PINA Trainer class for costumizing every aspect of training via flags. PINA Trainer class for costumizing every aspect of training via flags.
@@ -31,10 +31,11 @@ class Trainer(pytorch_lightning.Trainer):
check_consistency(solver, SolverInterface) check_consistency(solver, SolverInterface)
if batch_size is not None: if batch_size is not None:
check_consistency(batch_size, int) check_consistency(batch_size, int)
self.train_size = train_size
self.test_size = test_size
self.eval_size = eval_size
self.solver = solver self.solver = solver
self.batch_size = batch_size self.batch_size = batch_size
self._create_loader() self._create_loader()
self._move_to_device() self._move_to_device()
@@ -69,11 +70,12 @@ class Trainer(pytorch_lightning.Trainer):
raise RuntimeError("Parallel training is not supported yet.") raise RuntimeError("Parallel training is not supported yet.")
device = devices[0] device = devices[0]
dataset_phys = SamplePointDataset(self.solver.problem, device)
dataset_data = DataPointDataset(self.solver.problem, device) data_module = PinaDataModule(problem=self.solver.problem, device=device,
self._loader = SamplePointLoader( train_size=self.train_size, test_size=self.test_size,
dataset_phys, dataset_data, batch_size=self.batch_size, shuffle=True eval_size=self.eval_size)
) data_module.setup()
self._loader = data_module.train_dataloader()
def train(self, **kwargs): def train(self, **kwargs):
""" """
@@ -89,3 +91,7 @@ class Trainer(pytorch_lightning.Trainer):
Returning trainer solver. Returning trainer solver.
""" """
return self._solver return self._solver
@solver.setter
def solver(self, solver):
self._solver = solver

View File

@@ -1,12 +1,11 @@
import math
import torch import torch
import pytest from pina.data import SamplePointDataset, SupervisedDataset, PinaDataModule, UnsupervisedDataset, unsupervised_dataset
from pina.data import PinaDataLoader
from pina.data.dataset import SamplePointDataset, SamplePointLoader, DataPointDataset
from pina import LabelTensor, Condition from pina import LabelTensor, Condition
from pina.equation import Equation from pina.equation import Equation
from pina.domain import CartesianDomain from pina.domain import CartesianDomain
from pina.problem import SpatialProblem from pina.problem import SpatialProblem
from pina.model import FeedForward
from pina.operators import laplacian from pina.operators import laplacian
from pina.equation.equation_factory import FixedValue from pina.equation.equation_factory import FixedValue
@@ -17,28 +16,30 @@ def laplace_equation(input_, output_):
delta_u = laplacian(output_.extract(['u']), input_) delta_u = laplacian(output_.extract(['u']), input_)
return delta_u - force_term return delta_u - force_term
my_laplace = Equation(laplace_equation) my_laplace = Equation(laplace_equation)
in_ = LabelTensor(torch.tensor([[0., 1.]]), ['x', 'y']) in_ = LabelTensor(torch.tensor([[0., 1.]]), ['x', 'y'])
out_ = LabelTensor(torch.tensor([[0.]]), ['u']) out_ = LabelTensor(torch.tensor([[0.]]), ['u'])
in2_ = LabelTensor(torch.rand(60, 2), ['x', 'y']) in2_ = LabelTensor(torch.rand(60, 2), ['x', 'y'])
out2_ = LabelTensor(torch.rand(60, 1), ['u']) out2_ = LabelTensor(torch.rand(60, 1), ['u'])
class Poisson(SpatialProblem): class Poisson(SpatialProblem):
output_variables = ['u'] output_variables = ['u']
spatial_domain = CartesianDomain({'x': [0, 1], 'y': [0, 1]}) spatial_domain = CartesianDomain({'x': [0, 1], 'y': [0, 1]})
conditions = { conditions = {
'gamma1': Condition( 'gamma1': Condition(
location=CartesianDomain({'x': [0, 1], 'y': 1}), domain=CartesianDomain({'x': [0, 1], 'y': 1}),
equation=FixedValue(0.0)), equation=FixedValue(0.0)),
'gamma2': Condition( 'gamma2': Condition(
location=CartesianDomain({'x': [0, 1], 'y': 0}), domain=CartesianDomain({'x': [0, 1], 'y': 0}),
equation=FixedValue(0.0)), equation=FixedValue(0.0)),
'gamma3': Condition( 'gamma3': Condition(
location=CartesianDomain({'x': 1, 'y': [0, 1]}), domain=CartesianDomain({'x': 1, 'y': [0, 1]}),
equation=FixedValue(0.0)), equation=FixedValue(0.0)),
'gamma4': Condition( 'gamma4': Condition(
location=CartesianDomain({'x': 0, 'y': [0, 1]}), domain=CartesianDomain({'x': 0, 'y': [0, 1]}),
equation=FixedValue(0.0)), equation=FixedValue(0.0)),
'D': Condition( 'D': Condition(
input_points=LabelTensor(torch.rand(size=(100, 2)), ['x', 'y']), input_points=LabelTensor(torch.rand(size=(100, 2)), ['x', 'y']),
@@ -48,75 +49,114 @@ class Poisson(SpatialProblem):
output_points=out_), output_points=out_),
'data2': Condition( 'data2': Condition(
input_points=in2_, input_points=in2_,
output_points=out2_) output_points=out2_),
'unsupervised': Condition(
input_points=LabelTensor(torch.rand(size=(45, 2)), ['x', 'y']),
conditional_variables=LabelTensor(torch.ones(size=(45, 1)), ['alpha']),
),
'unsupervised2': Condition(
input_points=LabelTensor(torch.rand(size=(90, 2)), ['x', 'y']),
conditional_variables=LabelTensor(torch.ones(size=(90, 1)), ['alpha']),
)
} }
boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4'] boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
poisson = Poisson() poisson = Poisson()
poisson.discretise_domain(10, 'grid', locations=boundaries) poisson.discretise_domain(10, 'grid', locations=boundaries)
def test_sample(): def test_sample():
sample_dataset = SamplePointDataset(poisson, device='cpu') sample_dataset = SamplePointDataset(poisson, device='cpu')
assert len(sample_dataset) == 140 assert len(sample_dataset) == 140
assert sample_dataset.pts.shape == (140, 2) assert sample_dataset.input_points.shape == (140, 2)
assert sample_dataset.pts.labels == ['x', 'y'] assert sample_dataset.input_points.labels == ['x', 'y']
assert sample_dataset.condition_indeces.dtype == torch.int64 assert sample_dataset.condition_indices.dtype == torch.uint8
assert sample_dataset.condition_indeces.max() == torch.tensor(4) assert sample_dataset.condition_indices.max() == torch.tensor(4)
assert sample_dataset.condition_indeces.min() == torch.tensor(0) assert sample_dataset.condition_indices.min() == torch.tensor(0)
def test_data(): def test_data():
dataset = DataPointDataset(poisson, device='cpu') dataset = SupervisedDataset(poisson, device='cpu')
assert len(dataset) == 61 assert len(dataset) == 61
assert dataset.input_pts.shape == (61, 2) assert dataset['input_points'].shape == (61, 2)
assert dataset.input_pts.labels == ['x', 'y'] assert dataset.input_points.shape == (61, 2)
assert dataset.output_pts.shape == (61, 1 ) assert dataset['input_points'].labels == ['x', 'y']
assert dataset.output_pts.labels == ['u'] assert dataset.input_points.labels == ['x', 'y']
assert dataset.condition_indeces.dtype == torch.int64 assert dataset['input_points', 3:].shape == (58, 2)
assert dataset.condition_indeces.max() == torch.tensor(1) assert dataset[3:][1].labels == ['u']
assert dataset.condition_indeces.min() == torch.tensor(0) assert dataset.output_points.shape == (61, 1)
assert dataset.output_points.labels == ['u']
assert dataset.condition_indices.dtype == torch.uint8
assert dataset.condition_indices.max() == torch.tensor(1)
assert dataset.condition_indices.min() == torch.tensor(0)
def test_unsupervised():
dataset = UnsupervisedDataset(poisson, device='cpu')
assert len(dataset) == 135
assert dataset.input_points.shape == (135, 2)
assert dataset.input_points.labels == ['x', 'y']
assert dataset.input_points[3:].shape == (132, 2)
assert dataset.conditional_variables.shape == (135, 1)
assert dataset.conditional_variables.labels == ['alpha']
assert dataset.condition_indices.dtype == torch.uint8
assert dataset.condition_indices.max() == torch.tensor(1)
assert dataset.condition_indices.min() == torch.tensor(0)
def test_data_module():
data_module = PinaDataModule(poisson, device='cpu')
data_module.setup()
loader = data_module.train_dataloader()
assert isinstance(loader, PinaDataLoader)
assert isinstance(loader, PinaDataLoader)
data_module = PinaDataModule(poisson, device='cpu', batch_size=10, shuffle=False)
data_module.setup()
loader = data_module.train_dataloader()
assert len(loader) == 24
for i in loader:
assert len(i) <= 10
len_ref = sum([math.ceil(len(dataset) * 0.7) for dataset in data_module.datasets])
len_real = sum([len(dataset) for dataset in data_module.splits['train'].values()])
assert len_ref == len_real
supervised_dataset = SupervisedDataset(poisson, device='cpu')
data_module = PinaDataModule(poisson, device='cpu', batch_size=10, shuffle=False, datasets=[supervised_dataset])
data_module.setup()
loader = data_module.train_dataloader()
for batch in loader:
assert len(batch) <= 10
physics_dataset = SamplePointDataset(poisson, device='cpu')
data_module = PinaDataModule(poisson, device='cpu', batch_size=10, shuffle=False, datasets=[physics_dataset])
data_module.setup()
loader = data_module.train_dataloader()
for batch in loader:
assert len(batch) <= 10
unsupervised_dataset = UnsupervisedDataset(poisson, device='cpu')
data_module = PinaDataModule(poisson, device='cpu', batch_size=10, shuffle=False, datasets=[unsupervised_dataset])
data_module.setup()
loader = data_module.train_dataloader()
for batch in loader:
assert len(batch) <= 10
def test_loader(): def test_loader():
sample_dataset = SamplePointDataset(poisson, device='cpu') data_module = PinaDataModule(poisson, device='cpu', batch_size=10)
data_dataset = DataPointDataset(poisson, device='cpu') data_module.setup()
loader = SamplePointLoader(sample_dataset, data_dataset, batch_size=10) loader = data_module.train_dataloader()
assert isinstance(loader, PinaDataLoader)
for batch in loader: assert len(loader) == 24
assert len(batch) in [2, 3] for i in loader:
assert batch['pts'].shape[0] <= 10 assert len(i) <= 10
assert batch['pts'].requires_grad == True assert i.supervised.input_points.labels == ['x', 'y']
assert batch['pts'].labels == ['x', 'y'] assert i.physics.input_points.labels == ['x', 'y']
assert i.unsupervised.input_points.labels == ['x', 'y']
loader2 = SamplePointLoader(sample_dataset, data_dataset, batch_size=None) assert i.supervised.input_points.requires_grad == True
assert len(list(loader2)) == 2 assert i.physics.input_points.requires_grad == True
assert i.unsupervised.input_points.requires_grad == True
def test_loader2(): test_loader()
poisson2 = Poisson()
del poisson.conditions['data2']
del poisson2.conditions['data']
poisson2.discretise_domain(10, 'grid', locations=boundaries)
sample_dataset = SamplePointDataset(poisson, device='cpu')
data_dataset = DataPointDataset(poisson, device='cpu')
loader = SamplePointLoader(sample_dataset, data_dataset, batch_size=10)
for batch in loader:
assert len(batch) == 2 # only phys condtions
assert batch['pts'].shape[0] <= 10
assert batch['pts'].requires_grad == True
assert batch['pts'].labels == ['x', 'y']
def test_loader3():
poisson2 = Poisson()
del poisson.conditions['gamma1']
del poisson.conditions['gamma2']
del poisson.conditions['gamma3']
del poisson.conditions['gamma4']
del poisson.conditions['D']
sample_dataset = SamplePointDataset(poisson, device='cpu')
data_dataset = DataPointDataset(poisson, device='cpu')
loader = SamplePointLoader(sample_dataset, data_dataset, batch_size=10)
for batch in loader:
assert len(batch) == 2 # only phys condtions
assert batch['pts'].shape[0] <= 10
assert batch['pts'].requires_grad == True
assert batch['pts'].labels == ['x', 'y']

View File

@@ -1,50 +1,27 @@
import torch import torch
import pytest
from pina.problem import AbstractProblem from pina.problem import AbstractProblem, SpatialProblem
from pina import Condition, LabelTensor from pina import Condition, LabelTensor
from pina.solvers import SupervisedSolver from pina.solvers import SupervisedSolver
from pina.trainer import Trainer
from pina.model import FeedForward from pina.model import FeedForward
from pina.loss import LpLoss from pina.equation.equation import Equation
from pina.solvers import GraphSupervisedSolver from pina.equation.equation_factory import FixedValue
from pina.operators import laplacian
from pina.domain import CartesianDomain
from pina.trainer import Trainer
in_ = LabelTensor(torch.tensor([[0., 1.]]), ['u_0', 'u_1'])
out_ = LabelTensor(torch.tensor([[0.]]), ['u'])
class NeuralOperatorProblem(AbstractProblem): class NeuralOperatorProblem(AbstractProblem):
input_variables = ['u_0', 'u_1'] input_variables = ['u_0', 'u_1']
output_variables = ['u'] output_variables = ['u']
domains = {
'pts': LabelTensor(
torch.rand(100, 2),
labels={1: {'name': 'space', 'dof': ['u_0', 'u_1']}}
)
}
conditions = { conditions = {
'data' : Condition( 'data': Condition(input_points=in_, output_points=out_),
domain='pts',
output_points=LabelTensor(
torch.rand(100, 1),
labels={1: {'name': 'output', 'dof': ['u']}}
)
)
} }
class NeuralOperatorProblemGraph(AbstractProblem):
input_variables = ['x', 'y', 'u_0', 'u_1']
output_variables = ['u']
domains = {
'pts': LabelTensor(
torch.rand(100, 4),
labels={1: {'name': 'space', 'dof': ['x', 'y', 'u_0', 'u_1']}}
)
}
conditions = {
'data' : Condition(
domain='pts',
output_points=LabelTensor(
torch.rand(100, 1),
labels={1: {'name': 'output', 'dof': ['u']}}
)
)
}
class myFeature(torch.nn.Module): class myFeature(torch.nn.Module):
""" """
@@ -61,117 +38,106 @@ class myFeature(torch.nn.Module):
problem = NeuralOperatorProblem() problem = NeuralOperatorProblem()
problem_graph = NeuralOperatorProblemGraph()
# make the problem + extra feats
extra_feats = [myFeature()] extra_feats = [myFeature()]
model = FeedForward(len(problem.input_variables), model = FeedForward(len(problem.input_variables), len(problem.output_variables))
len(problem.output_variables))
model_extra_feats = FeedForward( model_extra_feats = FeedForward(
len(problem.input_variables) + 1, len(problem.input_variables) + 1, len(problem.output_variables))
len(problem.output_variables))
def test_constructor(): def test_constructor():
SupervisedSolver(problem=problem, model=model) SupervisedSolver(problem=problem, model=model)
# def test_constructor_extra_feats(): test_constructor()
# SupervisedSolver(problem=problem, model=model_extra_feats, extra_features=extra_feats)
'''
class AutoSolver(SupervisedSolver):
def forward(self, input): def laplace_equation(input_, output_):
from pina.graph import Graph force_term = (torch.sin(input_.extract(['x']) * torch.pi) *
print(Graph) torch.sin(input_.extract(['y']) * torch.pi))
print(input) delta_u = laplacian(output_.extract(['u']), input_)
if not isinstance(input, Graph): return delta_u - force_term
input = Graph.build('radius', nodes_coordinates=input, nodes_data=torch.rand(input.shape), radius=0.2)
print(input)
print(input.data.edge_index)
print(input.data)
g = self._model(input.data, edge_index=input.data.edge_index)
g.labels = {1: {'name': 'output', 'dof': ['u']}}
return g
du_dt_new = LabelTensor(self.model(graph).reshape(-1,1), labels = ['du'])
return du_dt_new
'''
class GraphModel(torch.nn.Module): my_laplace = Equation(laplace_equation)
def __init__(self, in_channels, out_channels):
from torch_geometric.nn import GCNConv, NNConv
super().__init__()
self.conv1 = GCNConv(in_channels, 16)
self.conv2 = GCNConv(16, out_channels)
def forward(self, data, edge_index):
print(data)
x = data.x
print(x)
x = self.conv1(x, edge_index)
x = x.relu()
x = self.conv2(x, edge_index)
return x
def test_graph(): class Poisson(SpatialProblem):
solver = GraphSupervisedSolver(problem=problem_graph, model=GraphModel(2, 1), loss=LpLoss(), output_variables = ['u']
nodes_coordinates=['x', 'y'], nodes_data=['u_0', 'u_1']) spatial_domain = CartesianDomain({'x': [0, 1], 'y': [0, 1]})
trainer = Trainer(solver=solver, max_epochs=30, accelerator='cpu', batch_size=20)
trainer.train() conditions = {
'gamma1':
Condition(domain=CartesianDomain({
'x': [0, 1],
'y': 1
}),
equation=FixedValue(0.0)),
'gamma2':
Condition(domain=CartesianDomain({
'x': [0, 1],
'y': 0
}),
equation=FixedValue(0.0)),
'gamma3':
Condition(domain=CartesianDomain({
'x': 1,
'y': [0, 1]
}),
equation=FixedValue(0.0)),
'gamma4':
Condition(domain=CartesianDomain({
'x': 0,
'y': [0, 1]
}),
equation=FixedValue(0.0)),
'D':
Condition(domain=CartesianDomain({
'x': [0, 1],
'y': [0, 1]
}),
equation=my_laplace),
'data':
Condition(input_points=in_, output_points=out_)
}
def poisson_sol(self, pts):
return -(torch.sin(pts.extract(['x']) * torch.pi) *
torch.sin(pts.extract(['y']) * torch.pi)) / (2 * torch.pi ** 2)
truth_solution = poisson_sol
def test_wrong_constructor():
poisson_problem = Poisson()
with pytest.raises(ValueError):
SupervisedSolver(problem=poisson_problem, model=model)
def test_train_cpu(): def test_train_cpu():
solver = SupervisedSolver(problem = problem, model=model, loss=LpLoss()) solver = SupervisedSolver(problem=problem, model=model)
trainer = Trainer(solver=solver, max_epochs=300, accelerator='cpu', batch_size=20) trainer = Trainer(solver=solver,
max_epochs=200,
accelerator='gpu',
batch_size=5,
train_size=1,
test_size=0.,
eval_size=0.)
trainer.train() trainer.train()
test_train_cpu()
# def test_train_restore(): def test_extra_features_constructor():
# tmpdir = "tests/tmp_restore" SupervisedSolver(problem=problem,
# solver = SupervisedSolver(problem=problem, model=model_extra_feats,
# model=model, extra_features=extra_feats)
# extra_features=None,
# loss=LpLoss())
# trainer = Trainer(solver=solver,
# max_epochs=5,
# accelerator='cpu',
# default_root_dir=tmpdir)
# trainer.train()
# ntrainer = Trainer(solver=solver, max_epochs=15, accelerator='cpu')
# t = ntrainer.train(
# ckpt_path=f'{tmpdir}/lightning_logs/version_0/checkpoints/epoch=4-step=5.ckpt')
# import shutil
# shutil.rmtree(tmpdir)
# def test_train_load(): def test_extra_features_train_cpu():
# tmpdir = "tests/tmp_load" solver = SupervisedSolver(problem=problem,
# solver = SupervisedSolver(problem=problem, model=model_extra_feats,
# model=model, extra_features=extra_feats)
# extra_features=None, trainer = Trainer(solver=solver,
# loss=LpLoss()) max_epochs=200,
# trainer = Trainer(solver=solver, accelerator='gpu',
# max_epochs=15, batch_size=5)
# accelerator='cpu', trainer.train()
# default_root_dir=tmpdir)
# trainer.train()
# new_solver = SupervisedSolver.load_from_checkpoint(
# f'{tmpdir}/lightning_logs/version_0/checkpoints/epoch=14-step=15.ckpt',
# problem = problem, model=model)
# test_pts = LabelTensor(torch.rand(20, 2), problem.input_variables)
# assert new_solver.forward(test_pts).shape == (20, 1)
# assert new_solver.forward(test_pts).shape == solver.forward(test_pts).shape
# torch.testing.assert_close(
# new_solver.forward(test_pts),
# solver.forward(test_pts))
# import shutil
# shutil.rmtree(tmpdir)
# def test_train_extra_feats_cpu():
# pinn = SupervisedSolver(problem=problem,
# model=model_extra_feats,
# extra_features=extra_feats)
# trainer = Trainer(solver=pinn, max_epochs=5, accelerator='cpu')
# trainer.train()
test_graph()