Implementation of DataLoader and DataModule (#383)
Refactoring for 0.2 * Data module, data loader and dataset * Refactor LabelTensor * Refactor solvers Co-authored-by: dario-coscia <dariocos99@gmail.com>
This commit is contained in:
committed by
Nicola Demo
parent
dd43c8304c
commit
a27bd35443
@@ -1,6 +1,6 @@
|
||||
__all__ = [
|
||||
"Trainer", "LabelTensor", "Plotter", "Condition", "SamplePointDataset",
|
||||
"PinaDataModule", "PinaDataLoader", 'TorchOptimizer', 'Graph'
|
||||
"Trainer", "LabelTensor", "Plotter", "Condition",
|
||||
"PinaDataModule", 'TorchOptimizer', 'Graph',
|
||||
]
|
||||
|
||||
from .meta import *
|
||||
@@ -9,9 +9,9 @@ from .solvers.solver import SolverInterface
|
||||
from .trainer import Trainer
|
||||
from .plotter import Plotter
|
||||
from .condition.condition import Condition
|
||||
from .data import SamplePointDataset
|
||||
|
||||
from .data import PinaDataModule
|
||||
from .data import PinaDataLoader
|
||||
|
||||
from .optim import TorchOptimizer
|
||||
from .optim import TorchScheduler
|
||||
from .graph import Graph
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from . import LabelTensor
|
||||
from .utils import check_consistency, merge_tensors
|
||||
|
||||
|
||||
@@ -66,9 +67,12 @@ class Collector:
|
||||
for loc in sample_locations:
|
||||
# get condition
|
||||
condition = self.problem.conditions[loc]
|
||||
condition_domain = condition.domain
|
||||
if isinstance(condition_domain, str):
|
||||
condition_domain = self.problem.domains[condition_domain]
|
||||
keys = ["input_points", "equation"]
|
||||
# if the condition is not ready, we get and store the data
|
||||
if (not self._is_conditions_ready[loc]):
|
||||
if not self._is_conditions_ready[loc]:
|
||||
# if it is the first time we sample
|
||||
if not self.data_collections[loc]:
|
||||
already_sampled = []
|
||||
@@ -84,10 +88,11 @@ class Collector:
|
||||
|
||||
# get the samples
|
||||
samples = [
|
||||
condition.domain.sample(n=n, mode=mode, variables=variables)
|
||||
] + already_sampled
|
||||
condition_domain.sample(n=n, mode=mode,
|
||||
variables=variables)
|
||||
] + already_sampled
|
||||
pts = merge_tensors(samples)
|
||||
if (set(pts.labels).issubset(sorted(self.problem.input_variables))):
|
||||
if set(pts.labels).issubset(sorted(self.problem.input_variables)):
|
||||
pts = pts.sort_labels()
|
||||
if sorted(pts.labels) == sorted(self.problem.input_variables):
|
||||
self._is_conditions_ready[loc] = True
|
||||
@@ -110,5 +115,6 @@ class Collector:
|
||||
if not self._is_conditions_ready[k]:
|
||||
raise RuntimeError(
|
||||
'Cannot add points on a non sampled condition')
|
||||
self.data_collections[k]['input_points'] = self.data_collections[k][
|
||||
'input_points'].vstack(v)
|
||||
self.data_collections[k]['input_points'] = LabelTensor.vstack(
|
||||
[self.data_collections[k][
|
||||
'input_points'], v])
|
||||
|
||||
@@ -18,12 +18,11 @@ class DataConditionInterface(ConditionInterface):
|
||||
|
||||
def __init__(self, input_points, conditional_variables=None):
|
||||
"""
|
||||
TODO
|
||||
TODO : add docstring
|
||||
"""
|
||||
super().__init__()
|
||||
self.input_points = input_points
|
||||
self.conditional_variables = conditional_variables
|
||||
self._condition_type = 'unsupervised'
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
if (key == 'input_points') or (key == 'conditional_variables'):
|
||||
|
||||
@@ -16,16 +16,15 @@ class DomainEquationCondition(ConditionInterface):
|
||||
|
||||
def __init__(self, domain, equation):
|
||||
"""
|
||||
TODO
|
||||
TODO : add docstring
|
||||
"""
|
||||
super().__init__()
|
||||
self.domain = domain
|
||||
self.equation = equation
|
||||
self._condition_type = 'physics'
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
if key == 'domain':
|
||||
check_consistency(value, (DomainInterface))
|
||||
check_consistency(value, (DomainInterface, str))
|
||||
DomainEquationCondition.__dict__[key].__set__(self, value)
|
||||
elif key == 'equation':
|
||||
check_consistency(value, (EquationInterface))
|
||||
|
||||
@@ -17,12 +17,11 @@ class InputPointsEquationCondition(ConditionInterface):
|
||||
|
||||
def __init__(self, input_points, equation):
|
||||
"""
|
||||
TODO
|
||||
TODO : add docstring
|
||||
"""
|
||||
super().__init__()
|
||||
self.input_points = input_points
|
||||
self.equation = equation
|
||||
self._condition_type = 'physics'
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
if key == 'input_points':
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import torch
|
||||
import torch_geometric
|
||||
|
||||
from .condition_interface import ConditionInterface
|
||||
from ..label_tensor import LabelTensor
|
||||
@@ -16,16 +17,15 @@ class InputOutputPointsCondition(ConditionInterface):
|
||||
|
||||
def __init__(self, input_points, output_points):
|
||||
"""
|
||||
TODO
|
||||
TODO : add docstring
|
||||
"""
|
||||
super().__init__()
|
||||
self.input_points = input_points
|
||||
self.output_points = output_points
|
||||
self._condition_type = ['supervised', 'physics']
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
if (key == 'input_points') or (key == 'output_points'):
|
||||
check_consistency(value, (LabelTensor, Graph, torch.Tensor))
|
||||
check_consistency(value, (LabelTensor, Graph, torch.Tensor, torch_geometric.data.Data))
|
||||
InputOutputPointsCondition.__dict__[key].__set__(self, value)
|
||||
elif key in ('_problem', '_condition_type'):
|
||||
super().__setattr__(key, value)
|
||||
|
||||
@@ -2,14 +2,11 @@
|
||||
Import data classes
|
||||
"""
|
||||
__all__ = [
|
||||
'PinaDataLoader', 'SupervisedDataset', 'SamplePointDataset',
|
||||
'UnsupervisedDataset', 'Batch', 'PinaDataModule', 'BaseDataset'
|
||||
'PinaDataModule',
|
||||
'PinaDataset'
|
||||
]
|
||||
|
||||
from .pina_dataloader import PinaDataLoader
|
||||
from .supervised_dataset import SupervisedDataset
|
||||
from .sample_dataset import SamplePointDataset
|
||||
from .unsupervised_dataset import UnsupervisedDataset
|
||||
from .pina_batch import Batch
|
||||
|
||||
|
||||
from .data_module import PinaDataModule
|
||||
from .base_dataset import BaseDataset
|
||||
from .dataset import PinaDataset
|
||||
|
||||
@@ -1,157 +0,0 @@
|
||||
"""
|
||||
Basic data module implementation
|
||||
"""
|
||||
import torch
|
||||
import logging
|
||||
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from ..label_tensor import LabelTensor
|
||||
|
||||
|
||||
class BaseDataset(Dataset):
|
||||
"""
|
||||
BaseDataset class, which handle initialization and data retrieval
|
||||
:var condition_indices: List of indices
|
||||
:var device: torch.device
|
||||
"""
|
||||
|
||||
def __new__(cls, problem=None, device=torch.device('cpu')):
|
||||
"""
|
||||
Ensure correct definition of __slots__ before initialization
|
||||
:param AbstractProblem problem: The formulation of the problem.
|
||||
:param torch.device device: The device on which the
|
||||
dataset will be loaded.
|
||||
"""
|
||||
if cls is BaseDataset:
|
||||
raise TypeError(
|
||||
'BaseDataset cannot be instantiated directly. Use a subclass.')
|
||||
if not hasattr(cls, '__slots__'):
|
||||
raise TypeError(
|
||||
'Something is wrong, __slots__ must be defined in subclasses.')
|
||||
return object.__new__(cls)
|
||||
|
||||
def __init__(self, problem=None, device=torch.device('cpu')):
|
||||
""""
|
||||
Initialize the object based on __slots__
|
||||
:param AbstractProblem problem: The formulation of the problem.
|
||||
:param torch.device device: The device on which the
|
||||
dataset will be loaded.
|
||||
"""
|
||||
super().__init__()
|
||||
self.empty = True
|
||||
self.problem = problem
|
||||
self.device = device
|
||||
self.condition_indices = None
|
||||
for slot in self.__slots__:
|
||||
setattr(self, slot, [])
|
||||
self.num_el_per_condition = []
|
||||
self.conditions_idx = []
|
||||
if self.problem is not None:
|
||||
self._init_from_problem(self.problem.collector.data_collections)
|
||||
self.initialized = False
|
||||
|
||||
def _init_from_problem(self, collector_dict):
|
||||
"""
|
||||
TODO
|
||||
"""
|
||||
for name, data in collector_dict.items():
|
||||
keys = list(data.keys())
|
||||
if set(self.__slots__) == set(keys):
|
||||
self._populate_init_list(data)
|
||||
idx = [
|
||||
key for key, val in
|
||||
self.problem.collector.conditions_name.items()
|
||||
if val == name
|
||||
]
|
||||
self.conditions_idx.append(idx)
|
||||
self.initialize()
|
||||
|
||||
def add_points(self, data_dict, condition_idx, batching_dim=0):
|
||||
"""
|
||||
This method filled internal lists of data points
|
||||
:param data_dict: dictionary containing data points
|
||||
:param condition_idx: index of the condition to which the data points
|
||||
belong to
|
||||
:param batching_dim: dimension of the batching
|
||||
:raises: ValueError if the dataset has already been initialized
|
||||
"""
|
||||
if not self.initialized:
|
||||
self._populate_init_list(data_dict, batching_dim)
|
||||
self.conditions_idx.append(condition_idx)
|
||||
self.empty = False
|
||||
else:
|
||||
raise ValueError('Dataset already initialized')
|
||||
|
||||
def _populate_init_list(self, data_dict, batching_dim=0):
|
||||
current_cond_num_el = None
|
||||
for slot in data_dict.keys():
|
||||
slot_data = data_dict[slot]
|
||||
if batching_dim != 0:
|
||||
if isinstance(slot_data, (LabelTensor, torch.Tensor)):
|
||||
dims = len(slot_data.size())
|
||||
slot_data = slot_data.permute(
|
||||
[batching_dim] +
|
||||
[dim for dim in range(dims) if dim != batching_dim])
|
||||
if current_cond_num_el is None:
|
||||
current_cond_num_el = len(slot_data)
|
||||
elif current_cond_num_el != len(slot_data):
|
||||
raise ValueError('Different dimension in same condition')
|
||||
current_list = getattr(self, slot)
|
||||
current_list += [
|
||||
slot_data
|
||||
] if not (isinstance(slot_data, list)) else slot_data
|
||||
self.num_el_per_condition.append(current_cond_num_el)
|
||||
|
||||
def initialize(self):
|
||||
"""
|
||||
Initialize the datasets tensors/LabelTensors/lists given the lists
|
||||
already filled
|
||||
"""
|
||||
logging.debug(f'Initialize dataset {self.__class__.__name__}')
|
||||
|
||||
if self.num_el_per_condition:
|
||||
self.condition_indices = torch.cat([
|
||||
torch.tensor([i] * self.num_el_per_condition[i],
|
||||
dtype=torch.uint8)
|
||||
for i in range(len(self.num_el_per_condition))
|
||||
],
|
||||
dim=0)
|
||||
for slot in self.__slots__:
|
||||
current_attribute = getattr(self, slot)
|
||||
if all(isinstance(a, LabelTensor) for a in current_attribute):
|
||||
setattr(self, slot, LabelTensor.vstack(current_attribute))
|
||||
self.initialized = True
|
||||
|
||||
def __len__(self):
|
||||
"""
|
||||
:return: Number of elements in the dataset
|
||||
"""
|
||||
return len(getattr(self, self.__slots__[0]))
|
||||
|
||||
def __getitem__(self, idx):
|
||||
"""
|
||||
:param idx:
|
||||
:return:
|
||||
"""
|
||||
if not isinstance(idx, (tuple, list, slice, int)):
|
||||
raise IndexError("Invalid index")
|
||||
tensors = []
|
||||
for attribute in self.__slots__:
|
||||
tensor = getattr(self, attribute)
|
||||
if isinstance(attribute, (LabelTensor, torch.Tensor)):
|
||||
tensors.append(tensor.__getitem__(idx))
|
||||
elif isinstance(attribute, list):
|
||||
if isinstance(idx, (list, tuple)):
|
||||
tensor = [tensor[i] for i in idx]
|
||||
tensors.append(tensor)
|
||||
return tensors
|
||||
|
||||
def apply_shuffle(self, indices):
|
||||
for slot in self.__slots__:
|
||||
if slot != 'equation':
|
||||
attribute = getattr(self, slot)
|
||||
if isinstance(attribute, (LabelTensor, torch.Tensor)):
|
||||
setattr(self, 'slot', attribute[[indices]])
|
||||
if isinstance(attribute, list):
|
||||
setattr(self, 'slot', [attribute[i] for i in indices])
|
||||
@@ -1,17 +1,71 @@
|
||||
"""
|
||||
This module provide basic data management functionalities
|
||||
"""
|
||||
|
||||
import logging
|
||||
from lightning.pytorch import LightningDataModule
|
||||
import math
|
||||
import torch
|
||||
import logging
|
||||
from pytorch_lightning import LightningDataModule
|
||||
from .sample_dataset import SamplePointDataset
|
||||
from .supervised_dataset import SupervisedDataset
|
||||
from .unsupervised_dataset import UnsupervisedDataset
|
||||
from .pina_dataloader import PinaDataLoader
|
||||
from .pina_subset import PinaSubset
|
||||
from ..label_tensor import LabelTensor
|
||||
from torch.utils.data import DataLoader, BatchSampler, SequentialSampler, \
|
||||
RandomSampler
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from .dataset import PinaDatasetFactory
|
||||
|
||||
class Collator:
|
||||
def __init__(self, max_conditions_lengths, ):
|
||||
self.max_conditions_lengths = max_conditions_lengths
|
||||
self.callable_function = self._collate_custom_dataloader if \
|
||||
max_conditions_lengths is None else (
|
||||
self._collate_standard_dataloader)
|
||||
|
||||
@staticmethod
|
||||
def _collate_custom_dataloader(batch):
|
||||
return batch[0]
|
||||
|
||||
def _collate_standard_dataloader(self, batch):
|
||||
"""
|
||||
Function used to collate the batch
|
||||
"""
|
||||
batch_dict = {}
|
||||
if isinstance(batch, dict):
|
||||
return batch
|
||||
conditions_names = batch[0].keys()
|
||||
|
||||
# Condition names
|
||||
for condition_name in conditions_names:
|
||||
single_cond_dict = {}
|
||||
condition_args = batch[0][condition_name].keys()
|
||||
for arg in condition_args:
|
||||
data_list = [batch[idx][condition_name][arg] for idx in range(
|
||||
min(len(batch),
|
||||
self.max_conditions_lengths[condition_name]))]
|
||||
if isinstance(data_list[0], LabelTensor):
|
||||
single_cond_dict[arg] = LabelTensor.stack(data_list)
|
||||
elif isinstance(data_list[0], torch.Tensor):
|
||||
single_cond_dict[arg] = torch.stack(data_list)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Data type {type(data_list[0])} not supported")
|
||||
batch_dict[condition_name] = single_cond_dict
|
||||
return batch_dict
|
||||
|
||||
def __call__(self, batch):
|
||||
return self.callable_function(batch)
|
||||
|
||||
|
||||
class PinaBatchSampler(BatchSampler):
|
||||
def __init__(self, dataset, batch_size, shuffle, sampler=None):
|
||||
if sampler is None:
|
||||
if (torch.distributed.is_available() and
|
||||
torch.distributed.is_initialized()):
|
||||
rank = torch.distributed.get_rank()
|
||||
world_size = torch.distributed.get_world_size()
|
||||
sampler = DistributedSampler(dataset, shuffle=shuffle,
|
||||
rank=rank, num_replicas=world_size)
|
||||
else:
|
||||
if shuffle:
|
||||
sampler = RandomSampler(dataset)
|
||||
else:
|
||||
sampler = SequentialSampler(dataset)
|
||||
super().__init__(sampler=sampler, batch_size=batch_size,
|
||||
drop_last=False)
|
||||
|
||||
class PinaDataModule(LightningDataModule):
|
||||
"""
|
||||
@@ -20,160 +74,218 @@ class PinaDataModule(LightningDataModule):
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
problem,
|
||||
device,
|
||||
collector,
|
||||
train_size=.7,
|
||||
test_size=.1,
|
||||
val_size=.2,
|
||||
test_size=.2,
|
||||
val_size=.1,
|
||||
predict_size=0.,
|
||||
batch_size=None,
|
||||
shuffle=True,
|
||||
datasets=None):
|
||||
repeat=False,
|
||||
automatic_batching=False
|
||||
):
|
||||
"""
|
||||
Initialize the object, creating dataset based on input problem
|
||||
:param AbstractProblem problem: PINA problem
|
||||
:param device: Device used for training and testing
|
||||
:param Collector collector: PINA problem
|
||||
:param train_size: number/percentage of elements in train split
|
||||
:param test_size: number/percentage of elements in test split
|
||||
:param eval_size: number/percentage of elements in evaluation split
|
||||
:param val_size: number/percentage of elements in evaluation split
|
||||
:param batch_size: batch size used for training
|
||||
:param datasets: list of datasets objects
|
||||
"""
|
||||
logging.debug('Start initialization of Pina DataModule')
|
||||
logging.info('Start initialization of Pina DataModule')
|
||||
super().__init__()
|
||||
self.problem = problem
|
||||
self.device = device
|
||||
self.dataset_classes = [
|
||||
SupervisedDataset, UnsupervisedDataset, SamplePointDataset
|
||||
]
|
||||
if datasets is None:
|
||||
self.datasets = None
|
||||
else:
|
||||
self.datasets = datasets
|
||||
|
||||
self.split_length = []
|
||||
self.split_names = []
|
||||
self.loader_functions = {}
|
||||
self.default_batching = automatic_batching
|
||||
self.batch_size = batch_size
|
||||
self.condition_names = problem.collector.conditions_name
|
||||
|
||||
if train_size > 0:
|
||||
self.split_names.append('train')
|
||||
self.split_length.append(train_size)
|
||||
self.loader_functions['train_dataloader'] = lambda: PinaDataLoader(
|
||||
self.splits['train'], self.batch_size, self.condition_names)
|
||||
if test_size > 0:
|
||||
self.split_length.append(test_size)
|
||||
self.split_names.append('test')
|
||||
self.loader_functions['test_dataloader'] = lambda: PinaDataLoader(
|
||||
self.splits['test'], self.batch_size, self.condition_names)
|
||||
if val_size > 0:
|
||||
self.split_length.append(val_size)
|
||||
self.split_names.append('val')
|
||||
self.loader_functions['val_dataloader'] = lambda: PinaDataLoader(
|
||||
self.splits['val'], self.batch_size, self.condition_names)
|
||||
if predict_size > 0:
|
||||
self.split_length.append(predict_size)
|
||||
self.split_names.append('predict')
|
||||
self.loader_functions['predict_dataloader'] = lambda: PinaDataLoader(
|
||||
self.splits['predict'], self.batch_size, self.condition_names)
|
||||
self.splits = {k: {} for k in self.split_names}
|
||||
self.shuffle = shuffle
|
||||
self.repeat = repeat
|
||||
|
||||
for k, v in self.loader_functions.items():
|
||||
setattr(self, k, v)
|
||||
|
||||
def prepare_data(self):
|
||||
if self.datasets is None:
|
||||
self._create_datasets()
|
||||
# Begin Data splitting
|
||||
splits_dict = {}
|
||||
if train_size > 0:
|
||||
splits_dict['train'] = train_size
|
||||
self.train_dataset = None
|
||||
else:
|
||||
self.train_dataloader = super().train_dataloader
|
||||
if test_size > 0:
|
||||
splits_dict['test'] = test_size
|
||||
self.test_dataset = None
|
||||
else:
|
||||
self.test_dataloader = super().test_dataloader
|
||||
if val_size > 0:
|
||||
splits_dict['val'] = val_size
|
||||
self.val_dataset = None
|
||||
else:
|
||||
self.val_dataloader = super().val_dataloader
|
||||
if predict_size > 0:
|
||||
splits_dict['predict'] = predict_size
|
||||
self.predict_dataset = None
|
||||
else:
|
||||
self.predict_dataloader = super().predict_dataloader
|
||||
self.collector_splits = self._create_splits(collector, splits_dict)
|
||||
|
||||
def setup(self, stage=None):
|
||||
"""
|
||||
Perform the splitting of the dataset
|
||||
"""
|
||||
logging.debug('Start setup of Pina DataModule obj')
|
||||
if self.datasets is None:
|
||||
self._create_datasets()
|
||||
if stage == 'fit' or stage is None:
|
||||
for dataset in self.datasets:
|
||||
if len(dataset) > 0:
|
||||
splits = self.dataset_split(dataset,
|
||||
self.split_length,
|
||||
shuffle=self.shuffle)
|
||||
for i in range(len(self.split_length)):
|
||||
self.splits[self.split_names[i]][
|
||||
dataset.data_type] = splits[i]
|
||||
self.train_dataset = PinaDatasetFactory(
|
||||
self.collector_splits['train'],
|
||||
max_conditions_lengths=self.find_max_conditions_lengths(
|
||||
'train'))
|
||||
if 'val' in self.collector_splits.keys():
|
||||
self.val_dataset = PinaDatasetFactory(
|
||||
self.collector_splits['val'],
|
||||
max_conditions_lengths=self.find_max_conditions_lengths(
|
||||
'val')
|
||||
)
|
||||
elif stage == 'test':
|
||||
raise NotImplementedError("Testing pipeline not implemented yet")
|
||||
self.test_dataset = PinaDatasetFactory(
|
||||
self.collector_splits['test'],
|
||||
max_conditions_lengths=self.find_max_conditions_lengths(
|
||||
'test')
|
||||
)
|
||||
elif stage == 'predict':
|
||||
self.predict_dataset = PinaDatasetFactory(
|
||||
self.collector_splits['predict'],
|
||||
max_conditions_lengths=self.find_max_conditions_lengths(
|
||||
'predict')
|
||||
)
|
||||
else:
|
||||
raise ValueError("stage must be either 'fit' or 'test'")
|
||||
raise ValueError(
|
||||
"stage must be either 'fit' or 'test' or 'predict'."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def dataset_split(dataset, lengths, seed=None, shuffle=True):
|
||||
"""
|
||||
Perform the splitting of the dataset
|
||||
:param dataset: dataset object we wanted to split
|
||||
:param lengths: lengths of elements in dataset
|
||||
:param seed: random seed
|
||||
:param shuffle: shuffle dataset
|
||||
:return: split dataset
|
||||
:rtype: PinaSubset
|
||||
"""
|
||||
if sum(lengths) - 1 < 1e-3:
|
||||
len_dataset = len(dataset)
|
||||
lengths = [
|
||||
int(math.floor(len_dataset * length)) for length in lengths
|
||||
]
|
||||
remainder = len(dataset) - sum(lengths)
|
||||
for i in range(remainder):
|
||||
lengths[i % len(lengths)] += 1
|
||||
elif sum(lengths) - 1 >= 1e-3:
|
||||
raise ValueError(f"Sum of lengths is {sum(lengths)} less than 1")
|
||||
def _split_condition(condition_dict, splits_dict):
|
||||
len_condition = len(condition_dict['input_points'])
|
||||
|
||||
if shuffle:
|
||||
if seed is not None:
|
||||
generator = torch.Generator()
|
||||
generator.manual_seed(seed)
|
||||
indices = torch.randperm(sum(lengths), generator=generator)
|
||||
else:
|
||||
indices = torch.randperm(sum(lengths))
|
||||
dataset.apply_shuffle(indices)
|
||||
|
||||
indices = torch.arange(0, sum(lengths), 1, dtype=torch.uint8).tolist()
|
||||
offsets = [
|
||||
sum(lengths[:i]) if i > 0 else 0 for i in range(len(lengths))
|
||||
]
|
||||
return [
|
||||
PinaSubset(dataset, indices[offset:offset + length])
|
||||
for offset, length in zip(offsets, lengths)
|
||||
lengths = [
|
||||
int(math.floor(len_condition * length)) for length in
|
||||
splits_dict.values()
|
||||
]
|
||||
|
||||
def _create_datasets(self):
|
||||
remainder = len_condition - sum(lengths)
|
||||
for i in range(remainder):
|
||||
lengths[i % len(lengths)] += 1
|
||||
splits_dict = {k: v for k, v in zip(splits_dict.keys(), lengths)
|
||||
}
|
||||
to_return_dict = {}
|
||||
offset = 0
|
||||
for stage, stage_len in splits_dict.items():
|
||||
to_return_dict[stage] = {k: v[offset:offset + stage_len]
|
||||
for k, v in condition_dict.items() if
|
||||
k != 'equation'
|
||||
# Equations are NEVER dataloaded
|
||||
}
|
||||
offset += stage_len
|
||||
return to_return_dict
|
||||
|
||||
def _create_splits(self, collector, splits_dict):
|
||||
"""
|
||||
Create the dataset objects putting data
|
||||
"""
|
||||
logging.debug('Dataset creation in PinaDataModule obj')
|
||||
collector = self.problem.collector
|
||||
batching_dim = self.problem.batching_dimension
|
||||
datasets_slots = [i.__slots__ for i in self.dataset_classes]
|
||||
self.datasets = [
|
||||
dataset(device=self.device) for dataset in self.dataset_classes
|
||||
]
|
||||
logging.debug('Filling datasets in PinaDataModule obj')
|
||||
for name, data in collector.data_collections.items():
|
||||
keys = list(data.keys())
|
||||
idx = [
|
||||
key for key, val in collector.conditions_name.items()
|
||||
if val == name
|
||||
]
|
||||
for i, slot in enumerate(datasets_slots):
|
||||
if slot == keys:
|
||||
self.datasets[i].add_points(data, idx[0], batching_dim)
|
||||
|
||||
# ----------- Auxiliary function ------------
|
||||
def _apply_shuffle(condition_dict, len_data):
|
||||
idx = torch.randperm(len_data)
|
||||
for k, v in condition_dict.items():
|
||||
if k == 'equation':
|
||||
continue
|
||||
datasets = []
|
||||
for dataset in self.datasets:
|
||||
if not dataset.empty:
|
||||
dataset.initialize()
|
||||
datasets.append(dataset)
|
||||
self.datasets = datasets
|
||||
if isinstance(v, list):
|
||||
condition_dict[k] = [v[i] for i in idx]
|
||||
elif isinstance(v, LabelTensor):
|
||||
condition_dict[k] = LabelTensor(v.tensor[idx],
|
||||
v.labels)
|
||||
elif isinstance(v, torch.Tensor):
|
||||
condition_dict[k] = v[idx]
|
||||
else:
|
||||
raise ValueError(f"Data type {type(v)} not supported")
|
||||
# ----------- End auxiliary function ------------
|
||||
|
||||
logging.debug('Dataset creation in PinaDataModule obj')
|
||||
split_names = list(splits_dict.keys())
|
||||
dataset_dict = {name: {} for name in split_names}
|
||||
for condition_name, condition_dict in collector.data_collections.items():
|
||||
len_data = len(condition_dict['input_points'])
|
||||
if self.shuffle:
|
||||
_apply_shuffle(condition_dict, len_data)
|
||||
for key, data in self._split_condition(condition_dict,
|
||||
splits_dict).items():
|
||||
dataset_dict[key].update({condition_name: data})
|
||||
return dataset_dict
|
||||
|
||||
def find_max_conditions_lengths(self, split):
|
||||
max_conditions_lengths = {}
|
||||
for k, v in self.collector_splits[split].items():
|
||||
if self.batch_size is None:
|
||||
max_conditions_lengths[k] = len(v['input_points'])
|
||||
elif self.repeat:
|
||||
max_conditions_lengths[k] = self.batch_size
|
||||
else:
|
||||
max_conditions_lengths[k] = min(len(v['input_points']),
|
||||
self.batch_size)
|
||||
return max_conditions_lengths
|
||||
|
||||
def val_dataloader(self):
|
||||
"""
|
||||
Create the validation dataloader
|
||||
"""
|
||||
|
||||
batch_size = self.batch_size if self.batch_size is not None else len(
|
||||
self.val_dataset)
|
||||
|
||||
# Use default batching in torch DataLoader (good is batch size is small)
|
||||
if self.default_batching:
|
||||
collate = Collator(self.find_max_conditions_lengths('val'))
|
||||
return DataLoader(self.val_dataset, self.batch_size,
|
||||
collate_fn=collate)
|
||||
collate = Collator(None)
|
||||
# Use custom batching (good if batch size is large)
|
||||
sampler = PinaBatchSampler(self.val_dataset, batch_size, shuffle=False)
|
||||
return DataLoader(self.val_dataset, sampler=sampler,
|
||||
collate_fn=collate)
|
||||
|
||||
def train_dataloader(self):
|
||||
"""
|
||||
Create the training dataloader
|
||||
"""
|
||||
# Use default batching in torch DataLoader (good is batch size is small)
|
||||
if self.default_batching:
|
||||
collate = Collator(self.find_max_conditions_lengths('train'))
|
||||
return DataLoader(self.train_dataset, self.batch_size,
|
||||
collate_fn=collate)
|
||||
collate = Collator(None)
|
||||
# Use custom batching (good if batch size is large)
|
||||
batch_size = self.batch_size if self.batch_size is not None else len(
|
||||
self.train_dataset)
|
||||
sampler = PinaBatchSampler(self.train_dataset, batch_size,
|
||||
shuffle=False)
|
||||
return DataLoader(self.train_dataset, sampler=sampler,
|
||||
collate_fn=collate)
|
||||
|
||||
def test_dataloader(self):
|
||||
"""
|
||||
Create the testing dataloader
|
||||
"""
|
||||
raise NotImplementedError("Test dataloader not implemented")
|
||||
|
||||
def predict_dataloader(self):
|
||||
"""
|
||||
Create the prediction dataloader
|
||||
"""
|
||||
raise NotImplementedError("Predict dataloader not implemented")
|
||||
|
||||
def transfer_batch_to_device(self, batch, device, dataloader_idx):
|
||||
"""
|
||||
Transfer the batch to the device. This method is called in the
|
||||
training loop and is used to transfer the batch to the device.
|
||||
"""
|
||||
batch = [
|
||||
(k, super(LightningDataModule, self).transfer_batch_to_device(v,
|
||||
device,
|
||||
dataloader_idx))
|
||||
for k, v in batch.items()
|
||||
]
|
||||
return batch
|
||||
|
||||
102
pina/data/dataset.py
Normal file
102
pina/data/dataset.py
Normal file
@@ -0,0 +1,102 @@
|
||||
"""
|
||||
This module provide basic data management functionalities
|
||||
"""
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
from abc import abstractmethod
|
||||
from torch_geometric.data import Batch
|
||||
|
||||
class PinaDatasetFactory:
|
||||
"""
|
||||
Factory class for the PINA dataset. Depending on the type inside the
|
||||
conditions it creates a different dataset object:
|
||||
- PinaTensorDataset for torch.Tensor
|
||||
- PinaGraphDataset for list of torch_geometric.data.Data objects
|
||||
"""
|
||||
def __new__(cls, conditions_dict, **kwargs):
|
||||
if len(conditions_dict) == 0:
|
||||
raise ValueError('No conditions provided')
|
||||
if all([isinstance(v['input_points'], torch.Tensor) for v
|
||||
in conditions_dict.values()]):
|
||||
return PinaTensorDataset(conditions_dict, **kwargs)
|
||||
elif all([isinstance(v['input_points'], list) for v
|
||||
in conditions_dict.values()]):
|
||||
return PinaGraphDataset(conditions_dict, **kwargs)
|
||||
raise ValueError('Conditions must be either torch.Tensor or list of Data '
|
||||
'objects.')
|
||||
|
||||
class PinaDataset(Dataset):
|
||||
"""
|
||||
Abstract class for the PINA dataset
|
||||
"""
|
||||
def __init__(self, conditions_dict, max_conditions_lengths):
|
||||
self.conditions_dict = conditions_dict
|
||||
self.max_conditions_lengths = max_conditions_lengths
|
||||
self.conditions_length = {k: len(v['input_points']) for k, v in
|
||||
self.conditions_dict.items()}
|
||||
self.length = max(self.conditions_length.values())
|
||||
|
||||
def _get_max_len(self):
|
||||
max_len = 0
|
||||
for condition in self.conditions_dict.values():
|
||||
max_len = max(max_len, len(condition['input_points']))
|
||||
return max_len
|
||||
|
||||
def __len__(self):
|
||||
return self.length
|
||||
|
||||
@abstractmethod
|
||||
def __getitem__(self, item):
|
||||
pass
|
||||
|
||||
class PinaTensorDataset(PinaDataset):
|
||||
def __init__(self, conditions_dict, max_conditions_lengths,
|
||||
):
|
||||
super().__init__(conditions_dict, max_conditions_lengths)
|
||||
|
||||
def _getitem_int(self, idx):
|
||||
return {
|
||||
k: {k_data: v[k_data][idx % len(v['input_points'])] for k_data
|
||||
in v.keys()} for k, v in self.conditions_dict.items()
|
||||
}
|
||||
|
||||
def _getitem_list(self, idx):
|
||||
to_return_dict = {}
|
||||
for condition, data in self.conditions_dict.items():
|
||||
cond_idx = idx[:self.max_conditions_lengths[condition]]
|
||||
condition_len = self.conditions_length[condition]
|
||||
if self.length > condition_len:
|
||||
cond_idx = [idx%condition_len for idx in cond_idx]
|
||||
to_return_dict[condition] = {k: v[cond_idx]
|
||||
for k, v in data.items()}
|
||||
return to_return_dict
|
||||
|
||||
def __getitem__(self, idx):
|
||||
if isinstance(idx, int):
|
||||
return self._getitem_int(idx)
|
||||
return self._getitem_list(idx)
|
||||
|
||||
class PinaGraphDataset(PinaDataset):
|
||||
pass
|
||||
"""
|
||||
def __init__(self, conditions_dict, max_conditions_lengths):
|
||||
super().__init__(conditions_dict, max_conditions_lengths)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
|
||||
Getitem method for large batch size
|
||||
|
||||
to_return_dict = {}
|
||||
for condition, data in self.conditions_dict.items():
|
||||
cond_idx = idx[:self.max_conditions_lengths[condition]]
|
||||
condition_len = self.conditions_length[condition]
|
||||
if self.length > condition_len:
|
||||
cond_idx = [idx%condition_len for idx in cond_idx]
|
||||
to_return_dict[condition] = {k: Batch.from_data_list([v[i]
|
||||
for i in cond_idx])
|
||||
if isinstance(v, list)
|
||||
else v[cond_idx].tensor.reshape(-1, v.size(-1))
|
||||
for k, v in data.items()
|
||||
}
|
||||
return to_return_dict
|
||||
"""
|
||||
@@ -1,47 +0,0 @@
|
||||
"""
|
||||
Batch management module
|
||||
"""
|
||||
from .pina_subset import PinaSubset
|
||||
|
||||
|
||||
class Batch:
|
||||
"""
|
||||
Implementation of the Batch class used during training to perform SGD
|
||||
optimization.
|
||||
"""
|
||||
|
||||
def __init__(self, dataset_dict, idx_dict, require_grad=True):
|
||||
self.attributes = []
|
||||
for k, v in dataset_dict.items():
|
||||
setattr(self, k, v)
|
||||
self.attributes.append(k)
|
||||
|
||||
for k, v in idx_dict.items():
|
||||
setattr(self, k + '_idx', v)
|
||||
self.require_grad = require_grad
|
||||
|
||||
def __len__(self):
|
||||
"""
|
||||
Returns the number of elements in the batch
|
||||
:return: number of elements in the batch
|
||||
:rtype: int
|
||||
"""
|
||||
length = 0
|
||||
for dataset in dir(self):
|
||||
attribute = getattr(self, dataset)
|
||||
if isinstance(attribute, list):
|
||||
length += len(getattr(self, dataset))
|
||||
return length
|
||||
|
||||
def __getattribute__(self, item):
|
||||
if item in super().__getattribute__('attributes'):
|
||||
dataset = super().__getattribute__(item)
|
||||
index = super().__getattribute__(item + '_idx')
|
||||
return PinaSubset(dataset.dataset, dataset.indices[index])
|
||||
return super().__getattribute__(item)
|
||||
|
||||
def __getattr__(self, item):
|
||||
if item == 'data' and len(self.attributes) == 1:
|
||||
item = self.attributes[0]
|
||||
return super().__getattribute__(item)
|
||||
raise AttributeError(f"'Batch' object has no attribute '{item}'")
|
||||
@@ -1,68 +0,0 @@
|
||||
"""
|
||||
This module is used to create an iterable object used during training
|
||||
"""
|
||||
import math
|
||||
from .pina_batch import Batch
|
||||
|
||||
|
||||
class PinaDataLoader:
|
||||
"""
|
||||
This class is used to create a dataloader to use during the training.
|
||||
|
||||
:var condition_names: The names of the conditions. The order is consistent
|
||||
with the condition indeces in the batches.
|
||||
:vartype condition_names: list[str]
|
||||
"""
|
||||
|
||||
def __init__(self, dataset_dict, batch_size, condition_names) -> None:
|
||||
"""
|
||||
Initialize local variables
|
||||
:param dataset_dict: Dictionary of datasets
|
||||
:type dataset_dict: dict
|
||||
:param batch_size: Size of the batch
|
||||
:type batch_size: int
|
||||
:param condition_names: Names of the conditions
|
||||
:type condition_names: list[str]
|
||||
"""
|
||||
self.condition_names = condition_names
|
||||
self.dataset_dict = dataset_dict
|
||||
self._init_batches(batch_size)
|
||||
|
||||
def _init_batches(self, batch_size=None):
|
||||
"""
|
||||
Create batches according to the batch_size provided in input.
|
||||
"""
|
||||
self.batches = []
|
||||
n_elements = sum(len(v) for v in self.dataset_dict.values())
|
||||
if batch_size is None:
|
||||
batch_size = n_elements
|
||||
indexes_dict = {}
|
||||
n_batches = int(math.ceil(n_elements / batch_size))
|
||||
for k, v in self.dataset_dict.items():
|
||||
if n_batches != 1:
|
||||
indexes_dict[k] = math.floor(len(v) / (n_batches - 1))
|
||||
else:
|
||||
indexes_dict[k] = len(v)
|
||||
for i in range(n_batches):
|
||||
temp_dict = {}
|
||||
for k, v in indexes_dict.items():
|
||||
if i != n_batches - 1:
|
||||
temp_dict[k] = slice(i * v, (i + 1) * v)
|
||||
else:
|
||||
temp_dict[k] = slice(i * v, len(self.dataset_dict[k]))
|
||||
self.batches.append(
|
||||
Batch(idx_dict=temp_dict, dataset_dict=self.dataset_dict))
|
||||
|
||||
def __iter__(self):
|
||||
"""
|
||||
Makes dataloader object iterable
|
||||
"""
|
||||
yield from self.batches
|
||||
|
||||
def __len__(self):
|
||||
"""
|
||||
Return the number of batches.
|
||||
:return: The number of batches.
|
||||
:rtype: int
|
||||
"""
|
||||
return len(self.batches)
|
||||
@@ -1,36 +0,0 @@
|
||||
"""
|
||||
Module for PinaSubset class
|
||||
"""
|
||||
from pina import LabelTensor
|
||||
from torch import Tensor, float32
|
||||
|
||||
|
||||
class PinaSubset:
|
||||
"""
|
||||
TODO
|
||||
"""
|
||||
__slots__ = ['dataset', 'indices', 'require_grad']
|
||||
|
||||
def __init__(self, dataset, indices, require_grad=True):
|
||||
"""
|
||||
TODO
|
||||
"""
|
||||
self.dataset = dataset
|
||||
self.indices = indices
|
||||
self.require_grad = require_grad
|
||||
|
||||
def __len__(self):
|
||||
"""
|
||||
TODO
|
||||
"""
|
||||
return len(self.indices)
|
||||
|
||||
def __getattr__(self, name):
|
||||
tensor = self.dataset.__getattribute__(name)
|
||||
if isinstance(tensor, (LabelTensor, Tensor)):
|
||||
tensor = tensor[[self.indices]].to(self.dataset.device)
|
||||
return tensor.requires_grad_(
|
||||
self.require_grad) if tensor.dtype == float32 else tensor
|
||||
if isinstance(tensor, list):
|
||||
return [tensor[i] for i in self.indices]
|
||||
raise AttributeError(f"No attribute named {name}")
|
||||
@@ -1,35 +0,0 @@
|
||||
"""
|
||||
Sample dataset module
|
||||
"""
|
||||
from copy import deepcopy
|
||||
from .base_dataset import BaseDataset
|
||||
from ..condition import InputPointsEquationCondition
|
||||
|
||||
|
||||
class SamplePointDataset(BaseDataset):
|
||||
"""
|
||||
This class extends the BaseDataset to handle physical datasets
|
||||
composed of only input points.
|
||||
"""
|
||||
data_type = 'physics'
|
||||
__slots__ = InputPointsEquationCondition.__slots__
|
||||
|
||||
def add_points(self, data_dict, condition_idx, batching_dim=0):
|
||||
data_dict = deepcopy(data_dict)
|
||||
data_dict.pop('equation')
|
||||
super().add_points(data_dict, condition_idx)
|
||||
|
||||
def _init_from_problem(self, collector_dict):
|
||||
for name, data in collector_dict.items():
|
||||
keys = list(data.keys())
|
||||
if set(self.__slots__) == set(keys):
|
||||
data = deepcopy(data)
|
||||
data.pop('equation')
|
||||
self._populate_init_list(data)
|
||||
idx = [
|
||||
key for key, val in
|
||||
self.problem.collector.conditions_name.items()
|
||||
if val == name
|
||||
]
|
||||
self.conditions_idx.append(idx)
|
||||
self.initialize()
|
||||
@@ -1,13 +0,0 @@
|
||||
"""
|
||||
Supervised dataset module
|
||||
"""
|
||||
from .base_dataset import BaseDataset
|
||||
|
||||
|
||||
class SupervisedDataset(BaseDataset):
|
||||
"""
|
||||
This class extends the BaseDataset to handle datasets that consist of
|
||||
input-output pairs.
|
||||
"""
|
||||
data_type = 'supervised'
|
||||
__slots__ = ['input_points', 'output_points']
|
||||
@@ -1,14 +0,0 @@
|
||||
"""
|
||||
Unsupervised dataset module
|
||||
"""
|
||||
from .base_dataset import BaseDataset
|
||||
|
||||
|
||||
class UnsupervisedDataset(BaseDataset):
|
||||
"""
|
||||
This class extend BaseDataset class to handle
|
||||
unsupervised dataset,composed of input points
|
||||
and, optionally, conditional variables
|
||||
"""
|
||||
data_type = 'unsupervised'
|
||||
__slots__ = ['input_points', 'conditional_variables']
|
||||
@@ -93,8 +93,8 @@ class Graph:
|
||||
|
||||
logging.debug(f"edge_index computed")
|
||||
return Data(
|
||||
x=nodes_data,
|
||||
pos=nodes_coordinates,
|
||||
x=nodes_data.tensor,
|
||||
pos=nodes_coordinates.tensor,
|
||||
edge_index=edge_index,
|
||||
edge_attr=edges_data,
|
||||
)
|
||||
|
||||
@@ -4,26 +4,20 @@ import torch
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
def issubset(a, b):
|
||||
"""
|
||||
Check if a is a subset of b.
|
||||
"""
|
||||
if isinstance(a, list) and isinstance(b, list):
|
||||
return set(a).issubset(set(b))
|
||||
if isinstance(a, range) and isinstance(b, range):
|
||||
return a.start <= b.start and a.stop >= b.stop
|
||||
return False
|
||||
|
||||
full_labels = True
|
||||
MATH_FUNCTIONS = {torch.sin, torch.cos}
|
||||
|
||||
class LabelTensor(torch.Tensor):
|
||||
"""Torch tensor with a label for any column."""
|
||||
|
||||
@staticmethod
|
||||
def __new__(cls, x, labels, *args, **kwargs):
|
||||
full = kwargs.pop("full", full_labels)
|
||||
|
||||
if isinstance(x, LabelTensor):
|
||||
x.full = full
|
||||
return x
|
||||
else:
|
||||
return super().__new__(cls, x, *args, **kwargs)
|
||||
return super().__new__(cls, x, *args, **kwargs)
|
||||
|
||||
@property
|
||||
def tensor(self):
|
||||
@@ -40,22 +34,11 @@ class LabelTensor(torch.Tensor):
|
||||
{1: {"name": "space"['a', 'b', 'c'])
|
||||
|
||||
"""
|
||||
self.dim_names = None
|
||||
self.full = kwargs.get('full', True)
|
||||
self.labels = labels
|
||||
|
||||
@classmethod
|
||||
def __internal_init__(cls,
|
||||
x,
|
||||
labels,
|
||||
dim_names,
|
||||
*args,
|
||||
**kwargs):
|
||||
lt = cls.__new__(cls, x, labels, *args, **kwargs)
|
||||
lt._labels = labels
|
||||
lt.full = kwargs.get('full', True)
|
||||
lt.dim_names = dim_names
|
||||
return lt
|
||||
self.full = kwargs.get('full', full_labels)
|
||||
if labels is not None:
|
||||
self.labels = labels
|
||||
else:
|
||||
self._labels = {}
|
||||
|
||||
@property
|
||||
def labels(self):
|
||||
@@ -104,14 +87,13 @@ class LabelTensor(torch.Tensor):
|
||||
self._labels = {}
|
||||
if isinstance(labels, dict):
|
||||
self._init_labels_from_dict(labels)
|
||||
elif isinstance(labels, list):
|
||||
elif isinstance(labels, (list, range)):
|
||||
self._init_labels_from_list(labels)
|
||||
elif isinstance(labels, str):
|
||||
labels = [labels]
|
||||
self._init_labels_from_list(labels)
|
||||
else:
|
||||
raise ValueError("labels must be list, dict or string.")
|
||||
self.set_names()
|
||||
|
||||
def _init_labels_from_dict(self, labels):
|
||||
"""
|
||||
@@ -125,34 +107,38 @@ class LabelTensor(torch.Tensor):
|
||||
"""
|
||||
tensor_shape = self.shape
|
||||
|
||||
# Set all labels if full_labels is True
|
||||
if hasattr(self, 'full') and self.full:
|
||||
labels = {
|
||||
i: labels[i] if i in labels else {
|
||||
'name': i
|
||||
'name': i, 'dof': range(tensor_shape[i])
|
||||
}
|
||||
for i in labels.keys()
|
||||
for i in range(len(tensor_shape))
|
||||
}
|
||||
|
||||
for k, v in labels.items():
|
||||
|
||||
# Init labels from str
|
||||
if isinstance(v, str):
|
||||
v = {'name': v, 'dof': range(tensor_shape[k])}
|
||||
|
||||
# Init labels from dict
|
||||
elif isinstance(v, dict) and list(v.keys()) == ['name']:
|
||||
# Init from dict with only name key
|
||||
v['dof'] = range(tensor_shape[k])
|
||||
# Init from dict with both name and dof keys
|
||||
elif isinstance(v, dict) and sorted(list(
|
||||
v.keys())) == ['dof', 'name']:
|
||||
dof_list = v['dof']
|
||||
dof_len = len(dof_list)
|
||||
if dof_len != len(set(dof_list)):
|
||||
raise ValueError("dof must be unique")
|
||||
if dof_len != tensor_shape[k]:
|
||||
raise ValueError(
|
||||
'Number of dof does not match tensor shape')
|
||||
elif isinstance(v, dict):
|
||||
# Only name of the dimension if provided
|
||||
if list(v.keys()) == ['name']:
|
||||
v['dof'] = range(tensor_shape[k])
|
||||
# Both name and dof are provided
|
||||
elif sorted(list(v.keys())) == ['dof', 'name']:
|
||||
dof_list = v['dof']
|
||||
dof_len = len(dof_list)
|
||||
if dof_len != len(set(dof_list)):
|
||||
raise ValueError("dof must be unique")
|
||||
if dof_len != tensor_shape[k]:
|
||||
raise ValueError(
|
||||
'Number of dof does not match tensor shape')
|
||||
else:
|
||||
raise ValueError('Illegal labels initialization')
|
||||
# Perform update
|
||||
# Assign labels values
|
||||
self._labels[k] = v
|
||||
|
||||
def _init_labels_from_list(self, labels):
|
||||
@@ -172,75 +158,71 @@ class LabelTensor(torch.Tensor):
|
||||
}
|
||||
self._init_labels_from_dict(last_dim_labels)
|
||||
|
||||
def set_names(self):
|
||||
labels = self.stored_labels
|
||||
self.dim_names = {}
|
||||
for dim in labels.keys():
|
||||
self.dim_names[labels[dim]['name']] = dim
|
||||
|
||||
def extract(self, labels_to_extract):
|
||||
"""
|
||||
Extract the subset of the original tensor by returning all the columns
|
||||
corresponding to the passed ``label_to_extract``.
|
||||
|
||||
:param label_to_extract: The label(s) to extract.
|
||||
:type label_to_extract: str | list(str) | tuple(str)
|
||||
:param labels_to_extract: The label(s) to extract.
|
||||
:type labels_to_extract: str | list(str) | tuple(str)
|
||||
:raises TypeError: Labels are not ``str``.
|
||||
:raises ValueError: Label to extract is not in the labels ``list``.
|
||||
"""
|
||||
# Convert str/int to string
|
||||
def find_names(labels):
|
||||
dim_names = {}
|
||||
for dim in labels.keys():
|
||||
dim_names[labels[dim]['name']] = dim
|
||||
return dim_names
|
||||
|
||||
if isinstance(labels_to_extract, (str, int)):
|
||||
labels_to_extract = [labels_to_extract]
|
||||
|
||||
# Store useful variables
|
||||
labels = self.stored_labels
|
||||
labels = copy(self._labels)
|
||||
stored_keys = labels.keys()
|
||||
dim_names = self.dim_names
|
||||
dim_names = find_names(labels)
|
||||
ndim = len(super().shape)
|
||||
|
||||
# Convert tuple/list to dict
|
||||
# Convert tuple/list to dict (having a list as input
|
||||
# means that we want to extract a values from the last dimension)
|
||||
if isinstance(labels_to_extract, (tuple, list)):
|
||||
if not ndim - 1 in stored_keys:
|
||||
raise ValueError(
|
||||
"LabelTensor does not have labels in last dimension")
|
||||
name = labels[max(stored_keys)]['name']
|
||||
name = labels[ndim-1]['name']
|
||||
labels_to_extract = {name: list(labels_to_extract)}
|
||||
|
||||
# If labels_to_extract is not dict then rise error
|
||||
if not isinstance(labels_to_extract, dict):
|
||||
raise ValueError('labels_to_extract must be str or list or dict')
|
||||
|
||||
# Make copy of labels (avoid issue in consistency)
|
||||
updated_labels = {k: copy(v) for k, v in labels.items()}
|
||||
|
||||
# Initialize list used to perform extraction
|
||||
extractor = [slice(None) for _ in range(ndim)]
|
||||
extractor = [slice(None)]*ndim
|
||||
|
||||
# Loop over labels_to_extract dict
|
||||
for k, v in labels_to_extract.items():
|
||||
for dim_name, labels_te in labels_to_extract.items():
|
||||
|
||||
# If label is not find raise value error
|
||||
idx_dim = dim_names.get(k)
|
||||
idx_dim = dim_names.get(dim_name, None)
|
||||
if idx_dim is None:
|
||||
raise ValueError(
|
||||
'Cannot extract label with is not in original labels')
|
||||
|
||||
dim_labels = labels[idx_dim]['dof']
|
||||
v = [v] if isinstance(v, (int, str)) else v
|
||||
|
||||
if not isinstance(v, range):
|
||||
extractor[idx_dim] = [dim_labels.index(i)
|
||||
for i in v] if len(v) > 1 else slice(
|
||||
dim_labels.index(v[0]),
|
||||
dim_labels.index(v[0]) + 1)
|
||||
labels_te = [labels_te] if isinstance(labels_te, (int, str)) else labels_te
|
||||
if not isinstance(labels_te, range):
|
||||
#If is done to keep the dimension if there is only one extracted label
|
||||
extractor[idx_dim] = [dim_labels.index(i) for i in labels_te] \
|
||||
if len(labels_te)>1 else slice(dim_labels.index(labels_te[0]), dim_labels.index(labels_te[0])+1)
|
||||
else:
|
||||
extractor[idx_dim] = slice(v.start, v.stop)
|
||||
extractor[idx_dim] = slice(labels_te.start, labels_te.stop)
|
||||
|
||||
updated_labels.update({idx_dim: {'dof': v, 'name': k}})
|
||||
labels.update({idx_dim: {'dof': labels_te, 'name': dim_name}})
|
||||
|
||||
tensor = self.tensor
|
||||
tensor = tensor[extractor]
|
||||
return LabelTensor.__internal_init__(tensor, updated_labels, dim_names)
|
||||
tensor = super().__getitem__(extractor).as_subclass(LabelTensor)
|
||||
tensor._labels = labels
|
||||
return tensor
|
||||
|
||||
def __str__(self):
|
||||
"""
|
||||
@@ -272,39 +254,53 @@ class LabelTensor(torch.Tensor):
|
||||
return []
|
||||
if len(tensors) == 1 or isinstance(tensors, LabelTensor):
|
||||
return tensors[0]
|
||||
|
||||
# Perform cat on tensors
|
||||
new_tensor = torch.cat(tensors, dim=dim)
|
||||
|
||||
# Update labels
|
||||
labels = LabelTensor.__create_labels_cat(tensors, dim)
|
||||
# --------- Start definition auxiliary function ------
|
||||
# Compute and update labels
|
||||
def create_labels_cat(tensors, dim, tensor_shape):
|
||||
stored_labels = [tensor.stored_labels for tensor in tensors]
|
||||
keys = stored_labels[0].keys()
|
||||
|
||||
return LabelTensor.__internal_init__(new_tensor, labels,
|
||||
tensors[0].dim_names)
|
||||
if any(not all(stored_labels[i][k] == stored_labels[0][k] for i in
|
||||
range(len(stored_labels))) for k in keys if k != dim):
|
||||
raise RuntimeError('tensors must have the same shape and dof')
|
||||
|
||||
# Copy labels from the first tensor and update the 'dof' for dimension `dim`
|
||||
labels = copy(stored_labels[0])
|
||||
if dim in labels:
|
||||
labels_list = [tensor[dim]['dof'] for tensor in stored_labels]
|
||||
last_dim_dof = range(tensor_shape[dim]) if all(isinstance(label, range)
|
||||
for label in labels_list) else sum(labels_list, [])
|
||||
labels[dim]['dof'] = last_dim_dof
|
||||
return labels
|
||||
# --------- End definition auxiliary function ------
|
||||
|
||||
# Update labels
|
||||
if dim in tensors[0].stored_labels.keys():
|
||||
new_tensor_shape = new_tensor.shape
|
||||
labels = create_labels_cat(tensors, dim, new_tensor_shape)
|
||||
else:
|
||||
labels = tensors[0].stored_labels
|
||||
new_tensor._labels = labels
|
||||
return new_tensor
|
||||
|
||||
@staticmethod
|
||||
def __create_labels_cat(tensors, dim):
|
||||
# Check if names and dof of the labels are the same in all dimensions
|
||||
# except in dim
|
||||
stored_labels = [tensor.stored_labels for tensor in tensors]
|
||||
|
||||
# check if:
|
||||
# - labels dict have same keys
|
||||
# - all labels are the same expect for dimension dim
|
||||
if not all(
|
||||
all(stored_labels[i][k] == stored_labels[0][k]
|
||||
for i in range(len(stored_labels)))
|
||||
for k in stored_labels[0].keys() if k != dim):
|
||||
raise RuntimeError('tensors must have the same shape and dof')
|
||||
|
||||
labels = {k: copy(v) for k, v in tensors[0].stored_labels.items()}
|
||||
if dim in labels.keys():
|
||||
last_dim_dof = [i for j in stored_labels for i in j[dim]['dof']]
|
||||
labels[dim]['dof'] = last_dim_dof
|
||||
return labels
|
||||
def stack(tensors):
|
||||
new_tensor = torch.stack(tensors)
|
||||
labels = tensors[0]._labels
|
||||
labels = {key + 1: value for key, value in labels.items()}
|
||||
if full_labels:
|
||||
new_tensor.labels = labels
|
||||
else:
|
||||
new_tensor._labels = labels
|
||||
return new_tensor
|
||||
|
||||
def requires_grad_(self, mode=True):
|
||||
lt = super().requires_grad_(mode)
|
||||
lt.labels = self._labels
|
||||
lt._labels = self._labels
|
||||
return lt
|
||||
|
||||
@property
|
||||
@@ -316,10 +312,9 @@ class LabelTensor(torch.Tensor):
|
||||
Performs Tensor dtype and/or device conversion. For more details, see
|
||||
:meth:`torch.Tensor.to`.
|
||||
"""
|
||||
tmp = super().to(*args, **kwargs)
|
||||
new = self.__class__.clone(self)
|
||||
new.data = tmp.data
|
||||
return new
|
||||
lt = super().to(*args, **kwargs)
|
||||
lt._labels = self._labels
|
||||
return lt
|
||||
|
||||
def clone(self, *args, **kwargs):
|
||||
"""
|
||||
@@ -329,8 +324,7 @@ class LabelTensor(torch.Tensor):
|
||||
:return: A copy of the tensor.
|
||||
:rtype: LabelTensor
|
||||
"""
|
||||
labels = {k: copy(v) for k, v in self._labels.items()}
|
||||
out = LabelTensor(super().clone(*args, **kwargs), labels)
|
||||
out = LabelTensor(super().clone(*args, **kwargs), deepcopy(self._labels))
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
@@ -348,7 +342,7 @@ class LabelTensor(torch.Tensor):
|
||||
raise RuntimeError('Tensors must have the same shape and labels')
|
||||
|
||||
last_dim_labels = []
|
||||
data = torch.zeros(tensors[0].tensor.shape)
|
||||
data = torch.zeros(tensors[0].tensor.shape).to(tensors[0].device)
|
||||
for tensor in tensors:
|
||||
data += tensor.tensor
|
||||
last_dim_labels.append(tensor.labels)
|
||||
@@ -396,82 +390,114 @@ class LabelTensor(torch.Tensor):
|
||||
"""
|
||||
return LabelTensor.cat(label_tensors, dim=0)
|
||||
|
||||
# ---------------------- Start auxiliary function definition -----
|
||||
# This method is used to update labels
|
||||
def _update_single_label(self, old_labels, to_update_labels, index, dim,
|
||||
to_update_dim):
|
||||
"""
|
||||
TODO
|
||||
:param old_labels: labels from which retrieve data
|
||||
:param to_update_labels: labels to update
|
||||
:param index: index of dof to retain
|
||||
:param dim: label index
|
||||
:return:
|
||||
"""
|
||||
old_dof = old_labels[to_update_dim]['dof']
|
||||
if isinstance(index, slice):
|
||||
to_update_labels.update({
|
||||
dim: {
|
||||
'dof': old_dof[index],
|
||||
'name': old_labels[dim]['name']
|
||||
}
|
||||
})
|
||||
return
|
||||
if isinstance(index, int):
|
||||
index = [index]
|
||||
if isinstance(index, (list, torch.Tensor)):
|
||||
to_update_labels.update({
|
||||
dim: {
|
||||
'dof': [old_dof[i] for i in index] if isinstance(old_dof, list) else index,
|
||||
'name': old_labels[dim]['name']
|
||||
}
|
||||
})
|
||||
return
|
||||
raise NotImplementedError(f'Getitem not implemented for '
|
||||
f'{type(index)} values')
|
||||
# ---------------------- End auxiliary function definition -----
|
||||
|
||||
|
||||
def __getitem__(self, index):
|
||||
"""
|
||||
TODO: Complete docstring
|
||||
:param index:
|
||||
:return:
|
||||
"""
|
||||
if isinstance(index,
|
||||
str) or (isinstance(index, (tuple, list))
|
||||
and all(isinstance(a, str) for a in index)):
|
||||
# Index are str --> call extract
|
||||
if isinstance(index, str) or (isinstance(index, (tuple, list))
|
||||
and all(
|
||||
isinstance(a, str) for a in index)):
|
||||
return self.extract(index)
|
||||
|
||||
# Store important variables
|
||||
selected_lt = super().__getitem__(index)
|
||||
stored_labels = self._labels
|
||||
labels = copy(stored_labels)
|
||||
|
||||
if isinstance(index, (int, slice)):
|
||||
# Put here because it is the most common case (int as index).
|
||||
# Used by DataLoader -> put here for efficiency purpose
|
||||
if isinstance(index, list):
|
||||
if 0 in labels.keys():
|
||||
self._update_single_label(stored_labels, labels, index,
|
||||
0, 0)
|
||||
selected_lt._labels = labels
|
||||
return selected_lt
|
||||
|
||||
if isinstance(index, int):
|
||||
labels.pop(0, None)
|
||||
labels = {key - 1 if key > 0 else key: value for key, value in
|
||||
labels.items()}
|
||||
selected_lt._labels = labels
|
||||
return selected_lt
|
||||
|
||||
if not isinstance(index, (tuple, torch.Tensor)):
|
||||
index = [index]
|
||||
|
||||
# Ellipsis are used to perform operation on the last dimension
|
||||
if index[0] == Ellipsis:
|
||||
index = [slice(None)] * (self.ndim - 1) + [index[1]]
|
||||
if len(self.shape) in labels:
|
||||
self._update_single_label(stored_labels, labels, index, 0, 0)
|
||||
selected_lt._labels = labels
|
||||
return selected_lt
|
||||
|
||||
if hasattr(self, "labels"):
|
||||
labels = {k: copy(v) for k, v in self.stored_labels.items()}
|
||||
for j, idx in enumerate(index):
|
||||
if isinstance(idx, int):
|
||||
i = 0
|
||||
for j, idx in enumerate(index):
|
||||
if j in self.stored_labels.keys():
|
||||
if isinstance(idx, int) or (
|
||||
isinstance(idx, torch.Tensor) and idx.ndim == 0):
|
||||
selected_lt = selected_lt.unsqueeze(j)
|
||||
if j in labels.keys() and idx != slice(None):
|
||||
self._update_single_label(labels, labels, idx, j)
|
||||
selected_lt = LabelTensor.__internal_init__(selected_lt, labels,
|
||||
self.dim_names)
|
||||
if idx != slice(None):
|
||||
self._update_single_label(stored_labels, labels, idx, j, i)
|
||||
else:
|
||||
if isinstance(idx, int):
|
||||
labels = {key - 1 if key > j else key:
|
||||
value for key, value in labels.items()}
|
||||
continue
|
||||
i += 1
|
||||
selected_lt._labels = labels
|
||||
return selected_lt
|
||||
|
||||
@staticmethod
|
||||
def _update_single_label(old_labels, to_update_labels, index, dim):
|
||||
"""
|
||||
TODO
|
||||
:param old_labels: labels from which retrieve data
|
||||
:param to_update_labels: labels to update
|
||||
:param index: index of dof to retain
|
||||
:param dim: label index
|
||||
:return:
|
||||
"""
|
||||
old_dof = old_labels[dim]['dof']
|
||||
if not isinstance(
|
||||
index,
|
||||
(int, slice)) and len(index) == len(old_dof) and isinstance(
|
||||
old_dof, range):
|
||||
return
|
||||
if isinstance(index, torch.Tensor):
|
||||
index = index.nonzero(
|
||||
as_tuple=True
|
||||
)[0] if index.dtype == torch.bool else index.tolist()
|
||||
if isinstance(index, list):
|
||||
to_update_labels.update({
|
||||
dim: {
|
||||
'dof': [old_dof[i] for i in index],
|
||||
'name': old_labels[dim]['name']
|
||||
}
|
||||
})
|
||||
else:
|
||||
to_update_labels.update(
|
||||
{dim: {
|
||||
'dof': old_dof[index],
|
||||
'name': old_labels[dim]['name']
|
||||
}})
|
||||
|
||||
def sort_labels(self, dim=None):
|
||||
|
||||
def arg_sort(lst):
|
||||
return sorted(range(len(lst)), key=lambda x: lst[x])
|
||||
|
||||
if dim is None:
|
||||
dim = self.ndim - 1
|
||||
if self.shape[dim] == 1:
|
||||
return self
|
||||
labels = self.stored_labels[dim]['dof']
|
||||
sorted_index = arg_sort(labels)
|
||||
indexer = [slice(None)] * self.ndim
|
||||
indexer[dim] = sorted_index
|
||||
return self.__getitem__(indexer)
|
||||
return self.__getitem__(tuple(indexer))
|
||||
|
||||
def __deepcopy__(self, memo):
|
||||
cls = self.__class__
|
||||
@@ -480,10 +506,16 @@ class LabelTensor(torch.Tensor):
|
||||
|
||||
def permute(self, *dims):
|
||||
tensor = super().permute(*dims)
|
||||
stored_labels = self.stored_labels
|
||||
labels = self._labels
|
||||
keys_list = list(*dims)
|
||||
labels = {
|
||||
keys_list.index(k): copy(stored_labels[k])
|
||||
for k in stored_labels.keys()
|
||||
keys_list.index(k): labels[k]
|
||||
for k in labels.keys()
|
||||
}
|
||||
return LabelTensor.__internal_init__(tensor, labels, self.dim_names)
|
||||
tensor._labels = labels
|
||||
return tensor
|
||||
|
||||
def detach(self):
|
||||
lt = super().detach()
|
||||
lt._labels = self.stored_labels
|
||||
return lt
|
||||
@@ -119,6 +119,7 @@ class LowRankBlock(torch.nn.Module):
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
# extract basis
|
||||
coords = coords.as_subclass(torch.Tensor)
|
||||
basis = self._basis(coords)
|
||||
# reshape [B, N, D, 2*rank]
|
||||
shape = list(basis.shape[:-1]) + [-1, 2 * self.rank]
|
||||
|
||||
@@ -29,7 +29,8 @@ class Network(torch.nn.Module):
|
||||
# check model consistency
|
||||
check_consistency(model, nn.Module)
|
||||
check_consistency(input_variables, str)
|
||||
check_consistency(output_variables, str)
|
||||
if output_variables is not None:
|
||||
check_consistency(output_variables, str)
|
||||
|
||||
self._model = model
|
||||
self._input_variables = input_variables
|
||||
@@ -67,16 +68,15 @@ class Network(torch.nn.Module):
|
||||
# in case `input_variables = []` all points are used
|
||||
if self._input_variables:
|
||||
x = x.extract(self._input_variables)
|
||||
|
||||
# extract features and append
|
||||
for feature in self._extra_features:
|
||||
x = x.append(feature(x))
|
||||
|
||||
# perform forward pass + converting to LabelTensor
|
||||
output = self._model(x).as_subclass(LabelTensor)
|
||||
|
||||
# set the labels for LabelTensor
|
||||
output.labels = self._output_variables
|
||||
x = x.as_subclass(torch.Tensor)
|
||||
output = self._model(x)
|
||||
if self._output_variables is not None:
|
||||
output = LabelTensor(output, self._output_variables)
|
||||
|
||||
return output
|
||||
|
||||
@@ -97,15 +97,9 @@ class Network(torch.nn.Module):
|
||||
This function does not extract the input variables, all the variables
|
||||
are used for both tensors. Output variables are correctly applied.
|
||||
"""
|
||||
# convert LabelTensor s to torch.Tensor s
|
||||
x = list(map(lambda x: x.as_subclass(torch.Tensor), x))
|
||||
|
||||
# perform forward pass (using torch.Tensor) + converting to LabelTensor
|
||||
output = self._model(x).as_subclass(LabelTensor)
|
||||
|
||||
# set the labels for LabelTensor
|
||||
output.labels = self._output_variables
|
||||
|
||||
output = LabelTensor(self._model(x.tensor), self._output_variables)
|
||||
return output
|
||||
|
||||
@property
|
||||
|
||||
@@ -63,11 +63,9 @@ def grad(output_, input_, components=None, d=None):
|
||||
retain_graph=True,
|
||||
allow_unused=True,
|
||||
)[0]
|
||||
|
||||
gradients.labels = input_.labels
|
||||
gradients = gradients.extract(d)
|
||||
gradients.labels = input_.stored_labels
|
||||
gradients = gradients[..., [input_.labels.index(i) for i in d]]
|
||||
gradients.labels = [f"d{output_fieldname}d{i}" for i in d]
|
||||
|
||||
return gradients
|
||||
|
||||
if not isinstance(input_, LabelTensor):
|
||||
@@ -216,7 +214,9 @@ def laplacian(output_, input_, components=None, d=None, method="std"):
|
||||
to_append_tensors = []
|
||||
for i, label in enumerate(grad_output.labels):
|
||||
gg = grad(grad_output, input_, d=d, components=[label])
|
||||
to_append_tensors.append(gg.extract([gg.labels[i]]))
|
||||
gg = gg.extract([gg.labels[i]])
|
||||
|
||||
to_append_tensors.append(gg)
|
||||
labels = [f"dd{components[0]}"]
|
||||
result = LabelTensor.summation(tensors=to_append_tensors)
|
||||
result.labels = labels
|
||||
|
||||
@@ -90,10 +90,9 @@ class AbstractProblem(metaclass=ABCMeta):
|
||||
variables += self.spatial_variables
|
||||
if hasattr(self, "temporal_variable"):
|
||||
variables += self.temporal_variable
|
||||
if hasattr(self, "unknown_parameters"):
|
||||
if hasattr(self, "parameters"):
|
||||
variables += self.parameters
|
||||
if hasattr(self, "custom_variables"):
|
||||
variables += self.custom_variables
|
||||
|
||||
|
||||
return variables
|
||||
|
||||
@@ -170,7 +169,6 @@ class AbstractProblem(metaclass=ABCMeta):
|
||||
f"Wrong variables for sampling. Variables ",
|
||||
f"should be in {self.input_variables}.",
|
||||
)
|
||||
|
||||
# check correct location
|
||||
if locations == "all":
|
||||
locations = [
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
"""Module for the ParametricProblem class"""
|
||||
|
||||
import torch
|
||||
from abc import abstractmethod
|
||||
|
||||
from .abstract_problem import AbstractProblem
|
||||
|
||||
|
||||
|
||||
@@ -1,34 +0,0 @@
|
||||
from .supervised import SupervisedSolver
|
||||
from ..graph import Graph
|
||||
|
||||
|
||||
class GraphSupervisedSolver(SupervisedSolver):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
problem,
|
||||
model,
|
||||
nodes_coordinates,
|
||||
nodes_data,
|
||||
loss=None,
|
||||
optimizer=None,
|
||||
scheduler=None):
|
||||
super().__init__(problem, model, loss, optimizer, scheduler)
|
||||
if isinstance(nodes_coordinates, str):
|
||||
self._nodes_coordinates = [nodes_coordinates]
|
||||
else:
|
||||
self._nodes_coordinates = nodes_coordinates
|
||||
if isinstance(nodes_data, str):
|
||||
self._nodes_data = nodes_data
|
||||
else:
|
||||
self._nodes_data = nodes_data
|
||||
|
||||
def forward(self, input):
|
||||
input_coords = input.extract(self._nodes_coordinates)
|
||||
input_data = input.extract(self._nodes_data)
|
||||
|
||||
if not isinstance(input, Graph):
|
||||
input = Graph.build('radius', nodes_coordinates=input_coords, nodes_data=input_data, radius=0.2)
|
||||
g = self.model(input.data, edge_index=input.data.edge_index)
|
||||
g.labels = {1: {'name': 'output', 'dof': ['u']}}
|
||||
return g
|
||||
@@ -1,14 +1,15 @@
|
||||
""" Module for PINN """
|
||||
|
||||
import sys
|
||||
from abc import ABCMeta, abstractmethod
|
||||
import torch
|
||||
|
||||
from ...solvers.solver import SolverInterface
|
||||
from pina.utils import check_consistency
|
||||
from pina.loss.loss_interface import LossInterface
|
||||
from pina.problem import InverseProblem
|
||||
from torch.nn.modules.loss import _Loss
|
||||
from ...condition import InputOutputPointsCondition
|
||||
from ...solvers.solver import SolverInterface
|
||||
from ...utils import check_consistency
|
||||
from ...loss.loss_interface import LossInterface
|
||||
from ...problem import InverseProblem
|
||||
from ...condition import DomainEquationCondition
|
||||
from ...optim import TorchOptimizer, TorchScheduler
|
||||
|
||||
torch.pi = torch.acos(torch.zeros(1)).item() * 2 # which is 3.1415927410125732
|
||||
|
||||
@@ -25,13 +26,14 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
|
||||
to the user to choose which problem the implemented solver inheriting from
|
||||
this class is suitable for.
|
||||
"""
|
||||
|
||||
accepted_condition_types = [DomainEquationCondition.condition_type[0],
|
||||
InputOutputPointsCondition.condition_type[0]]
|
||||
def __init__(
|
||||
self,
|
||||
models,
|
||||
problem,
|
||||
optimizers,
|
||||
optimizers_kwargs,
|
||||
schedulers,
|
||||
extra_features,
|
||||
loss,
|
||||
):
|
||||
@@ -53,11 +55,20 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
|
||||
:param torch.nn.Module loss: The loss function used as minimizer,
|
||||
default :class:`torch.nn.MSELoss`.
|
||||
"""
|
||||
if optimizers is None:
|
||||
optimizers = TorchOptimizer(torch.optim.Adam, lr=0.001)
|
||||
|
||||
if schedulers is None:
|
||||
schedulers = TorchScheduler(torch.optim.lr_scheduler.ConstantLR)
|
||||
|
||||
if loss is None:
|
||||
loss = torch.nn.MSELoss()
|
||||
|
||||
super().__init__(
|
||||
models=models,
|
||||
problem=problem,
|
||||
optimizers=optimizers,
|
||||
optimizers_kwargs=optimizers_kwargs,
|
||||
schedulers=schedulers,
|
||||
extra_features=extra_features,
|
||||
)
|
||||
|
||||
@@ -85,7 +96,12 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
|
||||
# variable will be stored with name = self.__logged_metric
|
||||
self.__logged_metric = None
|
||||
|
||||
def training_step(self, batch, _):
|
||||
self._model = self._pina_models[0]
|
||||
self._optimizer = self._pina_optimizers[0]
|
||||
self._scheduler = self._pina_schedulers[0]
|
||||
|
||||
|
||||
def training_step(self, batch):
|
||||
"""
|
||||
The Physics Informed Solver Training Step. This function takes care
|
||||
of the physics informed training step, and it must not be override
|
||||
@@ -99,53 +115,68 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
|
||||
:rtype: LabelTensor
|
||||
"""
|
||||
|
||||
condition_losses = []
|
||||
condition_idx = batch["condition"]
|
||||
condition_loss = []
|
||||
for condition_name, points in batch:
|
||||
if 'output_points' in points:
|
||||
input_pts, output_pts = points['input_points'], points['output_points']
|
||||
|
||||
for condition_id in range(condition_idx.min(), condition_idx.max() + 1):
|
||||
|
||||
condition_name = self._dataloader.condition_names[condition_id]
|
||||
condition = self.problem.conditions[condition_name]
|
||||
pts = batch["pts"]
|
||||
# condition name is logged (if logs enabled)
|
||||
self.__logged_metric = condition_name
|
||||
|
||||
if len(batch) == 2:
|
||||
samples = pts[condition_idx == condition_id]
|
||||
loss = self.loss_phys(samples, condition.equation)
|
||||
elif len(batch) == 3:
|
||||
samples = pts[condition_idx == condition_id]
|
||||
ground_truth = batch["output"][condition_idx == condition_id]
|
||||
loss = self.loss_data(samples, ground_truth)
|
||||
loss_ = self.loss_data(input_pts=input_pts, output_pts=output_pts)
|
||||
condition_loss.append(loss_.as_subclass(torch.Tensor))
|
||||
else:
|
||||
raise ValueError("Batch size not supported")
|
||||
input_pts = points['input_points']
|
||||
|
||||
# add condition losses for each epoch
|
||||
condition_losses.append(loss * condition.data_weight)
|
||||
condition = self.problem.conditions[condition_name]
|
||||
|
||||
loss_ = self.loss_phys(input_pts.requires_grad_(), condition.equation)
|
||||
condition_loss.append(loss_.as_subclass(torch.Tensor))
|
||||
condition_loss.append(loss_.as_subclass(torch.Tensor))
|
||||
# clamp unknown parameters in InverseProblem (if needed)
|
||||
self._clamp_params()
|
||||
loss = sum(condition_loss)
|
||||
self.log('train_loss', loss, prog_bar=True, on_epoch=True,
|
||||
logger=True, batch_size=self.get_batch_size(batch),
|
||||
sync_dist=True)
|
||||
|
||||
# total loss (must be a torch.Tensor), and logs
|
||||
total_loss = sum(condition_losses)
|
||||
self.save_logs_and_release()
|
||||
return total_loss.as_subclass(torch.Tensor)
|
||||
return loss
|
||||
|
||||
def loss_data(self, input_tensor, output_tensor):
|
||||
def validation_step(self, batch):
|
||||
"""
|
||||
TODO: add docstring
|
||||
"""
|
||||
condition_loss = []
|
||||
for condition_name, points in batch:
|
||||
if 'output_points' in points:
|
||||
input_pts, output_pts = points['input_points'], points['output_points']
|
||||
loss_ = self.loss_data(input_pts=input_pts, output_pts=output_pts)
|
||||
condition_loss.append(loss_.as_subclass(torch.Tensor))
|
||||
else:
|
||||
input_pts = points['input_points']
|
||||
|
||||
condition = self.problem.conditions[condition_name]
|
||||
with torch.set_grad_enabled(True):
|
||||
loss_ = self.loss_phys(input_pts.requires_grad_(), condition.equation)
|
||||
condition_loss.append(loss_.as_subclass(torch.Tensor))
|
||||
condition_loss.append(loss_.as_subclass(torch.Tensor))
|
||||
# clamp unknown parameters in InverseProblem (if needed)
|
||||
|
||||
loss = sum(condition_loss)
|
||||
self.log('val_loss', loss, on_epoch=True, prog_bar=True,
|
||||
logger=True, batch_size=self.get_batch_size(batch),
|
||||
sync_dist=True)
|
||||
|
||||
def loss_data(self, input_pts, output_pts):
|
||||
"""
|
||||
The data loss for the PINN solver. It computes the loss between
|
||||
the network output against the true solution. This function
|
||||
should not be override if not intentionally.
|
||||
|
||||
:param LabelTensor input_tensor: The input to the neural networks.
|
||||
:param LabelTensor output_tensor: The true solution to compare the
|
||||
:param LabelTensor input_pts: The input to the neural networks.
|
||||
:param LabelTensor output_pts: The true solution to compare the
|
||||
network solution.
|
||||
:return: The residual loss averaged on the input coordinates
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
loss_value = self.loss(self.forward(input_tensor), output_tensor)
|
||||
self.store_log(loss_value=float(loss_value))
|
||||
return self.loss(self.forward(input_tensor), output_tensor)
|
||||
return self._loss(self.forward(input_pts), output_pts)
|
||||
|
||||
@abstractmethod
|
||||
def loss_phys(self, samples, equation):
|
||||
@@ -196,13 +227,17 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
|
||||
:param str name: The name of the loss.
|
||||
:param torch.Tensor loss_value: The value of the loss.
|
||||
"""
|
||||
batch_size = self.trainer.data_module.batch_size \
|
||||
if self.trainer.data_module.batch_size is not None else 999
|
||||
|
||||
self.log(
|
||||
self.__logged_metric + "_loss",
|
||||
loss_value,
|
||||
prog_bar=True,
|
||||
logger=True,
|
||||
on_epoch=True,
|
||||
on_step=False,
|
||||
on_step=True,
|
||||
batch_size=batch_size,
|
||||
)
|
||||
self.__logged_res_losses.append(loss_value)
|
||||
|
||||
|
||||
@@ -9,10 +9,8 @@ except ImportError:
|
||||
_LRScheduler as LRScheduler,
|
||||
) # torch < 2.0
|
||||
|
||||
from torch.optim.lr_scheduler import ConstantLR
|
||||
|
||||
from .basepinn import PINNInterface
|
||||
from pina.utils import check_consistency
|
||||
from pina.problem import InverseProblem
|
||||
|
||||
|
||||
@@ -56,16 +54,16 @@ class PINN(PINNInterface):
|
||||
DOI: `10.1038 <https://doi.org/10.1038/s42254-021-00314-5>`_.
|
||||
"""
|
||||
|
||||
__name__ = 'PINN'
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
problem,
|
||||
model,
|
||||
extra_features=None,
|
||||
loss=torch.nn.MSELoss(),
|
||||
optimizer=torch.optim.Adam,
|
||||
optimizer_kwargs={"lr": 0.001},
|
||||
scheduler=ConstantLR,
|
||||
scheduler_kwargs={"factor": 1, "total_iters": 0},
|
||||
loss=None,
|
||||
optimizer=None,
|
||||
scheduler=None,
|
||||
):
|
||||
"""
|
||||
:param AbstractProblem problem: The formulation of the problem.
|
||||
@@ -82,20 +80,15 @@ class PINN(PINNInterface):
|
||||
:param dict scheduler_kwargs: LR scheduler constructor keyword args.
|
||||
"""
|
||||
super().__init__(
|
||||
models=[model],
|
||||
models=model,
|
||||
problem=problem,
|
||||
optimizers=[optimizer],
|
||||
optimizers_kwargs=[optimizer_kwargs],
|
||||
optimizers=optimizer,
|
||||
schedulers=scheduler,
|
||||
extra_features=extra_features,
|
||||
loss=loss,
|
||||
)
|
||||
|
||||
# check consistency
|
||||
check_consistency(scheduler, LRScheduler, subclass=True)
|
||||
check_consistency(scheduler_kwargs, dict)
|
||||
|
||||
# assign variables
|
||||
self._scheduler = scheduler(self.optimizers[0], **scheduler_kwargs)
|
||||
self._neural_net = self.models[0]
|
||||
|
||||
def forward(self, x):
|
||||
@@ -126,9 +119,8 @@ class PINN(PINNInterface):
|
||||
"""
|
||||
residual = self.compute_residual(samples=samples, equation=equation)
|
||||
loss_value = self.loss(
|
||||
torch.zeros_like(residual, requires_grad=True), residual
|
||||
torch.zeros_like(residual), residual
|
||||
)
|
||||
self.store_log(loss_value=float(loss_value))
|
||||
return loss_value
|
||||
|
||||
def configure_optimizers(self):
|
||||
@@ -141,16 +133,21 @@ class PINN(PINNInterface):
|
||||
"""
|
||||
# if the problem is an InverseProblem, add the unknown parameters
|
||||
# to the parameters that the optimizer needs to optimize
|
||||
|
||||
|
||||
self._optimizer.hook(self._model.parameters())
|
||||
if isinstance(self.problem, InverseProblem):
|
||||
self.optimizers[0].add_param_group(
|
||||
{
|
||||
"params": [
|
||||
self._params[var]
|
||||
for var in self.problem.unknown_variables
|
||||
]
|
||||
}
|
||||
)
|
||||
return self.optimizers, [self.scheduler]
|
||||
self._optimizer.optimizer_instance.add_param_group(
|
||||
{
|
||||
"params": [
|
||||
self._params[var]
|
||||
for var in self.problem.unknown_variables
|
||||
]
|
||||
}
|
||||
)
|
||||
self._scheduler.hook(self._optimizer)
|
||||
return ([self._optimizer.optimizer_instance],
|
||||
[self._scheduler.scheduler_instance])
|
||||
|
||||
@property
|
||||
def scheduler(self):
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from ..model.network import Network
|
||||
import pytorch_lightning
|
||||
import lightning
|
||||
from ..utils import check_consistency
|
||||
from ..problem import AbstractProblem
|
||||
from ..optim import Optimizer, Scheduler
|
||||
@@ -10,7 +10,8 @@ import torch
|
||||
import sys
|
||||
|
||||
|
||||
class SolverInterface(pytorch_lightning.LightningModule, metaclass=ABCMeta):
|
||||
|
||||
class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
|
||||
"""
|
||||
Solver base class. This class inherits is a wrapper of
|
||||
LightningModule class, inheriting all the
|
||||
@@ -83,7 +84,6 @@ class SolverInterface(pytorch_lightning.LightningModule, metaclass=ABCMeta):
|
||||
" optimizers.")
|
||||
|
||||
# extra features handling
|
||||
|
||||
self._pina_models = models
|
||||
self._pina_optimizers = optimizers
|
||||
self._pina_schedulers = schedulers
|
||||
@@ -94,7 +94,7 @@ class SolverInterface(pytorch_lightning.LightningModule, metaclass=ABCMeta):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def training_step(self, batch, batch_idx):
|
||||
def training_step(self, batch):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@@ -138,8 +138,16 @@ class SolverInterface(pytorch_lightning.LightningModule, metaclass=ABCMeta):
|
||||
TODO
|
||||
"""
|
||||
for _, condition in problem.conditions.items():
|
||||
if not set(self.accepted_condition_types).issubset(
|
||||
condition.condition_type):
|
||||
if not set(condition.condition_type).issubset(
|
||||
set(self.accepted_condition_types)):
|
||||
raise ValueError(
|
||||
f'{self.__name__} support only dose not support condition '
|
||||
f'{self.__name__} dose not support condition '
|
||||
f'{condition.condition_type}')
|
||||
|
||||
@staticmethod
|
||||
def get_batch_size(batch):
|
||||
# Assuming batch is your custom Batch object
|
||||
batch_size = 0
|
||||
for data in batch:
|
||||
batch_size += len(data[1]['input_points'])
|
||||
return batch_size
|
||||
@@ -1,12 +1,14 @@
|
||||
""" Module for SupervisedSolver """
|
||||
|
||||
import torch
|
||||
from pytorch_lightning.utilities.types import STEP_OUTPUT
|
||||
from sympy.strategies.branch import condition
|
||||
from torch.nn.modules.loss import _Loss
|
||||
from ..optim import TorchOptimizer, TorchScheduler
|
||||
from .solver import SolverInterface
|
||||
from ..label_tensor import LabelTensor
|
||||
from ..utils import check_consistency
|
||||
from ..loss.loss_interface import LossInterface
|
||||
from ..condition import InputOutputPointsCondition
|
||||
|
||||
|
||||
class SupervisedSolver(SolverInterface):
|
||||
@@ -37,7 +39,7 @@ class SupervisedSolver(SolverInterface):
|
||||
we are seeking to approximate multiple (discretised) functions given
|
||||
multiple (discretised) input functions.
|
||||
"""
|
||||
accepted_condition_types = ['supervised']
|
||||
accepted_condition_types = [InputOutputPointsCondition.condition_type[0]]
|
||||
__name__ = 'SupervisedSolver'
|
||||
|
||||
def __init__(self,
|
||||
@@ -46,7 +48,8 @@ class SupervisedSolver(SolverInterface):
|
||||
loss=None,
|
||||
optimizer=None,
|
||||
scheduler=None,
|
||||
extra_features=None):
|
||||
extra_features=None,
|
||||
use_lt=True):
|
||||
"""
|
||||
:param AbstractProblem problem: The formualation of the problem.
|
||||
:param torch.nn.Module model: The neural network model to use.
|
||||
@@ -72,14 +75,19 @@ class SupervisedSolver(SolverInterface):
|
||||
problem=problem,
|
||||
optimizers=optimizer,
|
||||
schedulers=scheduler,
|
||||
extra_features=extra_features)
|
||||
extra_features=extra_features,
|
||||
use_lt=use_lt)
|
||||
|
||||
# check consistency
|
||||
check_consistency(loss, (LossInterface, _Loss), subclass=False)
|
||||
check_consistency(loss, (LossInterface, _Loss, torch.nn.Module),
|
||||
subclass=False)
|
||||
self._loss = loss
|
||||
self._model = self._pina_models[0]
|
||||
self._optimizer = self._pina_optimizers[0]
|
||||
self._scheduler = self._pina_schedulers[0]
|
||||
self.validation_condition_losses = {
|
||||
k: {'loss': [],
|
||||
'count': []} for k in self.problem.conditions.keys()}
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward pass implementation for the solver.
|
||||
@@ -105,7 +113,7 @@ class SupervisedSolver(SolverInterface):
|
||||
return ([self._optimizer.optimizer_instance],
|
||||
[self._scheduler.scheduler_instance])
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
def training_step(self, batch):
|
||||
"""Solver training step.
|
||||
|
||||
:param batch: The batch element in the dataloader.
|
||||
@@ -115,33 +123,37 @@ class SupervisedSolver(SolverInterface):
|
||||
:return: The sum of the loss functions.
|
||||
:rtype: LabelTensor
|
||||
"""
|
||||
condition_idx = batch.supervised.condition_indices
|
||||
|
||||
for condition_id in range(condition_idx.min(), condition_idx.max() + 1):
|
||||
|
||||
condition_name = self._dataloader.condition_names[condition_id]
|
||||
condition = self.problem.conditions[condition_name]
|
||||
pts = batch.supervised.input_points
|
||||
out = batch.supervised.output_points
|
||||
if condition_name not in self.problem.conditions:
|
||||
raise RuntimeError("Something wrong happened.")
|
||||
|
||||
# for data driven mode
|
||||
if not hasattr(condition, "output_points"):
|
||||
raise NotImplementedError(
|
||||
f"{type(self).__name__} works only in data-driven mode.")
|
||||
output_pts = out[condition_idx == condition_id]
|
||||
input_pts = pts[condition_idx == condition_id]
|
||||
|
||||
input_pts.labels = pts.labels
|
||||
output_pts.labels = out.labels
|
||||
|
||||
loss = self.loss_data(input_pts=input_pts, output_pts=output_pts)
|
||||
loss = loss.as_subclass(torch.Tensor)
|
||||
|
||||
self.log("mean_loss", float(loss), prog_bar=True, logger=True)
|
||||
condition_loss = []
|
||||
for condition_name, points in batch:
|
||||
input_pts, output_pts = points['input_points'], points['output_points']
|
||||
loss_ = self.loss_data(input_pts=input_pts, output_pts=output_pts)
|
||||
condition_loss.append(loss_.as_subclass(torch.Tensor))
|
||||
loss = sum(condition_loss)
|
||||
self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True,
|
||||
batch_size=self.get_batch_size(batch), sync_dist=True)
|
||||
return loss
|
||||
|
||||
def validation_step(self, batch):
|
||||
"""
|
||||
Solver validation step.
|
||||
"""
|
||||
condition_loss = []
|
||||
for condition_name, points in batch:
|
||||
input_pts, output_pts = points['input_points'], points['output_points']
|
||||
loss_ = self.loss_data(input_pts=input_pts, output_pts=output_pts)
|
||||
condition_loss.append(loss_.as_subclass(torch.Tensor))
|
||||
loss = sum(condition_loss)
|
||||
self.log('val_loss', loss, prog_bar=True, logger=True,
|
||||
batch_size=self.get_batch_size(batch), sync_dist=True)
|
||||
|
||||
|
||||
def test_step(self, batch, batch_idx) -> STEP_OUTPUT:
|
||||
"""
|
||||
Solver test step.
|
||||
"""
|
||||
|
||||
raise NotImplementedError("Test step not implemented yet.")
|
||||
|
||||
def loss_data(self, input_pts, output_pts):
|
||||
"""
|
||||
The data loss for the Supervised solver. It computes the loss between
|
||||
|
||||
@@ -1,20 +1,21 @@
|
||||
""" Trainer module. """
|
||||
|
||||
import warnings
|
||||
import torch
|
||||
import pytorch_lightning
|
||||
import lightning
|
||||
from .utils import check_consistency
|
||||
from .data import PinaDataModule
|
||||
from .solvers.solver import SolverInterface
|
||||
|
||||
|
||||
class Trainer(pytorch_lightning.Trainer):
|
||||
class Trainer(lightning.pytorch.Trainer):
|
||||
|
||||
def __init__(self,
|
||||
solver,
|
||||
batch_size=None,
|
||||
train_size=.7,
|
||||
test_size=.2,
|
||||
eval_size=.1,
|
||||
val_size=.1,
|
||||
predict_size=.0,
|
||||
**kwargs):
|
||||
"""
|
||||
PINA Trainer class for costumizing every aspect of training via flags.
|
||||
@@ -39,11 +40,13 @@ class Trainer(pytorch_lightning.Trainer):
|
||||
check_consistency(batch_size, int)
|
||||
self.train_size = train_size
|
||||
self.test_size = test_size
|
||||
self.eval_size = eval_size
|
||||
self.val_size = val_size
|
||||
self.predict_size = predict_size
|
||||
self.solver = solver
|
||||
self.batch_size = batch_size
|
||||
self._create_loader()
|
||||
self._move_to_device()
|
||||
self.data_module = None
|
||||
self._create_loader()
|
||||
|
||||
def _move_to_device(self):
|
||||
device = self._accelerator_connector._parallel_devices[0]
|
||||
@@ -64,34 +67,34 @@ class Trainer(pytorch_lightning.Trainer):
|
||||
if not self.solver.problem.collector.full:
|
||||
error_message = '\n'.join([
|
||||
f"""{" " * 13} ---> Condition {key} {"sampled" if value else
|
||||
"not sampled"}""" for key, value in
|
||||
"not sampled"}""" for key, value in
|
||||
self._solver.problem.collector._is_conditions_ready.items()
|
||||
])
|
||||
raise RuntimeError('Cannot create Trainer if not all conditions '
|
||||
'are sampled. The Trainer got the following:\n'
|
||||
f'{error_message}')
|
||||
devices = self._accelerator_connector._parallel_devices
|
||||
|
||||
if len(devices) > 1:
|
||||
raise RuntimeError("Parallel training is not supported yet.")
|
||||
|
||||
device = devices[0]
|
||||
|
||||
data_module = PinaDataModule(problem=self.solver.problem,
|
||||
device=device,
|
||||
train_size=self.train_size,
|
||||
test_size=self.test_size,
|
||||
val_size=self.eval_size)
|
||||
data_module.setup()
|
||||
self._loader = data_module.train_dataloader()
|
||||
self.data_module = PinaDataModule(collector=self.solver.problem.collector,
|
||||
train_size=self.train_size,
|
||||
test_size=self.test_size,
|
||||
val_size=self.val_size,
|
||||
predict_size=self.predict_size,
|
||||
batch_size=self.batch_size,)
|
||||
|
||||
def train(self, **kwargs):
|
||||
"""
|
||||
Train the solver method.
|
||||
"""
|
||||
return super().fit(self.solver,
|
||||
train_dataloaders=self._loader,
|
||||
**kwargs)
|
||||
datamodule=self.data_module,
|
||||
**kwargs)
|
||||
|
||||
def test(self, **kwargs):
|
||||
"""
|
||||
Test the solver method.
|
||||
"""
|
||||
return super().test(self.solver,
|
||||
datamodule=self.data_module,
|
||||
**kwargs)
|
||||
|
||||
@property
|
||||
def solver(self):
|
||||
|
||||
3
setup.py
3
setup.py
@@ -15,7 +15,8 @@ VERSION = meta['__version__']
|
||||
KEYWORDS = 'machine-learning deep-learning modeling pytorch ode neural-networks differential-equations pde hacktoberfest pinn physics-informed physics-informed-neural-networks neural-operators equation-learning lightining'
|
||||
|
||||
REQUIRED = [
|
||||
'numpy', 'matplotlib', 'torch', 'lightning', 'pytorch_lightning', 'torch_geometric', 'torch-cluster'
|
||||
'numpy', 'matplotlib', 'torch', 'lightning', 'torch_geometric',
|
||||
'torch-cluster', 'pytorch_lightning',
|
||||
]
|
||||
|
||||
EXTRAS = {
|
||||
|
||||
@@ -1,227 +0,0 @@
|
||||
import math
|
||||
import torch
|
||||
from pina.data import SamplePointDataset, SupervisedDataset, PinaDataModule, \
|
||||
UnsupervisedDataset
|
||||
from pina.data import PinaDataLoader
|
||||
from pina import LabelTensor, Condition
|
||||
from pina.equation import Equation
|
||||
from pina.domain import CartesianDomain
|
||||
from pina.problem import SpatialProblem, AbstractProblem
|
||||
from pina.operators import laplacian
|
||||
from pina.equation.equation_factory import FixedValue
|
||||
from pina.graph import Graph
|
||||
|
||||
|
||||
def laplace_equation(input_, output_):
|
||||
force_term = (torch.sin(input_.extract(['x']) * torch.pi) *
|
||||
torch.sin(input_.extract(['y']) * torch.pi))
|
||||
delta_u = laplacian(output_.extract(['u']), input_)
|
||||
return delta_u - force_term
|
||||
|
||||
|
||||
my_laplace = Equation(laplace_equation)
|
||||
in_ = LabelTensor(torch.tensor([[0., 1.]]), ['x', 'y'])
|
||||
out_ = LabelTensor(torch.tensor([[0.]]), ['u'])
|
||||
in2_ = LabelTensor(torch.rand(60, 2), ['x', 'y'])
|
||||
out2_ = LabelTensor(torch.rand(60, 1), ['u'])
|
||||
|
||||
|
||||
class Poisson(SpatialProblem):
|
||||
output_variables = ['u']
|
||||
spatial_domain = CartesianDomain({'x': [0, 1], 'y': [0, 1]})
|
||||
|
||||
conditions = {
|
||||
'gamma1':
|
||||
Condition(domain=CartesianDomain({
|
||||
'x': [0, 1],
|
||||
'y': 1
|
||||
}),
|
||||
equation=FixedValue(0.0)),
|
||||
'gamma2':
|
||||
Condition(domain=CartesianDomain({
|
||||
'x': [0, 1],
|
||||
'y': 0
|
||||
}),
|
||||
equation=FixedValue(0.0)),
|
||||
'gamma3':
|
||||
Condition(domain=CartesianDomain({
|
||||
'x': 1,
|
||||
'y': [0, 1]
|
||||
}),
|
||||
equation=FixedValue(0.0)),
|
||||
'gamma4':
|
||||
Condition(domain=CartesianDomain({
|
||||
'x': 0,
|
||||
'y': [0, 1]
|
||||
}),
|
||||
equation=FixedValue(0.0)),
|
||||
'D':
|
||||
Condition(input_points=LabelTensor(torch.rand(size=(100, 2)),
|
||||
['x', 'y']),
|
||||
equation=my_laplace),
|
||||
'data':
|
||||
Condition(input_points=in_, output_points=out_),
|
||||
'data2':
|
||||
Condition(input_points=in2_, output_points=out2_),
|
||||
'unsupervised':
|
||||
Condition(
|
||||
input_points=LabelTensor(torch.rand(size=(45, 2)), ['x', 'y']),
|
||||
conditional_variables=LabelTensor(torch.ones(size=(45, 1)),
|
||||
['alpha']),
|
||||
),
|
||||
'unsupervised2':
|
||||
Condition(
|
||||
input_points=LabelTensor(torch.rand(size=(90, 2)), ['x', 'y']),
|
||||
conditional_variables=LabelTensor(torch.ones(size=(90, 1)),
|
||||
['alpha']),
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
|
||||
poisson = Poisson()
|
||||
poisson.discretise_domain(10, 'grid', locations=boundaries)
|
||||
|
||||
|
||||
def test_sample():
|
||||
sample_dataset = SamplePointDataset(poisson, device='cpu')
|
||||
assert len(sample_dataset) == 140
|
||||
assert sample_dataset.input_points.shape == (140, 2)
|
||||
assert sample_dataset.input_points.labels == ['x', 'y']
|
||||
assert sample_dataset.condition_indices.dtype == torch.uint8
|
||||
assert sample_dataset.condition_indices.max() == torch.tensor(4)
|
||||
assert sample_dataset.condition_indices.min() == torch.tensor(0)
|
||||
|
||||
|
||||
def test_data():
|
||||
dataset = SupervisedDataset(poisson, device='cpu')
|
||||
assert len(dataset) == 61
|
||||
assert dataset['input_points'].shape == (61, 2)
|
||||
assert dataset.input_points.shape == (61, 2)
|
||||
assert dataset['input_points'].labels == ['x', 'y']
|
||||
assert dataset.input_points.labels == ['x', 'y']
|
||||
assert dataset.input_points[3:].shape == (58, 2)
|
||||
assert dataset.output_points[:3].labels == ['u']
|
||||
assert dataset.output_points.shape == (61, 1)
|
||||
assert dataset.output_points.labels == ['u']
|
||||
assert dataset.condition_indices.dtype == torch.uint8
|
||||
assert dataset.condition_indices.max() == torch.tensor(1)
|
||||
assert dataset.condition_indices.min() == torch.tensor(0)
|
||||
|
||||
|
||||
def test_unsupervised():
|
||||
dataset = UnsupervisedDataset(poisson, device='cpu')
|
||||
assert len(dataset) == 135
|
||||
assert dataset.input_points.shape == (135, 2)
|
||||
assert dataset.input_points.labels == ['x', 'y']
|
||||
assert dataset.input_points[3:].shape == (132, 2)
|
||||
|
||||
assert dataset.conditional_variables.shape == (135, 1)
|
||||
assert dataset.conditional_variables.labels == ['alpha']
|
||||
assert dataset.condition_indices.dtype == torch.uint8
|
||||
assert dataset.condition_indices.max() == torch.tensor(1)
|
||||
assert dataset.condition_indices.min() == torch.tensor(0)
|
||||
|
||||
|
||||
def test_data_module():
|
||||
data_module = PinaDataModule(poisson, device='cpu')
|
||||
data_module.setup()
|
||||
loader = data_module.train_dataloader()
|
||||
assert isinstance(loader, PinaDataLoader)
|
||||
assert isinstance(loader, PinaDataLoader)
|
||||
|
||||
data_module = PinaDataModule(poisson,
|
||||
device='cpu',
|
||||
batch_size=10,
|
||||
shuffle=False)
|
||||
data_module.setup()
|
||||
loader = data_module.train_dataloader()
|
||||
assert len(loader) == 24
|
||||
for i in loader:
|
||||
assert len(i) <= 10
|
||||
len_ref = sum(
|
||||
[math.ceil(len(dataset) * 0.7) for dataset in data_module.datasets])
|
||||
len_real = sum(
|
||||
[len(dataset) for dataset in data_module.splits['train'].values()])
|
||||
assert len_ref == len_real
|
||||
|
||||
supervised_dataset = SupervisedDataset(poisson, device='cpu')
|
||||
data_module = PinaDataModule(poisson,
|
||||
device='cpu',
|
||||
batch_size=10,
|
||||
shuffle=False,
|
||||
datasets=[supervised_dataset])
|
||||
data_module.setup()
|
||||
loader = data_module.train_dataloader()
|
||||
for batch in loader:
|
||||
assert len(batch) <= 10
|
||||
|
||||
physics_dataset = SamplePointDataset(poisson, device='cpu')
|
||||
data_module = PinaDataModule(poisson,
|
||||
device='cpu',
|
||||
batch_size=10,
|
||||
shuffle=False,
|
||||
datasets=[physics_dataset])
|
||||
data_module.setup()
|
||||
loader = data_module.train_dataloader()
|
||||
for batch in loader:
|
||||
assert len(batch) <= 10
|
||||
|
||||
unsupervised_dataset = UnsupervisedDataset(poisson, device='cpu')
|
||||
data_module = PinaDataModule(poisson,
|
||||
device='cpu',
|
||||
batch_size=10,
|
||||
shuffle=False,
|
||||
datasets=[unsupervised_dataset])
|
||||
data_module.setup()
|
||||
loader = data_module.train_dataloader()
|
||||
for batch in loader:
|
||||
assert len(batch) <= 10
|
||||
|
||||
|
||||
def test_loader():
|
||||
data_module = PinaDataModule(poisson, device='cpu', batch_size=10)
|
||||
data_module.setup()
|
||||
loader = data_module.train_dataloader()
|
||||
assert isinstance(loader, PinaDataLoader)
|
||||
assert len(loader) == 24
|
||||
for i in loader:
|
||||
assert len(i) <= 10
|
||||
assert i.supervised.input_points.labels == ['x', 'y']
|
||||
assert i.physics.input_points.labels == ['x', 'y']
|
||||
assert i.unsupervised.input_points.labels == ['x', 'y']
|
||||
assert i.supervised.input_points.requires_grad == True
|
||||
assert i.physics.input_points.requires_grad == True
|
||||
assert i.unsupervised.input_points.requires_grad == True
|
||||
|
||||
|
||||
coordinates = LabelTensor(torch.rand((100, 100, 2)), labels=['x', 'y'])
|
||||
data = LabelTensor(torch.rand((100, 100, 3)), labels=['ux', 'uy', 'p'])
|
||||
|
||||
|
||||
class GraphProblem(AbstractProblem):
|
||||
output = LabelTensor(torch.rand((100, 3)), labels=['ux', 'uy', 'p'])
|
||||
input = [
|
||||
Graph.build('radius',
|
||||
nodes_coordinates=coordinates[i, :, :],
|
||||
nodes_data=data[i, :, :],
|
||||
radius=0.2) for i in range(100)
|
||||
]
|
||||
output_variables = ['u']
|
||||
|
||||
conditions = {
|
||||
'graph_data': Condition(input_points=input, output_points=output)
|
||||
}
|
||||
|
||||
|
||||
graph_problem = GraphProblem()
|
||||
|
||||
|
||||
def test_loader_graph():
|
||||
data_module = PinaDataModule(graph_problem, device='cpu', batch_size=10)
|
||||
data_module.setup()
|
||||
loader = data_module.train_dataloader()
|
||||
for i in loader:
|
||||
assert len(i) <= 10
|
||||
assert isinstance(i.supervised.input_points, list)
|
||||
assert all(isinstance(x, Graph) for x in i.supervised.input_points)
|
||||
@@ -114,5 +114,5 @@ def test_slice():
|
||||
assert torch.allclose(tensor_view2, data[3])
|
||||
|
||||
tensor_view3 = tensor[:, 2]
|
||||
assert tensor_view3.labels == labels[2]
|
||||
assert tensor_view3.labels == [labels[2]]
|
||||
assert torch.allclose(tensor_view3, data[:, 2].reshape(-1, 1))
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import torch
|
||||
|
||||
from pina.problem import SpatialProblem, InverseProblem
|
||||
from pina.operators import laplacian
|
||||
from pina.domain import CartesianDomain
|
||||
@@ -9,7 +8,7 @@ from pina.trainer import Trainer
|
||||
from pina.model import FeedForward
|
||||
from pina.equation.equation import Equation
|
||||
from pina.equation.equation_factory import FixedValue
|
||||
from pina.loss.loss_interface import LpLoss
|
||||
from pina.loss import LpLoss
|
||||
|
||||
|
||||
def laplace_equation(input_, output_):
|
||||
@@ -54,22 +53,22 @@ class InversePoisson(SpatialProblem, InverseProblem):
|
||||
|
||||
# define the conditions for the loss (boundary conditions, equation, data)
|
||||
conditions = {
|
||||
'gamma1': Condition(location=CartesianDomain({'x': [x_min, x_max],
|
||||
'gamma1': Condition(domain=CartesianDomain({'x': [x_min, x_max],
|
||||
'y': y_max}),
|
||||
equation=FixedValue(0.0, components=['u'])),
|
||||
'gamma2': Condition(location=CartesianDomain(
|
||||
'gamma2': Condition(domain=CartesianDomain(
|
||||
{'x': [x_min, x_max], 'y': y_min
|
||||
}),
|
||||
equation=FixedValue(0.0, components=['u'])),
|
||||
'gamma3': Condition(location=CartesianDomain(
|
||||
'gamma3': Condition(domain=CartesianDomain(
|
||||
{'x': x_max, 'y': [y_min, y_max]
|
||||
}),
|
||||
equation=FixedValue(0.0, components=['u'])),
|
||||
'gamma4': Condition(location=CartesianDomain(
|
||||
'gamma4': Condition(domain=CartesianDomain(
|
||||
{'x': x_min, 'y': [y_min, y_max]
|
||||
}),
|
||||
equation=FixedValue(0.0, components=['u'])),
|
||||
'D': Condition(location=CartesianDomain(
|
||||
'D': Condition(domain=CartesianDomain(
|
||||
{'x': [x_min, x_max], 'y': [y_min, y_max]
|
||||
}),
|
||||
equation=Equation(laplace_equation)),
|
||||
@@ -84,16 +83,16 @@ class Poisson(SpatialProblem):
|
||||
|
||||
conditions = {
|
||||
'gamma1': Condition(
|
||||
location=CartesianDomain({'x': [0, 1], 'y': 1}),
|
||||
domain=CartesianDomain({'x': [0, 1], 'y': 1}),
|
||||
equation=FixedValue(0.0)),
|
||||
'gamma2': Condition(
|
||||
location=CartesianDomain({'x': [0, 1], 'y': 0}),
|
||||
domain=CartesianDomain({'x': [0, 1], 'y': 0}),
|
||||
equation=FixedValue(0.0)),
|
||||
'gamma3': Condition(
|
||||
location=CartesianDomain({'x': 1, 'y': [0, 1]}),
|
||||
domain=CartesianDomain({'x': 1, 'y': [0, 1]}),
|
||||
equation=FixedValue(0.0)),
|
||||
'gamma4': Condition(
|
||||
location=CartesianDomain({'x': 0, 'y': [0, 1]}),
|
||||
domain=CartesianDomain({'x': 0, 'y': [0, 1]}),
|
||||
equation=FixedValue(0.0)),
|
||||
'D': Condition(
|
||||
input_points=LabelTensor(torch.rand(size=(100, 2)), ['x', 'y']),
|
||||
@@ -112,7 +111,6 @@ class Poisson(SpatialProblem):
|
||||
|
||||
truth_solution = poisson_sol
|
||||
|
||||
|
||||
class myFeature(torch.nn.Module):
|
||||
"""
|
||||
Feature: sin(x)
|
||||
@@ -158,21 +156,35 @@ def test_train_cpu():
|
||||
pinn = PINN(problem = poisson_problem, model=model,
|
||||
extra_features=None, loss=LpLoss())
|
||||
trainer = Trainer(solver=pinn, max_epochs=1,
|
||||
accelerator='cpu', batch_size=20)
|
||||
trainer.train()
|
||||
accelerator='cpu', batch_size=20, val_size=0., train_size=1., test_size=0.)
|
||||
|
||||
def test_log():
|
||||
poisson_problem.discretise_domain(100)
|
||||
solver = PINN(problem = poisson_problem, model=model,
|
||||
extra_features=None, loss=LpLoss())
|
||||
trainer = Trainer(solver, max_epochs=2, accelerator='cpu')
|
||||
def test_train_load():
|
||||
tmpdir = "tests/tmp_load"
|
||||
poisson_problem = Poisson()
|
||||
boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
|
||||
n = 10
|
||||
poisson_problem.discretise_domain(n, 'grid', locations=boundaries)
|
||||
pinn = PINN(problem=poisson_problem,
|
||||
model=model,
|
||||
extra_features=None,
|
||||
loss=LpLoss())
|
||||
trainer = Trainer(solver=pinn,
|
||||
max_epochs=15,
|
||||
accelerator='cpu',
|
||||
default_root_dir=tmpdir)
|
||||
trainer.train()
|
||||
# assert the logged metrics are correct
|
||||
logged_metrics = sorted(list(trainer.logged_metrics.keys()))
|
||||
total_metrics = sorted(
|
||||
list([key + '_loss' for key in poisson_problem.conditions.keys()])
|
||||
+ ['mean_loss'])
|
||||
assert logged_metrics == total_metrics
|
||||
new_pinn = PINN.load_from_checkpoint(
|
||||
f'{tmpdir}/lightning_logs/version_0/checkpoints/epoch=14-step=15.ckpt',
|
||||
problem = poisson_problem, model=model)
|
||||
test_pts = CartesianDomain({'x': [0, 1], 'y': [0, 1]}).sample(10)
|
||||
assert new_pinn.forward(test_pts).extract(['u']).shape == (10, 1)
|
||||
assert new_pinn.forward(test_pts).extract(
|
||||
['u']).shape == pinn.forward(test_pts).extract(['u']).shape
|
||||
torch.testing.assert_close(
|
||||
new_pinn.forward(test_pts).extract(['u']),
|
||||
pinn.forward(test_pts).extract(['u']))
|
||||
import shutil
|
||||
shutil.rmtree(tmpdir)
|
||||
|
||||
def test_train_restore():
|
||||
tmpdir = "tests/tmp_restore"
|
||||
@@ -192,36 +204,7 @@ def test_train_restore():
|
||||
ntrainer = Trainer(solver=pinn, max_epochs=15, accelerator='cpu')
|
||||
t = ntrainer.train(
|
||||
ckpt_path=f'{tmpdir}/lightning_logs/version_0/'
|
||||
'checkpoints/epoch=4-step=10.ckpt')
|
||||
import shutil
|
||||
shutil.rmtree(tmpdir)
|
||||
|
||||
|
||||
def test_train_load():
|
||||
tmpdir = "tests/tmp_load"
|
||||
poisson_problem = Poisson()
|
||||
boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
|
||||
n = 10
|
||||
poisson_problem.discretise_domain(n, 'grid', locations=boundaries)
|
||||
pinn = PINN(problem=poisson_problem,
|
||||
model=model,
|
||||
extra_features=None,
|
||||
loss=LpLoss())
|
||||
trainer = Trainer(solver=pinn,
|
||||
max_epochs=15,
|
||||
accelerator='cpu',
|
||||
default_root_dir=tmpdir)
|
||||
trainer.train()
|
||||
new_pinn = PINN.load_from_checkpoint(
|
||||
f'{tmpdir}/lightning_logs/version_0/checkpoints/epoch=14-step=30.ckpt',
|
||||
problem = poisson_problem, model=model)
|
||||
test_pts = CartesianDomain({'x': [0, 1], 'y': [0, 1]}).sample(10)
|
||||
assert new_pinn.forward(test_pts).extract(['u']).shape == (10, 1)
|
||||
assert new_pinn.forward(test_pts).extract(
|
||||
['u']).shape == pinn.forward(test_pts).extract(['u']).shape
|
||||
torch.testing.assert_close(
|
||||
new_pinn.forward(test_pts).extract(['u']),
|
||||
pinn.forward(test_pts).extract(['u']))
|
||||
'checkpoints/epoch=4-step=5.ckpt')
|
||||
import shutil
|
||||
shutil.rmtree(tmpdir)
|
||||
|
||||
@@ -229,36 +212,24 @@ def test_train_inverse_problem_cpu():
|
||||
poisson_problem = InversePoisson()
|
||||
boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4', 'D']
|
||||
n = 100
|
||||
poisson_problem.discretise_domain(n, 'random', locations=boundaries)
|
||||
poisson_problem.discretise_domain(n, 'random', locations=boundaries,
|
||||
variables=['x', 'y'])
|
||||
pinn = PINN(problem = poisson_problem, model=model,
|
||||
extra_features=None, loss=LpLoss())
|
||||
trainer = Trainer(solver=pinn, max_epochs=1,
|
||||
accelerator='cpu', batch_size=20)
|
||||
trainer.train()
|
||||
|
||||
|
||||
# # TODO does not currently work
|
||||
# def test_train_inverse_problem_restore():
|
||||
# tmpdir = "tests/tmp_restore_inv"
|
||||
# poisson_problem = InversePoisson()
|
||||
# boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4', 'D']
|
||||
# n = 100
|
||||
# poisson_problem.discretise_domain(n, 'random', locations=boundaries)
|
||||
# pinn = PINN(problem=poisson_problem,
|
||||
# model=model,
|
||||
# extra_features=None,
|
||||
# loss=LpLoss())
|
||||
# trainer = Trainer(solver=pinn,
|
||||
# max_epochs=5,
|
||||
# accelerator='cpu',
|
||||
# default_root_dir=tmpdir)
|
||||
# trainer.train()
|
||||
# ntrainer = Trainer(solver=pinn, max_epochs=5, accelerator='cpu')
|
||||
# t = ntrainer.train(
|
||||
# ckpt_path=f'{tmpdir}/lightning_logs/version_0/checkpoints/epoch=4-step=10.ckpt')
|
||||
# import shutil
|
||||
# shutil.rmtree(tmpdir)
|
||||
|
||||
def test_train_extra_feats_cpu():
|
||||
poisson_problem = Poisson()
|
||||
boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
|
||||
n = 10
|
||||
poisson_problem.discretise_domain(n, 'grid', locations=boundaries)
|
||||
pinn = PINN(problem=poisson_problem,
|
||||
model=model_extra_feats,
|
||||
extra_features=extra_feats)
|
||||
trainer = Trainer(solver=pinn, max_epochs=5, accelerator='cpu')
|
||||
trainer.train()
|
||||
|
||||
def test_train_inverse_problem_load():
|
||||
tmpdir = "tests/tmp_load_inv"
|
||||
@@ -276,7 +247,7 @@ def test_train_inverse_problem_load():
|
||||
default_root_dir=tmpdir)
|
||||
trainer.train()
|
||||
new_pinn = PINN.load_from_checkpoint(
|
||||
f'{tmpdir}/lightning_logs/version_0/checkpoints/epoch=14-step=30.ckpt',
|
||||
f'{tmpdir}/lightning_logs/version_0/checkpoints/epoch=14-step=15.ckpt',
|
||||
problem = poisson_problem, model=model)
|
||||
test_pts = CartesianDomain({'x': [0, 1], 'y': [0, 1]}).sample(10)
|
||||
assert new_pinn.forward(test_pts).extract(['u']).shape == (10, 1)
|
||||
@@ -287,159 +258,3 @@ def test_train_inverse_problem_load():
|
||||
pinn.forward(test_pts).extract(['u']))
|
||||
import shutil
|
||||
shutil.rmtree(tmpdir)
|
||||
|
||||
# # TODO fix asap. Basically sampling few variables
|
||||
# # works only if both variables are in a range.
|
||||
# # if one is fixed and the other not, this will
|
||||
# # not work. This test also needs to be fixed and
|
||||
# # insert in test problem not in test pinn.
|
||||
# def test_train_cpu_sampling_few_vars():
|
||||
# poisson_problem = Poisson()
|
||||
# boundaries = ['gamma1', 'gamma2', 'gamma3']
|
||||
# n = 10
|
||||
# poisson_problem.discretise_domain(n, 'grid', locations=boundaries)
|
||||
# poisson_problem.discretise_domain(n, 'random', locations=['gamma4'], variables=['x'])
|
||||
# poisson_problem.discretise_domain(n, 'random', locations=['gamma4'], variables=['y'])
|
||||
# pinn = PINN(problem = poisson_problem, model=model, extra_features=None, loss=LpLoss())
|
||||
# trainer = Trainer(solver=pinn, kwargs={'max_epochs' : 5, 'accelerator':'cpu'})
|
||||
# trainer.train()
|
||||
|
||||
|
||||
def test_train_extra_feats_cpu():
|
||||
poisson_problem = Poisson()
|
||||
boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
|
||||
n = 10
|
||||
poisson_problem.discretise_domain(n, 'grid', locations=boundaries)
|
||||
pinn = PINN(problem=poisson_problem,
|
||||
model=model_extra_feats,
|
||||
extra_features=extra_feats)
|
||||
trainer = Trainer(solver=pinn, max_epochs=5, accelerator='cpu')
|
||||
trainer.train()
|
||||
|
||||
|
||||
# TODO, fix GitHub actions to run also on GPU
|
||||
# def test_train_gpu():
|
||||
# poisson_problem = Poisson()
|
||||
# boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
|
||||
# n = 10
|
||||
# poisson_problem.discretise_domain(n, 'grid', locations=boundaries)
|
||||
# pinn = PINN(problem = poisson_problem, model=model, extra_features=None, loss=LpLoss())
|
||||
# trainer = Trainer(solver=pinn, kwargs={'max_epochs' : 5, 'accelerator':'gpu'})
|
||||
# trainer.train()
|
||||
|
||||
# def test_train_gpu(): #TODO fix ASAP
|
||||
# poisson_problem = Poisson()
|
||||
# boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
|
||||
# n = 10
|
||||
# poisson_problem.discretise_domain(n, 'grid', locations=boundaries)
|
||||
# poisson_problem.conditions.pop('data') # The input/output pts are allocated on cpu
|
||||
# pinn = PINN(problem = poisson_problem, model=model, extra_features=None, loss=LpLoss())
|
||||
# trainer = Trainer(solver=pinn, kwargs={'max_epochs' : 5, 'accelerator':'gpu'})
|
||||
# trainer.train()
|
||||
|
||||
# def test_train_2():
|
||||
# boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
|
||||
# n = 10
|
||||
# expected_keys = [[], list(range(0, 50, 3))]
|
||||
# param = [0, 3]
|
||||
# for i, truth_key in zip(param, expected_keys):
|
||||
# pinn = PINN(problem, model)
|
||||
# pinn.discretise_domain(n, 'grid', locations=boundaries)
|
||||
# pinn.discretise_domain(n, 'grid', locations=['D'])
|
||||
# pinn.train(50, save_loss=i)
|
||||
# assert list(pinn.history_loss.keys()) == truth_key
|
||||
|
||||
|
||||
# def test_train_extra_feats():
|
||||
# pinn = PINN(problem, model_extra_feat, [myFeature()])
|
||||
# boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
|
||||
# n = 10
|
||||
# pinn.discretise_domain(n, 'grid', locations=boundaries)
|
||||
# pinn.discretise_domain(n, 'grid', locations=['D'])
|
||||
# pinn.train(5)
|
||||
|
||||
|
||||
# def test_train_2_extra_feats():
|
||||
# boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
|
||||
# n = 10
|
||||
# expected_keys = [[], list(range(0, 50, 3))]
|
||||
# param = [0, 3]
|
||||
# for i, truth_key in zip(param, expected_keys):
|
||||
# pinn = PINN(problem, model_extra_feat, [myFeature()])
|
||||
# pinn.discretise_domain(n, 'grid', locations=boundaries)
|
||||
# pinn.discretise_domain(n, 'grid', locations=['D'])
|
||||
# pinn.train(50, save_loss=i)
|
||||
# assert list(pinn.history_loss.keys()) == truth_key
|
||||
|
||||
|
||||
# def test_train_with_optimizer_kwargs():
|
||||
# boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
|
||||
# n = 10
|
||||
# expected_keys = [[], list(range(0, 50, 3))]
|
||||
# param = [0, 3]
|
||||
# for i, truth_key in zip(param, expected_keys):
|
||||
# pinn = PINN(problem, model, optimizer_kwargs={'lr' : 0.3})
|
||||
# pinn.discretise_domain(n, 'grid', locations=boundaries)
|
||||
# pinn.discretise_domain(n, 'grid', locations=['D'])
|
||||
# pinn.train(50, save_loss=i)
|
||||
# assert list(pinn.history_loss.keys()) == truth_key
|
||||
|
||||
|
||||
# def test_train_with_lr_scheduler():
|
||||
# boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
|
||||
# n = 10
|
||||
# expected_keys = [[], list(range(0, 50, 3))]
|
||||
# param = [0, 3]
|
||||
# for i, truth_key in zip(param, expected_keys):
|
||||
# pinn = PINN(
|
||||
# problem,
|
||||
# model,
|
||||
# lr_scheduler_type=torch.optim.lr_scheduler.CyclicLR,
|
||||
# lr_scheduler_kwargs={'base_lr' : 0.1, 'max_lr' : 0.3, 'cycle_momentum': False}
|
||||
# )
|
||||
# pinn.discretise_domain(n, 'grid', locations=boundaries)
|
||||
# pinn.discretise_domain(n, 'grid', locations=['D'])
|
||||
# pinn.train(50, save_loss=i)
|
||||
# assert list(pinn.history_loss.keys()) == truth_key
|
||||
|
||||
|
||||
# # def test_train_batch():
|
||||
# # pinn = PINN(problem, model, batch_size=6)
|
||||
# # boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
|
||||
# # n = 10
|
||||
# # pinn.discretise_domain(n, 'grid', locations=boundaries)
|
||||
# # pinn.discretise_domain(n, 'grid', locations=['D'])
|
||||
# # pinn.train(5)
|
||||
|
||||
|
||||
# # def test_train_batch_2():
|
||||
# # boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
|
||||
# # n = 10
|
||||
# # expected_keys = [[], list(range(0, 50, 3))]
|
||||
# # param = [0, 3]
|
||||
# # for i, truth_key in zip(param, expected_keys):
|
||||
# # pinn = PINN(problem, model, batch_size=6)
|
||||
# # pinn.discretise_domain(n, 'grid', locations=boundaries)
|
||||
# # pinn.discretise_domain(n, 'grid', locations=['D'])
|
||||
# # pinn.train(50, save_loss=i)
|
||||
# # assert list(pinn.history_loss.keys()) == truth_key
|
||||
|
||||
|
||||
# if torch.cuda.is_available():
|
||||
|
||||
# # def test_gpu_train():
|
||||
# # pinn = PINN(problem, model, batch_size=20, device='cuda')
|
||||
# # boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
|
||||
# # n = 100
|
||||
# # pinn.discretise_domain(n, 'grid', locations=boundaries)
|
||||
# # pinn.discretise_domain(n, 'grid', locations=['D'])
|
||||
# # pinn.train(5)
|
||||
|
||||
# def test_gpu_train_nobatch():
|
||||
# pinn = PINN(problem, model, batch_size=None, device='cuda')
|
||||
# boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
|
||||
# n = 100
|
||||
# pinn.discretise_domain(n, 'grid', locations=boundaries)
|
||||
# pinn.discretise_domain(n, 'grid', locations=['D'])
|
||||
# pinn.train(5)
|
||||
|
||||
|
||||
@@ -121,7 +121,7 @@ def test_train_cpu():
|
||||
batch_size=5,
|
||||
train_size=1,
|
||||
test_size=0.,
|
||||
eval_size=0.)
|
||||
val_size=0.)
|
||||
trainer.train()
|
||||
test_train_cpu()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user