Lightining update (#104)
* multiple functions for version 0.0 * lightining update * minor changes * data pinn loss added --------- Co-authored-by: Nicola Demo <demo.nicola@gmail.com> Co-authored-by: Dario Coscia <dariocoscia@cli-10-110-3-125.WIFIeduroamSTUD.units.it> Co-authored-by: Dario Coscia <dariocoscia@Dario-Coscia.station> Co-authored-by: Dario Coscia <dariocoscia@Dario-Coscia.local> Co-authored-by: Dario Coscia <dariocoscia@192.168.1.38>
This commit is contained in:
committed by
Nicola Demo
parent
0e3625de80
commit
63fd068988
@@ -1,5 +1,6 @@
|
||||
__all__ = [
|
||||
'PINN',
|
||||
'Trainer',
|
||||
'LabelTensor',
|
||||
'Plotter',
|
||||
'Condition',
|
||||
@@ -10,6 +11,7 @@ __all__ = [
|
||||
from .meta import *
|
||||
from .label_tensor import LabelTensor
|
||||
from .pinn import PINN
|
||||
from .trainer import Trainer
|
||||
from .plotter import Plotter
|
||||
from .condition import Condition
|
||||
from .geometry import Location
|
||||
|
||||
@@ -117,6 +117,7 @@ class LabelTensorDataset(Dataset):
|
||||
def __len__(self):
|
||||
return max([len(getattr(self, label)) for label in self.labels])
|
||||
|
||||
# TODO: working also for datapoints
|
||||
class DummyLoader:
|
||||
|
||||
def __init__(self, data) -> None:
|
||||
|
||||
@@ -88,6 +88,8 @@ class LabelTensor(torch.Tensor):
|
||||
|
||||
self._labels = labels # assign the label
|
||||
|
||||
# TODO remove try/ except thing IMPORTANT
|
||||
# make the label None of default
|
||||
def clone(self, *args, **kwargs):
|
||||
"""
|
||||
Clone the LabelTensor. For more details, see
|
||||
@@ -96,7 +98,12 @@ class LabelTensor(torch.Tensor):
|
||||
:return: a copy of the tensor
|
||||
:rtype: LabelTensor
|
||||
"""
|
||||
return LabelTensor(super().clone(*args, **kwargs), self.labels)
|
||||
try:
|
||||
out = LabelTensor(super().clone(*args, **kwargs), self.labels)
|
||||
except:
|
||||
out = super().clone(*args, **kwargs)
|
||||
|
||||
return out
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
"""
|
||||
|
||||
127
pina/loss.py
Normal file
127
pina/loss.py
Normal file
@@ -0,0 +1,127 @@
|
||||
""" Module for EquationInterface class """
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from torch.nn.modules.loss import _Loss
|
||||
import torch
|
||||
from .utils import check_consistency
|
||||
|
||||
__all__ = ['LpLoss']
|
||||
|
||||
class LossInterface(_Loss, metaclass=ABCMeta):
|
||||
"""
|
||||
The abstract `LossInterface` class. All the class defining a PINA Loss
|
||||
should be inheritied from this class.
|
||||
"""
|
||||
|
||||
def __init__(self, reduction = 'mean'):
|
||||
"""
|
||||
:param str reduction: Specifies the reduction to apply to the output:
|
||||
``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction
|
||||
will be applied, ``'mean'``: the sum of the output will be divided
|
||||
by the number of elements in the output, ``'sum'``: the output will
|
||||
be summed. Note: :attr:`size_average` and :attr:`reduce` are in the
|
||||
process of being deprecated, and in the meantime, specifying either of
|
||||
those two args will override :attr:`reduction`. Default: ``'mean'``.
|
||||
"""
|
||||
super().__init__(reduction=reduction, size_average=None, reduce=None)
|
||||
|
||||
@abstractmethod
|
||||
def forward(self):
|
||||
pass
|
||||
|
||||
def _reduction(self, loss):
|
||||
"""Simple helper function to check reduction
|
||||
|
||||
:param reduction: Specifies the reduction to apply to the output:
|
||||
``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction
|
||||
will be applied, ``'mean'``: the sum of the output will be divided
|
||||
by the number of elements in the output, ``'sum'``: the output will
|
||||
be summed. Note: :attr:`size_average` and :attr:`reduce` are in the
|
||||
process of being deprecated, and in the meantime, specifying either of
|
||||
those two args will override :attr:`reduction`. Default: ``'mean'``.
|
||||
:type reduction: str, optional
|
||||
:param loss: Loss tensor for each element.
|
||||
:type loss: torch.Tensor
|
||||
:return: Reduced loss.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
if self.reduction == "none":
|
||||
ret = loss
|
||||
elif self.reduction == "mean":
|
||||
ret = torch.mean(loss, keepdim=True, dim=-1)
|
||||
elif self.reduction == "sum":
|
||||
ret = torch.sum(loss, keepdim=True, dim=-1)
|
||||
else:
|
||||
raise ValueError(self.reduction + " is not valid")
|
||||
return ret
|
||||
|
||||
class LpLoss(LossInterface):
|
||||
"""
|
||||
The Lp loss implementation class. Creates a criterion that measures
|
||||
the Lp error between each element in the input :math:`x` and
|
||||
target :math:`y`.
|
||||
|
||||
The unreduced (i.e. with :attr:`reduction` set to ``'none'``) loss can
|
||||
be described as:
|
||||
|
||||
.. math::
|
||||
\ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad
|
||||
l_n = \left| x_n - y_n \right|^p,
|
||||
|
||||
If ``'relative'`` is set to true:
|
||||
|
||||
.. math::
|
||||
\ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad
|
||||
l_n = \left[\frac{\left| x_n - y_n \right|^p}{\left|y_n \right|^p}\right]^{1/p},
|
||||
|
||||
where :math:`N` is the batch size. If :attr:`reduction` is not ``'none'``
|
||||
(default ``'mean'``), then:
|
||||
|
||||
.. math::
|
||||
\ell(x, y) =
|
||||
\begin{cases}
|
||||
\operatorname{mean}(L), & \text{if reduction} = \text{`mean';}\\
|
||||
\operatorname{sum}(L), & \text{if reduction} = \text{`sum'.}
|
||||
\end{cases}
|
||||
|
||||
:math:`x` and :math:`y` are tensors of arbitrary shapes with a total
|
||||
of :math:`n` elements each.
|
||||
|
||||
The sum operation still operates over all the elements, and divides by :math:`n`.
|
||||
|
||||
The division by :math:`n` can be avoided if one sets ``reduction = 'sum'``.
|
||||
"""
|
||||
|
||||
def __init__(self, p=2, reduction = 'mean', relative = False):
|
||||
"""
|
||||
:param int p: Degree of Lp norm. It specifies the type of norm to
|
||||
be calculated. See :meth:`torch.linalg.norm` ```'ord'``` to
|
||||
see the possible degrees. Default 2 (euclidean norm).
|
||||
:param str reduction: Specifies the reduction to apply to the output:
|
||||
``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction
|
||||
will be applied, ``'mean'``: the sum of the output will be divided
|
||||
by the number of elements in the output, ``'sum'``: the output will
|
||||
be summed. Note: :attr:`size_average` and :attr:`reduce` are in the
|
||||
process of being deprecated, and in the meantime, specifying either of
|
||||
those two args will override :attr:`reduction`. Default: ``'mean'``.
|
||||
:param bool relative: Specifies if relative error should be computed.
|
||||
"""
|
||||
super().__init__(reduction=reduction)
|
||||
|
||||
# check consistency
|
||||
check_consistency(p, (str,int,float), 'degree p')
|
||||
self.p = p
|
||||
check_consistency(relative, bool, 'relative')
|
||||
self.relative = relative
|
||||
|
||||
def forward(self, input, target):
|
||||
"""Forward method for loss function.
|
||||
|
||||
:param torch.Tensor input: Input tensor from real data.
|
||||
:param torch.Tensor target: Model tensor output.
|
||||
:return: Loss evaluation.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
loss = torch.linalg.norm((input-target), ord=self.p, dim=-1)
|
||||
if self.relative:
|
||||
loss = loss / torch.linalg.norm(input, ord=self.p, dim=-1)
|
||||
return self._reduction(loss)
|
||||
@@ -2,10 +2,8 @@ __all__ = [
|
||||
'FeedForward',
|
||||
'MultiFeedForward',
|
||||
'DeepONet',
|
||||
'Network',
|
||||
]
|
||||
|
||||
from .feed_forward import FeedForward
|
||||
from .multi_feed_forward import MultiFeedForward
|
||||
from .deeponet import DeepONet
|
||||
from .network import Network
|
||||
|
||||
@@ -1,107 +1,47 @@
|
||||
import torch
|
||||
from pina.label_tensor import LabelTensor
|
||||
import torch.nn as nn
|
||||
from ..utils import check_consistency
|
||||
|
||||
|
||||
class Network(torch.nn.Module):
|
||||
"""The PINA implementation of any neural network.
|
||||
|
||||
:param torch.nn.Module model: the torch model of the network.
|
||||
:param list(str) input_variables: the list containing the labels
|
||||
corresponding to the input components of the model.
|
||||
:param list(str) output_variables: the list containing the labels
|
||||
corresponding to the components of the output computed by the model.
|
||||
:param torch.nn.Module extra_features: the additional input
|
||||
features to use as augmented input.
|
||||
|
||||
:Example:
|
||||
>>> class SimpleNet(nn.Module):
|
||||
>>> def __init__(self):
|
||||
>>> super().__init__()
|
||||
>>> self.layers = nn.Sequential(
|
||||
>>> nn.Linear(2, 20),
|
||||
>>> nn.Tanh(),
|
||||
>>> nn.Linear(20, 1)
|
||||
>>> )
|
||||
>>> def forward(self, x):
|
||||
>>> return self.layers(x)
|
||||
>>> net = SimpleNet()
|
||||
>>> input_variables = ['x', 'y']
|
||||
>>> output_variables =['u']
|
||||
>>> model_feat = Network(net, input_variables, output_variables)
|
||||
Network(
|
||||
(extra_features): Sequential()
|
||||
(model): Sequential(
|
||||
(0): Linear(in_features=2, out_features=20, bias=True)
|
||||
(1): Tanh()
|
||||
(2): Linear(in_features=20, out_features=1, bias=True)
|
||||
)
|
||||
)
|
||||
"""
|
||||
|
||||
def __init__(self, model, input_variables,
|
||||
output_variables, extra_features=None):
|
||||
def __init__(self, model, extra_features=None):
|
||||
super().__init__()
|
||||
|
||||
print('HHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHH')
|
||||
if extra_features is None:
|
||||
extra_features = []
|
||||
|
||||
self._extra_features = torch.nn.Sequential(*extra_features)
|
||||
# check model consistency
|
||||
check_consistency(model, nn.Module, 'torch model')
|
||||
self._model = model
|
||||
self._input_variables = input_variables
|
||||
self._output_variables = output_variables
|
||||
print(output_variables)
|
||||
|
||||
# check model and input/output
|
||||
self._check_consistency()
|
||||
# check consistency and assign extra fatures
|
||||
if extra_features is None:
|
||||
self._extra_features = []
|
||||
else:
|
||||
for feat in extra_features:
|
||||
check_consistency(feat, nn.Module, 'extra features')
|
||||
self._extra_features = nn.Sequential(*extra_features)
|
||||
|
||||
def _check_consistency(self):
|
||||
"""Checking the consistency of model with input and output variables
|
||||
|
||||
:raises ValueError: Error in constructing the PINA network
|
||||
"""
|
||||
try:
|
||||
pass
|
||||
# tmp = torch.rand((10, len(self._input_variables)))
|
||||
# tmp = LabelTensor(tmp, self._input_variables)
|
||||
# tmp = self.forward(tmp) # trying a forward pass
|
||||
# tmp = LabelTensor(tmp, self._output_variables)
|
||||
except:
|
||||
raise ValueError('Error in constructing the PINA network.'
|
||||
' Check compatibility of input/output'
|
||||
' variables shape with the torch model'
|
||||
' or check the correctness of the torch'
|
||||
' model itself.')
|
||||
# check model works with inputs
|
||||
# TODO
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward method for Network class
|
||||
"""
|
||||
Forward method for Network class. This class
|
||||
implements the standard forward method, and
|
||||
it adds the possibility to pass extra features.
|
||||
|
||||
:param torch.tensor x: input of the network
|
||||
:return torch.tensor: output of the network
|
||||
"""
|
||||
|
||||
x = x.extract(self._input_variables)
|
||||
|
||||
# extract features and append
|
||||
for feature in self._extra_features:
|
||||
x = x.append(feature(x))
|
||||
|
||||
output = self._model(x).as_subclass(LabelTensor)
|
||||
output.labels = self._output_variables
|
||||
|
||||
return output
|
||||
|
||||
@property
|
||||
def input_variables(self):
|
||||
return self._input_variables
|
||||
|
||||
@property
|
||||
def output_variables(self):
|
||||
return self._output_variables
|
||||
|
||||
@property
|
||||
def extra_features(self):
|
||||
return self._extra_features
|
||||
# perform forward pass
|
||||
return self._model(x)
|
||||
|
||||
@property
|
||||
def model(self):
|
||||
return self._model
|
||||
|
||||
@property
|
||||
def extra_features(self):
|
||||
return self._extra_features
|
||||
376
pina/pinn.py
376
pina/pinn.py
@@ -2,298 +2,96 @@
|
||||
import torch
|
||||
import torch.optim.lr_scheduler as lrs
|
||||
|
||||
from .problem import AbstractProblem
|
||||
from .model import Network
|
||||
|
||||
from .solver import SolverInterface
|
||||
from .label_tensor import LabelTensor
|
||||
from .utils import merge_tensors
|
||||
from .dataset import DummyLoader
|
||||
from .utils import check_consistency
|
||||
from .writer import Writer
|
||||
from .loss import LossInterface
|
||||
from torch.nn.modules.loss import _Loss
|
||||
|
||||
|
||||
torch.pi = torch.acos(torch.zeros(1)).item() * 2 # which is 3.1415927410125732
|
||||
|
||||
|
||||
class PINN(object):
|
||||
class PINN(SolverInterface):
|
||||
|
||||
def __init__(self,
|
||||
problem,
|
||||
model,
|
||||
extra_features=None,
|
||||
loss = torch.nn.MSELoss(),
|
||||
optimizer=torch.optim.Adam,
|
||||
optimizer_kwargs=None,
|
||||
lr=0.001,
|
||||
lr_scheduler_type=lrs.ConstantLR,
|
||||
lr_scheduler_kwargs={"factor": 1, "total_iters": 0},
|
||||
regularizer=0.00001,
|
||||
batch_size=None,
|
||||
dtype=torch.float32,
|
||||
device='cpu',
|
||||
writer=None,
|
||||
error_norm='mse'):
|
||||
optimizer_kwargs={'lr' : 0.001},
|
||||
scheduler=lrs.ConstantLR,
|
||||
scheduler_kwargs={"factor": 1, "total_iters": 0},
|
||||
):
|
||||
'''
|
||||
:param AbstractProblem problem: the formualation of the problem.
|
||||
:param torch.nn.Module model: the neural network model to use.
|
||||
:param torch.nn.Module extra_features: the additional input
|
||||
:param AbstractProblem problem: The formualation of the problem.
|
||||
:param torch.nn.Module model: The neural network model to use.
|
||||
:param torch.nn.Module loss: The loss function used as minimizer,
|
||||
default torch.nn.MSELoss().
|
||||
:param torch.nn.Module extra_features: The additional input
|
||||
features to use as augmented input.
|
||||
:param torch.optim.Optimizer optimizer: the neural network optimizer to
|
||||
:param torch.optim.Optimizer optimizer: The neural network optimizer to
|
||||
use; default is `torch.optim.Adam`.
|
||||
:param dict optimizer_kwargs: Optimizer constructor keyword args.
|
||||
:param float lr: the learning rate; default is 0.001.
|
||||
:param torch.optim.LRScheduler lr_scheduler_type: Learning
|
||||
:param float lr: The learning rate; default is 0.001.
|
||||
:param torch.optim.LRScheduler scheduler: Learning
|
||||
rate scheduler.
|
||||
:param dict lr_scheduler_kwargs: LR scheduler constructor keyword args.
|
||||
:param float regularizer: the coefficient for L2 regularizer term.
|
||||
:param type dtype: the data type to use for the model. Valid option are
|
||||
`torch.float32` and `torch.float64` (`torch.float16` only on GPU);
|
||||
default is `torch.float64`.
|
||||
:param str device: the device used for training; default 'cpu'
|
||||
option include 'cuda' if cuda is available.
|
||||
:param (str, int) error_norm: the loss function used as minimizer,
|
||||
default mean square error 'mse'. If string options include mean
|
||||
error 'me' and mean square error 'mse'. If int, the p-norm is
|
||||
calculated where p is specifined by the int input.
|
||||
:param int batch_size: batch size for the dataloader; default 5.
|
||||
:param dict scheduler_kwargs: LR scheduler constructor keyword args.
|
||||
'''
|
||||
super().__init__(model=model, problem=problem, extra_features=extra_features)
|
||||
|
||||
if dtype == torch.float64:
|
||||
raise NotImplementedError('only float for now')
|
||||
# check consistency
|
||||
check_consistency(optimizer, torch.optim.Optimizer, 'optimizer', subclass=True)
|
||||
check_consistency(optimizer_kwargs, dict, 'optimizer_kwargs')
|
||||
check_consistency(scheduler, lrs.LRScheduler, 'scheduler', subclass=True)
|
||||
check_consistency(scheduler_kwargs, dict, 'scheduler_kwargs')
|
||||
check_consistency(loss, (LossInterface, _Loss), 'loss', subclass=False)
|
||||
|
||||
self.problem = problem
|
||||
|
||||
# self._architecture = architecture if architecture else dict()
|
||||
# self._architecture['input_dimension'] = self.problem.domain_bound.shape[0]
|
||||
# self._architecture['output_dimension'] = len(self.problem.variables)
|
||||
# if hasattr(self.problem, 'params_domain'):
|
||||
# self._architecture['input_dimension'] += self.problem.params_domain.shape[0]
|
||||
|
||||
self.error_norm = error_norm
|
||||
|
||||
if device == 'cuda' and not torch.cuda.is_available():
|
||||
raise RuntimeError
|
||||
self.device = torch.device(device)
|
||||
|
||||
self.dtype = dtype
|
||||
self.history_loss = {}
|
||||
# assign variables
|
||||
self._optimizer = optimizer(self.model.parameters(), **optimizer_kwargs)
|
||||
self._scheduler = scheduler(self._optimizer, **scheduler_kwargs)
|
||||
self._loss = loss
|
||||
self._writer = Writer()
|
||||
|
||||
|
||||
self.model = Network(model=model,
|
||||
input_variables=problem.input_variables,
|
||||
output_variables=problem.output_variables,
|
||||
extra_features=extra_features)
|
||||
def forward(self, x):
|
||||
"""Forward pass implementation for the PINN
|
||||
solver.
|
||||
|
||||
self.model.to(dtype=self.dtype, device=self.device)
|
||||
|
||||
self.truth_values = {}
|
||||
self.input_pts = {}
|
||||
|
||||
self.trained_epoch = 0
|
||||
|
||||
from .writer import Writer
|
||||
if writer is None:
|
||||
writer = Writer()
|
||||
self.writer = writer
|
||||
|
||||
if not optimizer_kwargs:
|
||||
optimizer_kwargs = {}
|
||||
optimizer_kwargs['lr'] = lr
|
||||
self.optimizer = optimizer(
|
||||
self.model.parameters())#, weight_decay=regularizer, **optimizer_kwargs)
|
||||
#self._lr_scheduler = lr_scheduler_type(
|
||||
# self.optimizer, **lr_scheduler_kwargs)
|
||||
|
||||
self.batch_size = batch_size
|
||||
# self.data_set = PinaDataset(self)
|
||||
|
||||
@property
|
||||
def problem(self):
|
||||
""" The problem formulation."""
|
||||
return self._problem
|
||||
|
||||
@problem.setter
|
||||
def problem(self, problem):
|
||||
:param torch.tensor x: Input data.
|
||||
:return: PINN solution.
|
||||
:rtype: torch.tensor
|
||||
"""
|
||||
Set the problem formulation."""
|
||||
if not isinstance(problem, AbstractProblem):
|
||||
raise TypeError
|
||||
self._problem = problem
|
||||
# extract labels
|
||||
x = x.extract(self.problem.input_variables)
|
||||
# perform forward pass
|
||||
output = self.model(x).as_subclass(LabelTensor)
|
||||
# set the labels
|
||||
output.labels = self.problem.output_variables
|
||||
return output
|
||||
|
||||
def _compute_norm(self, vec):
|
||||
def configure_optimizers(self):
|
||||
"""Optimizer configuration for the PINN
|
||||
solver.
|
||||
|
||||
:return: The optimizers and the schedulers
|
||||
:rtype: tuple(list, list)
|
||||
"""
|
||||
Compute the norm of the `vec` one-dimensional tensor based on the
|
||||
`self.error_norm` attribute.
|
||||
return [self._optimizer], [self._scheduler]
|
||||
|
||||
.. todo: complete
|
||||
def training_step(self, batch, batch_idx):
|
||||
"""PINN solver training step.
|
||||
|
||||
:param torch.Tensor vec: the tensor
|
||||
:param batch: The batch element in the dataloader.
|
||||
:type batch: tuple
|
||||
:param batch_idx: The batch index.
|
||||
:type batch_idx: int
|
||||
:return: The sum of the loss functions.
|
||||
:rtype: LabelTensor
|
||||
"""
|
||||
if isinstance(self.error_norm, int):
|
||||
return torch.linalg.vector_norm(vec, ord=self.error_norm, dtype=self.dytpe)
|
||||
elif self.error_norm == 'mse':
|
||||
return torch.mean(vec.pow(2))
|
||||
elif self.error_norm == 'me':
|
||||
return torch.mean(torch.abs(vec))
|
||||
else:
|
||||
raise RuntimeError
|
||||
|
||||
def save_state(self, filename):
|
||||
"""
|
||||
Save the state of the model.
|
||||
|
||||
:param str filename: the filename to save the state to.
|
||||
"""
|
||||
checkpoint = {
|
||||
'epoch': self.trained_epoch,
|
||||
'model_state': self.model.state_dict(),
|
||||
'optimizer_state': self.optimizer.state_dict(),
|
||||
'optimizer_class': self.optimizer.__class__,
|
||||
'history': self.history_loss,
|
||||
'input_points_dict': self.input_pts,
|
||||
}
|
||||
|
||||
# TODO save also architecture param?
|
||||
# if isinstance(self.model, DeepFeedForward):
|
||||
# checkpoint['model_class'] = self.model.__class__
|
||||
# checkpoint['model_structure'] = {
|
||||
# }
|
||||
torch.save(checkpoint, filename)
|
||||
|
||||
def load_state(self, filename):
|
||||
"""
|
||||
Load the state of the model.
|
||||
|
||||
:param str filename: the filename to load the state from.
|
||||
"""
|
||||
|
||||
checkpoint = torch.load(filename)
|
||||
self.model.load_state_dict(checkpoint['model_state'])
|
||||
|
||||
self.optimizer = checkpoint['optimizer_class'](self.model.parameters())
|
||||
self.optimizer.load_state_dict(checkpoint['optimizer_state'])
|
||||
|
||||
self.trained_epoch = checkpoint['epoch']
|
||||
self.history_loss = checkpoint['history']
|
||||
|
||||
self.input_pts = checkpoint['input_points_dict']
|
||||
|
||||
return self
|
||||
|
||||
def span_pts(self, *args, **kwargs):
|
||||
"""
|
||||
Generate a set of points to span the `Location` of all the conditions of
|
||||
the problem.
|
||||
|
||||
>>> pinn.span_pts(n=10, mode='grid')
|
||||
>>> pinn.span_pts(n=10, mode='grid', location=['bound1'])
|
||||
>>> pinn.span_pts(n=10, mode='grid', variables=['x'])
|
||||
"""
|
||||
|
||||
if all(key in kwargs for key in ['n', 'mode']):
|
||||
argument = {}
|
||||
argument['n'] = kwargs['n']
|
||||
argument['mode'] = kwargs['mode']
|
||||
argument['variables'] = self.problem.input_variables
|
||||
arguments = [argument]
|
||||
elif any(key in kwargs for key in ['n', 'mode']) and args:
|
||||
raise ValueError("Don't mix args and kwargs")
|
||||
elif isinstance(args[0], int) and isinstance(args[1], str):
|
||||
argument = {}
|
||||
argument['n'] = int(args[0])
|
||||
argument['mode'] = args[1]
|
||||
argument['variables'] = self.problem.input_variables
|
||||
arguments = [argument]
|
||||
elif all(isinstance(arg, dict) for arg in args):
|
||||
arguments = args
|
||||
else:
|
||||
raise RuntimeError
|
||||
|
||||
locations = kwargs.get('locations', 'all')
|
||||
|
||||
if locations == 'all':
|
||||
locations = [condition for condition in self.problem.conditions]
|
||||
for location in locations:
|
||||
condition = self.problem.conditions[location]
|
||||
|
||||
samples = tuple(condition.location.sample(
|
||||
argument['n'],
|
||||
argument['mode'],
|
||||
variables=argument['variables'])
|
||||
for argument in arguments)
|
||||
pts = merge_tensors(samples)
|
||||
|
||||
# TODO
|
||||
# pts = pts.double()
|
||||
self.input_pts[location] = pts
|
||||
|
||||
def _residual_loss(self, input_pts, equation):
|
||||
"""
|
||||
Compute the residual loss for a given condition.
|
||||
|
||||
:param torch.Tensor pts: the points to evaluate the residual at.
|
||||
:param Equation equation: the equation to evaluate the residual with.
|
||||
"""
|
||||
|
||||
input_pts = input_pts.to(dtype=self.dtype, device=self.device)
|
||||
input_pts.requires_grad_(True)
|
||||
input_pts.retain_grad()
|
||||
|
||||
predicted = self.model(input_pts)
|
||||
residuals = equation.residual(input_pts, predicted)
|
||||
return self._compute_norm(residuals)
|
||||
|
||||
def _data_loss(self, input_pts, output_pts):
|
||||
"""
|
||||
Compute the residual loss for a given condition.
|
||||
|
||||
:param torch.Tensor pts: the points to evaluate the residual at.
|
||||
:param Equation equation: the equation to evaluate the residual with.
|
||||
"""
|
||||
input_pts = input_pts.to(dtype=self.dtype, device=self.device)
|
||||
output_pts = output_pts.to(dtype=self.dtype, device=self.device)
|
||||
predicted = self.model(input_pts)
|
||||
residuals = predicted - output_pts
|
||||
return self._compute_norm(residuals)
|
||||
|
||||
|
||||
# def closure(self):
|
||||
# """
|
||||
# """
|
||||
# self.optimizer.zero_grad()
|
||||
|
||||
# condition_losses = []
|
||||
# from torch.utils.data import DataLoader
|
||||
# from .utils import MyDataset
|
||||
# loader = DataLoader(
|
||||
# MyDataset(self.input_pts),
|
||||
# batch_size=self.batch_size,
|
||||
# num_workers=1
|
||||
# )
|
||||
# for condition_name in self.problem.conditions:
|
||||
# condition = self.problem.conditions[condition_name]
|
||||
|
||||
# batch_losses = []
|
||||
# for batch in data_loader[condition_name]:
|
||||
|
||||
# if hasattr(condition, 'equation'):
|
||||
# loss = self._residual_loss(
|
||||
# batch[condition_name], condition.equation)
|
||||
# elif hasattr(condition, 'output_points'):
|
||||
# loss = self._data_loss(
|
||||
# batch[condition_name], condition.output_points)
|
||||
|
||||
# batch_losses.append(loss * condition.data_weight)
|
||||
|
||||
# condition_losses.append(sum(batch_losses))
|
||||
|
||||
# loss = sum(condition_losses)
|
||||
# loss.backward()
|
||||
# return loss
|
||||
|
||||
def closure(self):
|
||||
"""
|
||||
"""
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
losses = []
|
||||
for i, batch in enumerate(self.loader):
|
||||
|
||||
condition_losses = []
|
||||
|
||||
@@ -302,52 +100,20 @@ class PINN(object):
|
||||
if condition_name not in self.problem.conditions:
|
||||
raise RuntimeError('Something wrong happened.')
|
||||
|
||||
if samples is None or samples.nelement() == 0:
|
||||
continue
|
||||
|
||||
condition = self.problem.conditions[condition_name]
|
||||
|
||||
# PINN loss: equation evaluated on location or input_points
|
||||
if hasattr(condition, 'equation'):
|
||||
loss = self._residual_loss(samples, condition.equation)
|
||||
target = condition.equation.residual(samples, self.forward(samples))
|
||||
loss = self._loss(torch.zeros_like(target), target)
|
||||
# PINN loss: evaluate model(input_points) vs output_points
|
||||
elif hasattr(condition, 'output_points'):
|
||||
loss = self._data_loss(samples, condition.output_points)
|
||||
input_pts, output_pts = samples
|
||||
loss = self._loss(self.forward(input_pts), output_pts)
|
||||
|
||||
condition_losses.append(loss * condition.data_weight)
|
||||
|
||||
losses.append(sum(condition_losses))
|
||||
|
||||
loss = sum(losses)
|
||||
loss.backward()
|
||||
return losses[0]
|
||||
|
||||
def train(self, stop=100):
|
||||
|
||||
self.model.train()
|
||||
|
||||
############################################################
|
||||
## TODO: move to problem class
|
||||
for condition in list(set(self.problem.conditions.keys()) - set(self.input_pts.keys())):
|
||||
self.input_pts[condition] = self.problem.conditions[condition].input_points
|
||||
|
||||
mydata = self.input_pts
|
||||
|
||||
self.loader = DummyLoader(mydata)
|
||||
|
||||
while True:
|
||||
|
||||
loss = self.optimizer.step(closure=self.closure)
|
||||
|
||||
self.writer.write_loss_in_loop(self, loss)
|
||||
|
||||
#self._lr_scheduler.step()
|
||||
|
||||
if isinstance(stop, int):
|
||||
if self.trained_epoch == stop:
|
||||
break
|
||||
elif isinstance(stop, float):
|
||||
if loss.item() < stop:
|
||||
break
|
||||
|
||||
self.trained_epoch += 1
|
||||
|
||||
self.model.eval()
|
||||
# TODO Fix the bug, tot_loss is a label tensor without labels
|
||||
# we need to pass it as a torch tensor to make everything work
|
||||
total_loss = sum(condition_losses)
|
||||
return total_loss
|
||||
@@ -11,11 +11,11 @@ class Plotter:
|
||||
Implementation of a plotter class, for easy visualizations.
|
||||
"""
|
||||
|
||||
def plot_samples(self, pinn, variables=None):
|
||||
def plot_samples(self, solver, variables=None):
|
||||
"""
|
||||
Plot a sample of solution.
|
||||
Plot the training grid samples.
|
||||
|
||||
:param PINN pinn: the PINN object.
|
||||
:param SolverInterface solver: the SolverInterface object.
|
||||
:param list(str) variables: variables to plot. If None, all variables
|
||||
are plotted. If 'spatial', only spatial variables are plotted. If
|
||||
'temporal', only temporal variables are plotted. Defaults to None.
|
||||
@@ -26,15 +26,15 @@ class Plotter:
|
||||
|
||||
:Example:
|
||||
>>> plotter = Plotter()
|
||||
>>> plotter.plot_samples(pinn=pinn, variables='spatial')
|
||||
>>> plotter.plot_samples(solver=solver, variables='spatial')
|
||||
"""
|
||||
|
||||
if variables is None:
|
||||
variables = pinn.problem.domain.variables
|
||||
variables = solver.problem.domain.variables
|
||||
elif variables == 'spatial':
|
||||
variables = pinn.problem.spatial_domain.variables
|
||||
variables = solver.problem.spatial_domain.variables
|
||||
elif variables == 'temporal':
|
||||
variables = pinn.problem.temporal_domain.variables
|
||||
variables = solver.problem.temporal_domain.variables
|
||||
|
||||
if len(variables) not in [1, 2, 3]:
|
||||
raise ValueError
|
||||
@@ -42,8 +42,8 @@ class Plotter:
|
||||
fig = plt.figure()
|
||||
proj = '3d' if len(variables) == 3 else None
|
||||
ax = fig.add_subplot(projection=proj)
|
||||
for location in pinn.input_pts:
|
||||
coords = pinn.input_pts[location].extract(variables).T.detach()
|
||||
for location in solver.problem.input_pts:
|
||||
coords = solver.problem.input_pts[location].extract(variables).T.detach()
|
||||
if coords.shape[0] == 1: # 1D samples
|
||||
ax.plot(coords[0], torch.zeros(coords[0].shape), '.',
|
||||
label=location)
|
||||
@@ -69,7 +69,7 @@ class Plotter:
|
||||
|
||||
:param pts: Points to plot the solution.
|
||||
:type pts: torch.Tensor
|
||||
:param pred: PINN solution evaluated at 'pts'.
|
||||
:param pred: SolverInterface solution evaluated at 'pts'.
|
||||
:type pred: torch.Tensor
|
||||
:param method: not used, kept for code compatibility
|
||||
:type method: None
|
||||
@@ -95,7 +95,7 @@ class Plotter:
|
||||
|
||||
:param pts: Points to plot the solution.
|
||||
:type pts: torch.Tensor
|
||||
:param pred: PINN solution evaluated at 'pts'.
|
||||
:param pred: SolverInterface solution evaluated at 'pts'.
|
||||
:type pred: torch.Tensor
|
||||
:param method: matplotlib method to plot 2-dimensional data,
|
||||
see https://matplotlib.org/stable/api/axes_api.html for
|
||||
@@ -129,12 +129,12 @@ class Plotter:
|
||||
*grids, pred_output.cpu().detach(), **kwargs)
|
||||
fig.colorbar(cb, ax=ax)
|
||||
|
||||
def plot(self, pinn, components=None, fixed_variables={}, method='contourf',
|
||||
def plot(self, solver, components=None, fixed_variables={}, method='contourf',
|
||||
res=256, filename=None, **kwargs):
|
||||
"""
|
||||
Plot sample of PINN output.
|
||||
Plot sample of SolverInterface output.
|
||||
|
||||
:param PINN pinn: the PINN object.
|
||||
:param SolverInterface solver: the SolverInterface object.
|
||||
:param list(str) components: the output variable to plot. If None, all
|
||||
the output variables of the problem are selected. Default value is
|
||||
None.
|
||||
@@ -150,12 +150,12 @@ class Plotter:
|
||||
is shown using the setted matplotlib frontend. Default is None.
|
||||
"""
|
||||
if components is None:
|
||||
components = [pinn.problem.output_variables]
|
||||
components = [solver.problem.output_variables]
|
||||
v = [
|
||||
var for var in pinn.problem.input_variables
|
||||
var for var in solver.problem.input_variables
|
||||
if var not in fixed_variables.keys()
|
||||
]
|
||||
pts = pinn.problem.domain.sample(res, 'grid', variables=v)
|
||||
pts = solver.problem.domain.sample(res, 'grid', variables=v)
|
||||
|
||||
fixed_pts = torch.ones(pts.shape[0], len(fixed_variables))
|
||||
fixed_pts *= torch.tensor(list(fixed_variables.values()))
|
||||
@@ -163,15 +163,15 @@ class Plotter:
|
||||
fixed_pts.labels = list(fixed_variables.keys())
|
||||
|
||||
pts = pts.append(fixed_pts)
|
||||
pts = pts.to(device=pinn.device)
|
||||
pts = pts.to(device=solver.device)
|
||||
|
||||
predicted_output = pinn.model(pts)
|
||||
predicted_output = solver.forward(pts)
|
||||
if isinstance(components, str):
|
||||
predicted_output = predicted_output.extract(components)
|
||||
elif callable(components):
|
||||
predicted_output = components(predicted_output)
|
||||
|
||||
truth_solution = getattr(pinn.problem, 'truth_solution', None)
|
||||
truth_solution = getattr(solver.problem, 'truth_solution', None)
|
||||
if len(v) == 1:
|
||||
self._1d_plot(pts, predicted_output, method, truth_solution,
|
||||
**kwargs)
|
||||
@@ -186,37 +186,25 @@ class Plotter:
|
||||
else:
|
||||
plt.show()
|
||||
|
||||
def plot_loss(self, pinn, label=None, log_scale=True, filename=None):
|
||||
"""
|
||||
Plot the loss function values during traininig.
|
||||
# TODO loss
|
||||
# def plot_loss(self, solver, label=None, log_scale=True):
|
||||
# """
|
||||
# Plot the loss function values during traininig.
|
||||
|
||||
:param PINN pinn: the PINN object.
|
||||
:param str label: the label to use in the legend, defaults to None.
|
||||
:param bool log_scale: If True, the y axis is in log scale. Default is
|
||||
True.
|
||||
:param str filename: the file name to save the plot. If None, the plot
|
||||
is not saved. Default is None.
|
||||
"""
|
||||
# :param SolverInterface solver: the SolverInterface object.
|
||||
# :param str label: the label to use in the legend, defaults to None.
|
||||
# :param bool log_scale: If True, the y axis is in log scale. Default is
|
||||
# True.
|
||||
# """
|
||||
|
||||
if not label:
|
||||
label = str(pinn)
|
||||
# if not label:
|
||||
# label = str(solver)
|
||||
|
||||
epochs = list(pinn.history_loss.keys())
|
||||
loss = np.array(list(pinn.history_loss.values()))
|
||||
# epochs = list(solver.history_loss.keys())
|
||||
# loss = np.array(list(solver.history_loss.values()))
|
||||
# if loss.ndim != 1:
|
||||
# loss = loss[:, 0]
|
||||
|
||||
# if multiple outputs, sum the loss
|
||||
if loss.ndim != 1:
|
||||
loss = np.sum(loss, axis=1)
|
||||
|
||||
# plot loss
|
||||
plt.plot(epochs, loss, label=label)
|
||||
plt.legend()
|
||||
if log_scale:
|
||||
plt.yscale('log')
|
||||
plt.title('Loss function')
|
||||
plt.xlabel('Epochs')
|
||||
plt.ylabel('Loss')
|
||||
|
||||
# save plot
|
||||
if filename:
|
||||
plt.savefig(filename)
|
||||
# plt.plot(epochs, loss, label=label)
|
||||
# if log_scale:
|
||||
# plt.yscale('log')
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
""" Module for AbstractProblem class """
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from ..utils import merge_tensors
|
||||
|
||||
|
||||
class AbstractProblem(metaclass=ABCMeta):
|
||||
@@ -11,6 +12,19 @@ class AbstractProblem(metaclass=ABCMeta):
|
||||
the output variables, the condition(s), and the domain(s) where the
|
||||
conditions are applied.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
|
||||
# variable storing all points
|
||||
self.input_pts = {}
|
||||
|
||||
# varible to check if sampling is done. If no location
|
||||
# element is presented in Condition this variable is set to true
|
||||
self._have_sampled_points = {}
|
||||
|
||||
# put in self.input_pts all the points that we don't need to sample
|
||||
self._span_condition_points()
|
||||
|
||||
@property
|
||||
def input_variables(self):
|
||||
"""
|
||||
@@ -80,3 +94,91 @@ class AbstractProblem(metaclass=ABCMeta):
|
||||
The conditions of the problem.
|
||||
"""
|
||||
pass
|
||||
|
||||
def _span_condition_points(self):
|
||||
"""
|
||||
Simple function to get the condition points
|
||||
"""
|
||||
for condition_name in self.conditions:
|
||||
condition = self.conditions[condition_name]
|
||||
if hasattr(condition, 'equation') and hasattr(condition, 'input_points'):
|
||||
samples = condition.input_points
|
||||
elif hasattr(condition, 'output_points') and hasattr(condition, 'input_points'):
|
||||
samples = (condition.input_points, condition.output_points)
|
||||
# skip if we need to sample
|
||||
elif hasattr(condition, 'location'):
|
||||
self._have_sampled_points[condition_name] = False
|
||||
continue
|
||||
self.input_pts[condition_name] = samples
|
||||
|
||||
def discretise_domain(self, *args, **kwargs):
|
||||
"""
|
||||
Generate a set of points to span the `Location` of all the conditions of
|
||||
the problem.
|
||||
|
||||
>>> pinn.span_pts(n=10, mode='grid')
|
||||
>>> pinn.span_pts(n=10, mode='grid', location=['bound1'])
|
||||
>>> pinn.span_pts(n=10, mode='grid', variables=['x'])
|
||||
"""
|
||||
if all(key in kwargs for key in ['n', 'mode']):
|
||||
argument = {}
|
||||
argument['n'] = kwargs['n']
|
||||
argument['mode'] = kwargs['mode']
|
||||
argument['variables'] = self.input_variables
|
||||
arguments = [argument]
|
||||
elif any(key in kwargs for key in ['n', 'mode']) and args:
|
||||
raise ValueError("Don't mix args and kwargs")
|
||||
elif isinstance(args[0], int) and isinstance(args[1], str):
|
||||
argument = {}
|
||||
argument['n'] = int(args[0])
|
||||
argument['mode'] = args[1]
|
||||
argument['variables'] = self.input_variables
|
||||
arguments = [argument]
|
||||
elif all(isinstance(arg, dict) for arg in args):
|
||||
arguments = args
|
||||
else:
|
||||
raise RuntimeError
|
||||
|
||||
locations = kwargs.get('locations', 'all')
|
||||
|
||||
if locations == 'all':
|
||||
locations = [condition for condition in self.conditions]
|
||||
for location in locations:
|
||||
condition = self.conditions[location]
|
||||
|
||||
samples = tuple(condition.location.sample(
|
||||
argument['n'],
|
||||
argument['mode'],
|
||||
variables=argument['variables'])
|
||||
for argument in arguments)
|
||||
pts = merge_tensors(samples)
|
||||
self.input_pts[location] = pts
|
||||
# setting the grad
|
||||
self.input_pts[location].requires_grad_(True)
|
||||
self.input_pts[location].retain_grad()
|
||||
# the condition is sampled
|
||||
self._have_sampled_points[location] = True
|
||||
|
||||
@property
|
||||
def have_sampled_points(self):
|
||||
"""
|
||||
Check if all points for
|
||||
``'Location'`` are sampled.
|
||||
"""
|
||||
return all(self._have_sampled_points.values())
|
||||
|
||||
@property
|
||||
def not_sampled_points(self):
|
||||
"""Check which points are
|
||||
not sampled.
|
||||
"""
|
||||
# variables which are not sampled
|
||||
not_sampled = None
|
||||
if self.have_sampled_points is False:
|
||||
# check which one are not sampled:
|
||||
not_sampled = []
|
||||
for condition_name, is_sample in self._have_sampled_points.items():
|
||||
if not is_sample:
|
||||
not_sampled.append(condition_name)
|
||||
return not_sampled
|
||||
|
||||
|
||||
65
pina/solver.py
Normal file
65
pina/solver.py
Normal file
@@ -0,0 +1,65 @@
|
||||
""" Solver module. """
|
||||
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from .model.network import Network
|
||||
import lightning.pytorch as pl
|
||||
from .utils import check_consistency
|
||||
from .problem import AbstractProblem
|
||||
|
||||
class SolverInterface(pl.LightningModule, metaclass=ABCMeta):
|
||||
""" Solver base class. """
|
||||
def __init__(self, model, problem, extra_features=None):
|
||||
"""
|
||||
:param model: A torch neural network model instance.
|
||||
:type model: torch.nn.Module
|
||||
:param problem: A problem definition instance.
|
||||
:type problem: AbstractProblem
|
||||
:param list(torch.nn.Module) extra_features: the additional input
|
||||
features to use as augmented input.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
# check inheritance for pina problem
|
||||
check_consistency(problem, AbstractProblem, 'pina problem')
|
||||
|
||||
# assigning class variables (check consistency inside Network class)
|
||||
self._pina_model = Network(model=model, extra_features=extra_features)
|
||||
self._problem = problem
|
||||
|
||||
@abstractmethod
|
||||
def forward(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def training_step(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def configure_optimizers(self):
|
||||
pass
|
||||
|
||||
@property
|
||||
def model(self):
|
||||
"""
|
||||
The torch model."""
|
||||
return self._pina_model
|
||||
|
||||
@property
|
||||
def problem(self):
|
||||
"""
|
||||
The problem formulation."""
|
||||
return self._problem
|
||||
|
||||
# @model.setter
|
||||
# def model(self, new_model):
|
||||
# """
|
||||
# Set the torch."""
|
||||
# check_consistency(new_model, nn.Module, 'torch model')
|
||||
# self._model= new_model
|
||||
|
||||
# @problem.setter
|
||||
# def problem(self, problem):
|
||||
# """
|
||||
# Set the problem formulation."""
|
||||
# check_consistency(problem, AbstractProblem, 'pina problem')
|
||||
# self._problem = problem
|
||||
31
pina/trainer.py
Normal file
31
pina/trainer.py
Normal file
@@ -0,0 +1,31 @@
|
||||
""" Solver module. """
|
||||
|
||||
import lightning.pytorch as pl
|
||||
from .utils import check_consistency
|
||||
from .dataset import DummyLoader
|
||||
from .solver import SolverInterface
|
||||
|
||||
class Trainer(pl.Trainer):
|
||||
|
||||
def __init__(self, solver, kwargs={}):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
# check inheritance consistency for solver
|
||||
check_consistency(solver, SolverInterface, 'Solver model')
|
||||
self._model = solver
|
||||
|
||||
# create dataloader
|
||||
if solver.problem.have_sampled_points is False:
|
||||
raise RuntimeError(f'Input points in {solver.problem.not_sampled_points} '
|
||||
'training are None. Please '
|
||||
'sample points in your problem by calling '
|
||||
'discretise_domain function before train '
|
||||
'in the provided locations.')
|
||||
|
||||
# TODO: make a better dataloader for train
|
||||
self._loader = DummyLoader(solver.problem.input_pts)
|
||||
|
||||
|
||||
def train(self): # TODO add kwargs and lightining capabilities
|
||||
return super().fit(self._model, self._loader)
|
||||
|
||||
@@ -10,6 +10,29 @@ from .label_tensor import LabelTensor
|
||||
import torch
|
||||
|
||||
|
||||
def check_consistency(object, object_instance, object_name, subclass=False):
|
||||
"""Helper function to check object inheritance consistency.
|
||||
Given a specific ``'object'`` we check if the object is
|
||||
instance of a specific ``'object_instance'``, or in case
|
||||
``'subclass=True'`` we check if the object is subclass
|
||||
if the ``'object_instance'``.
|
||||
|
||||
:param Object object: The object to check the inheritance
|
||||
:param Object object_instance: The parent class from where the object
|
||||
is expected to inherit
|
||||
:param str object_name: The name of the object
|
||||
:param bool subclass: Check if is a subclass and not instance
|
||||
:raises ValueError: If the object does not inherit from the
|
||||
specified class
|
||||
"""
|
||||
if not subclass:
|
||||
if not isinstance(object, object_instance):
|
||||
raise ValueError(f"{object_name} must be {object_instance}")
|
||||
else:
|
||||
if not issubclass(object, object_instance):
|
||||
raise ValueError(f"{object_name} must be {object_instance}")
|
||||
|
||||
|
||||
def number_parameters(model, aggregate=True, only_trainable=True): # TODO: check
|
||||
"""
|
||||
Return the number of parameters of a given `model`.
|
||||
@@ -189,8 +212,7 @@ class LabelTensorDataset(Dataset):
|
||||
class LabelTensorDataLoader(DataLoader):
|
||||
|
||||
def collate_fn(self, data):
|
||||
print(data)
|
||||
gggggggggg
|
||||
pass
|
||||
# return dict(zip(self.pinn.input_pts.keys(), dataloaders))
|
||||
|
||||
# class SampleDataset(torch.utils.data.Dataset):
|
||||
@@ -239,5 +261,4 @@ class LabelTensorDataset(Dataset):
|
||||
class LabelTensorDataLoader(DataLoader):
|
||||
|
||||
def collate_fn(self, data):
|
||||
print(data)
|
||||
gggggggggg
|
||||
pass
|
||||
49
tests/test_loss.py
Normal file
49
tests/test_loss.py
Normal file
@@ -0,0 +1,49 @@
|
||||
import torch
|
||||
import pytest
|
||||
|
||||
from pina.loss import *
|
||||
|
||||
input = torch.tensor([[3.], [1.], [-8.]])
|
||||
target = torch.tensor([[6.], [4.], [2.]])
|
||||
available_reductions = ['str', 'mean', 'none']
|
||||
|
||||
|
||||
def test_LpLoss_constructor():
|
||||
# test reduction
|
||||
for reduction in available_reductions:
|
||||
LpLoss(reduction=reduction)
|
||||
# test p
|
||||
for p in [float('inf'), -float('inf'), 1, 10, -8]:
|
||||
LpLoss(p=p)
|
||||
|
||||
def test_LpLoss_forward():
|
||||
# l2 loss
|
||||
loss = LpLoss(p=2, reduction='mean')
|
||||
l2_loss = torch.mean(torch.sqrt((input-target).pow(2)))
|
||||
assert loss(input, target) == l2_loss
|
||||
# l1 loss
|
||||
loss = LpLoss(p=1, reduction='sum')
|
||||
l1_loss = torch.sum(torch.abs(input-target))
|
||||
assert loss(input, target) == l1_loss
|
||||
|
||||
def test_LpRelativeLoss_constructor():
|
||||
# test reduction
|
||||
for reduction in available_reductions:
|
||||
LpLoss(reduction=reduction, relative=True)
|
||||
# test p
|
||||
for p in [float('inf'), -float('inf'), 1, 10, -8]:
|
||||
LpLoss(p=p,relative=True)
|
||||
|
||||
def test_LpRelativeLoss_forward():
|
||||
# l2 relative loss
|
||||
loss = LpLoss(p=2, reduction='mean',relative=True)
|
||||
l2_loss = torch.sqrt((input-target).pow(2))/torch.sqrt(input.pow(2))
|
||||
assert loss(input, target) == torch.mean(l2_loss)
|
||||
# l1 relative loss
|
||||
loss = LpLoss(p=1, reduction='sum',relative=True)
|
||||
l1_loss = torch.abs(input-target)/torch.abs(input)
|
||||
assert loss(input, target) == torch.sum(l1_loss)
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -1,55 +0,0 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import pytest
|
||||
from pina.model import Network, FeedForward
|
||||
from pina import LabelTensor
|
||||
|
||||
|
||||
class myFeature(torch.nn.Module):
|
||||
"""
|
||||
Feature: sin(x)
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(myFeature, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
t = (torch.sin(x.extract(['x'])*torch.pi) *
|
||||
torch.sin(x.extract(['y'])*torch.pi))
|
||||
return LabelTensor(t, ['sin(x)sin(y)'])
|
||||
|
||||
|
||||
input_variables = ['x', 'y']
|
||||
output_variables = ['u']
|
||||
data = torch.rand((20, 2))
|
||||
input_ = LabelTensor(data, input_variables)
|
||||
|
||||
|
||||
def test_constructor():
|
||||
net = FeedForward(2, 1)
|
||||
pina_net = Network(model=net, input_variables=input_variables,
|
||||
output_variables=output_variables)
|
||||
|
||||
|
||||
def test_forward():
|
||||
net = FeedForward(2, 1)
|
||||
pina_net = Network(model=net, input_variables=input_variables,
|
||||
output_variables=output_variables)
|
||||
output_ = pina_net(input_)
|
||||
assert output_.labels == output_variables
|
||||
|
||||
|
||||
def test_constructor_extrafeat():
|
||||
net = FeedForward(3, 1)
|
||||
feat = [myFeature()]
|
||||
pina_net = Network(model=net, input_variables=input_variables,
|
||||
output_variables=output_variables, extra_features=feat)
|
||||
|
||||
|
||||
def test_forward_extrafeat():
|
||||
net = FeedForward(3, 1)
|
||||
feat = [myFeature()]
|
||||
pina_net = Network(model=net, input_variables=input_variables,
|
||||
output_variables=output_variables, extra_features=feat)
|
||||
output_ = pina_net(input_)
|
||||
assert output_.labels == output_variables
|
||||
@@ -1,17 +1,18 @@
|
||||
import torch
|
||||
import pytest
|
||||
|
||||
from pina import LabelTensor, Condition, CartesianDomain, PINN
|
||||
from pina.problem import SpatialProblem
|
||||
from pina.model import FeedForward
|
||||
from pina.operators import nabla
|
||||
from pina.geometry import CartesianDomain
|
||||
from pina import Condition, LabelTensor, PINN
|
||||
from pina.trainer import Trainer
|
||||
from pina.model import FeedForward
|
||||
from pina.equation.equation import Equation
|
||||
from pina.equation.equation_factory import FixedValue
|
||||
from pina.plotter import Plotter
|
||||
from pina.loss import LpLoss
|
||||
|
||||
|
||||
in_ = LabelTensor(torch.tensor([[0., 1.]]), ['x', 'y'])
|
||||
out_ = LabelTensor(torch.tensor([[0.]]), ['u'])
|
||||
|
||||
def laplace_equation(input_, output_):
|
||||
force_term = (torch.sin(input_.extract(['x'])*torch.pi) *
|
||||
torch.sin(input_.extract(['y'])*torch.pi))
|
||||
@@ -19,6 +20,8 @@ def laplace_equation(input_, output_):
|
||||
return nabla_u - force_term
|
||||
|
||||
my_laplace = Equation(laplace_equation)
|
||||
in_ = LabelTensor(torch.tensor([[0., 1.]], requires_grad=True), ['x', 'y'])
|
||||
out_ = LabelTensor(torch.tensor([[0.]], requires_grad=True), ['u'])
|
||||
|
||||
class Poisson(SpatialProblem):
|
||||
output_variables = ['u']
|
||||
@@ -68,75 +71,40 @@ class myFeature(torch.nn.Module):
|
||||
return LabelTensor(t, ['sin(x)sin(y)'])
|
||||
|
||||
|
||||
problem = Poisson()
|
||||
model = FeedForward(len(problem.input_variables),len(problem.output_variables))
|
||||
model_extra_feat = FeedForward(len(problem.input_variables) + 1,len(problem.output_variables))
|
||||
# make the problem
|
||||
poisson_problem = Poisson()
|
||||
model = FeedForward(len(poisson_problem.input_variables),len(poisson_problem.output_variables))
|
||||
model_extra_feats = FeedForward(len(poisson_problem.input_variables)+1,len(poisson_problem.output_variables))
|
||||
extra_feats = [myFeature()]
|
||||
|
||||
|
||||
def test_constructor():
|
||||
PINN(problem, model)
|
||||
PINN(problem = poisson_problem, model=model, extra_features=None)
|
||||
|
||||
|
||||
def test_constructor_extra_feats():
|
||||
PINN(problem, model_extra_feat, [myFeature()])
|
||||
|
||||
|
||||
def test_span_pts():
|
||||
pinn = PINN(problem, model)
|
||||
n = 10
|
||||
boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
|
||||
pinn.span_pts(n, 'grid', locations=boundaries)
|
||||
for b in boundaries:
|
||||
assert pinn.input_pts[b].shape[0] == n
|
||||
pinn.span_pts(n, 'random', locations=boundaries)
|
||||
for b in boundaries:
|
||||
assert pinn.input_pts[b].shape[0] == n
|
||||
|
||||
pinn.span_pts(n, 'grid', locations=['D'])
|
||||
assert pinn.input_pts['D'].shape[0] == n**2
|
||||
pinn.span_pts(n, 'random', locations=['D'])
|
||||
assert pinn.input_pts['D'].shape[0] == n
|
||||
|
||||
pinn.span_pts(n, 'latin', locations=['D'])
|
||||
assert pinn.input_pts['D'].shape[0] == n
|
||||
|
||||
pinn.span_pts(n, 'lh', locations=['D'])
|
||||
assert pinn.input_pts['D'].shape[0] == n
|
||||
|
||||
|
||||
def test_sampling_all_args():
|
||||
pinn = PINN(problem, model)
|
||||
n = 10
|
||||
pinn.span_pts(n, 'grid', locations=['D'])
|
||||
|
||||
|
||||
def test_sampling_all_kwargs():
|
||||
pinn = PINN(problem, model)
|
||||
n = 10
|
||||
pinn.span_pts(n=n, mode='latin', locations=['D'])
|
||||
|
||||
|
||||
def test_sampling_dict():
|
||||
pinn = PINN(problem, model)
|
||||
n = 10
|
||||
pinn.span_pts(
|
||||
{'variables': ['x', 'y'], 'mode': 'grid', 'n': n}, locations=['D'])
|
||||
|
||||
|
||||
def test_sampling_mixed_args_kwargs():
|
||||
pinn = PINN(problem, model)
|
||||
n = 10
|
||||
with pytest.raises(ValueError):
|
||||
pinn.span_pts(n, mode='latin', locations=['D'])
|
||||
|
||||
model_extra_feats = FeedForward(len(poisson_problem.input_variables)+1,len(poisson_problem.output_variables))
|
||||
PINN(problem = poisson_problem, model=model_extra_feats, extra_features=extra_feats)
|
||||
|
||||
def test_train():
|
||||
pinn = PINN(problem, model)
|
||||
poisson_problem = Poisson()
|
||||
boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
|
||||
n = 10
|
||||
pinn.span_pts(n, 'grid', locations=boundaries)
|
||||
pinn.span_pts(n, 'grid', locations=['D'])
|
||||
pinn.train(5)
|
||||
poisson_problem.discretise_domain(n, 'grid', locations=boundaries)
|
||||
poisson_problem.discretise_domain(n, 'grid', locations=['D'])
|
||||
pinn = PINN(problem = poisson_problem, model=model, extra_features=None, loss=LpLoss())
|
||||
trainer = Trainer(solver=pinn, kwargs={'max_epochs' : 5})
|
||||
trainer.train()
|
||||
|
||||
def test_train_extra_feats():
|
||||
poisson_problem = Poisson()
|
||||
boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
|
||||
n = 10
|
||||
poisson_problem.discretise_domain(n, 'grid', locations=boundaries)
|
||||
poisson_problem.discretise_domain(n, 'grid', locations=['D'])
|
||||
pinn = PINN(problem = poisson_problem, model=model_extra_feats, extra_features=extra_feats)
|
||||
trainer = Trainer(solver=pinn, kwargs={'max_epochs' : 5})
|
||||
trainer.train()
|
||||
|
||||
"""
|
||||
def test_train_2():
|
||||
@@ -146,8 +114,8 @@ def test_train_2():
|
||||
param = [0, 3]
|
||||
for i, truth_key in zip(param, expected_keys):
|
||||
pinn = PINN(problem, model)
|
||||
pinn.span_pts(n, 'grid', locations=boundaries)
|
||||
pinn.span_pts(n, 'grid', locations=['D'])
|
||||
pinn.discretise_domain(n, 'grid', locations=boundaries)
|
||||
pinn.discretise_domain(n, 'grid', locations=['D'])
|
||||
pinn.train(50, save_loss=i)
|
||||
assert list(pinn.history_loss.keys()) == truth_key
|
||||
|
||||
@@ -156,8 +124,8 @@ def test_train_extra_feats():
|
||||
pinn = PINN(problem, model_extra_feat, [myFeature()])
|
||||
boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
|
||||
n = 10
|
||||
pinn.span_pts(n, 'grid', locations=boundaries)
|
||||
pinn.span_pts(n, 'grid', locations=['D'])
|
||||
pinn.discretise_domain(n, 'grid', locations=boundaries)
|
||||
pinn.discretise_domain(n, 'grid', locations=['D'])
|
||||
pinn.train(5)
|
||||
|
||||
|
||||
@@ -168,8 +136,8 @@ def test_train_2_extra_feats():
|
||||
param = [0, 3]
|
||||
for i, truth_key in zip(param, expected_keys):
|
||||
pinn = PINN(problem, model_extra_feat, [myFeature()])
|
||||
pinn.span_pts(n, 'grid', locations=boundaries)
|
||||
pinn.span_pts(n, 'grid', locations=['D'])
|
||||
pinn.discretise_domain(n, 'grid', locations=boundaries)
|
||||
pinn.discretise_domain(n, 'grid', locations=['D'])
|
||||
pinn.train(50, save_loss=i)
|
||||
assert list(pinn.history_loss.keys()) == truth_key
|
||||
|
||||
@@ -181,8 +149,8 @@ def test_train_with_optimizer_kwargs():
|
||||
param = [0, 3]
|
||||
for i, truth_key in zip(param, expected_keys):
|
||||
pinn = PINN(problem, model, optimizer_kwargs={'lr' : 0.3})
|
||||
pinn.span_pts(n, 'grid', locations=boundaries)
|
||||
pinn.span_pts(n, 'grid', locations=['D'])
|
||||
pinn.discretise_domain(n, 'grid', locations=boundaries)
|
||||
pinn.discretise_domain(n, 'grid', locations=['D'])
|
||||
pinn.train(50, save_loss=i)
|
||||
assert list(pinn.history_loss.keys()) == truth_key
|
||||
|
||||
@@ -199,8 +167,8 @@ def test_train_with_lr_scheduler():
|
||||
lr_scheduler_type=torch.optim.lr_scheduler.CyclicLR,
|
||||
lr_scheduler_kwargs={'base_lr' : 0.1, 'max_lr' : 0.3, 'cycle_momentum': False}
|
||||
)
|
||||
pinn.span_pts(n, 'grid', locations=boundaries)
|
||||
pinn.span_pts(n, 'grid', locations=['D'])
|
||||
pinn.discretise_domain(n, 'grid', locations=boundaries)
|
||||
pinn.discretise_domain(n, 'grid', locations=['D'])
|
||||
pinn.train(50, save_loss=i)
|
||||
assert list(pinn.history_loss.keys()) == truth_key
|
||||
|
||||
@@ -209,8 +177,8 @@ def test_train_with_lr_scheduler():
|
||||
# pinn = PINN(problem, model, batch_size=6)
|
||||
# boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
|
||||
# n = 10
|
||||
# pinn.span_pts(n, 'grid', locations=boundaries)
|
||||
# pinn.span_pts(n, 'grid', locations=['D'])
|
||||
# pinn.discretise_domain(n, 'grid', locations=boundaries)
|
||||
# pinn.discretise_domain(n, 'grid', locations=['D'])
|
||||
# pinn.train(5)
|
||||
|
||||
|
||||
@@ -221,8 +189,8 @@ def test_train_with_lr_scheduler():
|
||||
# param = [0, 3]
|
||||
# for i, truth_key in zip(param, expected_keys):
|
||||
# pinn = PINN(problem, model, batch_size=6)
|
||||
# pinn.span_pts(n, 'grid', locations=boundaries)
|
||||
# pinn.span_pts(n, 'grid', locations=['D'])
|
||||
# pinn.discretise_domain(n, 'grid', locations=boundaries)
|
||||
# pinn.discretise_domain(n, 'grid', locations=['D'])
|
||||
# pinn.train(50, save_loss=i)
|
||||
# assert list(pinn.history_loss.keys()) == truth_key
|
||||
|
||||
@@ -233,15 +201,15 @@ if torch.cuda.is_available():
|
||||
# pinn = PINN(problem, model, batch_size=20, device='cuda')
|
||||
# boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
|
||||
# n = 100
|
||||
# pinn.span_pts(n, 'grid', locations=boundaries)
|
||||
# pinn.span_pts(n, 'grid', locations=['D'])
|
||||
# pinn.discretise_domain(n, 'grid', locations=boundaries)
|
||||
# pinn.discretise_domain(n, 'grid', locations=['D'])
|
||||
# pinn.train(5)
|
||||
|
||||
def test_gpu_train_nobatch():
|
||||
pinn = PINN(problem, model, batch_size=None, device='cuda')
|
||||
boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
|
||||
n = 100
|
||||
pinn.span_pts(n, 'grid', locations=boundaries)
|
||||
pinn.span_pts(n, 'grid', locations=['D'])
|
||||
pinn.discretise_domain(n, 'grid', locations=boundaries)
|
||||
pinn.discretise_domain(n, 'grid', locations=['D'])
|
||||
pinn.train(5)
|
||||
"""
|
||||
97
tests/test_problem.py
Normal file
97
tests/test_problem.py
Normal file
@@ -0,0 +1,97 @@
|
||||
import torch
|
||||
import pytest
|
||||
|
||||
from pina.problem import SpatialProblem
|
||||
from pina.operators import nabla
|
||||
from pina import LabelTensor, Condition
|
||||
from pina.geometry import CartesianDomain
|
||||
from pina.equation.equation import Equation
|
||||
from pina.equation.equation_factory import FixedValue
|
||||
|
||||
|
||||
def laplace_equation(input_, output_):
|
||||
force_term = (torch.sin(input_.extract(['x'])*torch.pi) *
|
||||
torch.sin(input_.extract(['y'])*torch.pi))
|
||||
nabla_u = nabla(output_.extract(['u']), input_)
|
||||
return nabla_u - force_term
|
||||
|
||||
my_laplace = Equation(laplace_equation)
|
||||
in_ = LabelTensor(torch.tensor([[0., 1.]], requires_grad=True), ['x', 'y'])
|
||||
out_ = LabelTensor(torch.tensor([[0.]], requires_grad=True), ['u'])
|
||||
|
||||
class Poisson(SpatialProblem):
|
||||
output_variables = ['u']
|
||||
spatial_domain = CartesianDomain({'x': [0, 1], 'y': [0, 1]})
|
||||
|
||||
conditions = {
|
||||
'gamma1': Condition(
|
||||
location=CartesianDomain({'x': [0, 1], 'y': 1}),
|
||||
equation=FixedValue(0.0)),
|
||||
'gamma2': Condition(
|
||||
location=CartesianDomain({'x': [0, 1], 'y': 0}),
|
||||
equation=FixedValue(0.0)),
|
||||
'gamma3': Condition(
|
||||
location=CartesianDomain({'x': 1, 'y': [0, 1]}),
|
||||
equation=FixedValue(0.0)),
|
||||
'gamma4': Condition(
|
||||
location=CartesianDomain({'x': 0, 'y': [0, 1]}),
|
||||
equation=FixedValue(0.0)),
|
||||
'D': Condition(
|
||||
location=CartesianDomain({'x': [0, 1], 'y': [0, 1]}),
|
||||
equation=my_laplace),
|
||||
'data': Condition(
|
||||
input_points=in_,
|
||||
output_points=out_)
|
||||
}
|
||||
|
||||
def poisson_sol(self, pts):
|
||||
return -(
|
||||
torch.sin(pts.extract(['x'])*torch.pi) *
|
||||
torch.sin(pts.extract(['y'])*torch.pi)
|
||||
)/(2*torch.pi**2)
|
||||
|
||||
truth_solution = poisson_sol
|
||||
|
||||
|
||||
# make the problem
|
||||
poisson_problem = Poisson()
|
||||
|
||||
|
||||
def test_discretise_domain():
|
||||
n = 10
|
||||
boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4']
|
||||
poisson_problem.discretise_domain(n, 'grid', locations=boundaries)
|
||||
for b in boundaries:
|
||||
assert poisson_problem.input_pts[b].shape[0] == n
|
||||
poisson_problem.discretise_domain(n, 'random', locations=boundaries)
|
||||
for b in boundaries:
|
||||
assert poisson_problem.input_pts[b].shape[0] == n
|
||||
|
||||
poisson_problem.discretise_domain(n, 'grid', locations=['D'])
|
||||
assert poisson_problem.input_pts['D'].shape[0] == n**2
|
||||
poisson_problem.discretise_domain(n, 'random', locations=['D'])
|
||||
assert poisson_problem.input_pts['D'].shape[0] == n
|
||||
|
||||
poisson_problem.discretise_domain(n, 'latin', locations=['D'])
|
||||
assert poisson_problem.input_pts['D'].shape[0] == n
|
||||
|
||||
poisson_problem.discretise_domain(n, 'lh', locations=['D'])
|
||||
assert poisson_problem.input_pts['D'].shape[0] == n
|
||||
|
||||
def test_sampling_all_args():
|
||||
n = 10
|
||||
poisson_problem.discretise_domain(n, 'grid', locations=['D'])
|
||||
|
||||
def test_sampling_all_kwargs():
|
||||
n = 10
|
||||
poisson_problem.discretise_domain(n=n, mode='latin', locations=['D'])
|
||||
|
||||
def test_sampling_dict():
|
||||
n = 10
|
||||
poisson_problem.discretise_domain(
|
||||
{'variables': ['x', 'y'], 'mode': 'grid', 'n': n}, locations=['D'])
|
||||
|
||||
def test_sampling_mixed_args_kwargs():
|
||||
n = 10
|
||||
with pytest.raises(ValueError):
|
||||
poisson_problem.discretise_domain(n, mode='latin', locations=['D'])
|
||||
Reference in New Issue
Block a user