From 63fd068988e5c59e48e67a655545336118763aa4 Mon Sep 17 00:00:00 2001 From: Dario Coscia <93731561+dario-coscia@users.noreply.github.com> Date: Wed, 7 Jun 2023 15:34:43 +0200 Subject: [PATCH] Lightining update (#104) * multiple functions for version 0.0 * lightining update * minor changes * data pinn loss added --------- Co-authored-by: Nicola Demo Co-authored-by: Dario Coscia Co-authored-by: Dario Coscia Co-authored-by: Dario Coscia Co-authored-by: Dario Coscia --- pina/__init__.py | 4 +- pina/dataset.py | 1 + pina/label_tensor.py | 9 +- pina/loss.py | 127 ++++++++++ pina/model/__init__.py | 2 - pina/model/network.py | 112 ++------- pina/pinn.py | 410 +++++++------------------------ pina/plotter.py | 88 +++---- pina/problem/abstract_problem.py | 102 ++++++++ pina/solver.py | 65 +++++ pina/trainer.py | 31 +++ pina/utils.py | 29 ++- tests/test_loss.py | 49 ++++ tests/test_model/test_network.py | 55 ----- tests/test_pinn.py | 132 ++++------ tests/test_problem.py | 97 ++++++++ 16 files changed, 710 insertions(+), 603 deletions(-) create mode 100644 pina/loss.py create mode 100644 pina/solver.py create mode 100644 pina/trainer.py create mode 100644 tests/test_loss.py delete mode 100644 tests/test_model/test_network.py create mode 100644 tests/test_problem.py diff --git a/pina/__init__.py b/pina/__init__.py index eda80de..b8d4063 100644 --- a/pina/__init__.py +++ b/pina/__init__.py @@ -1,5 +1,6 @@ __all__ = [ 'PINN', + 'Trainer', 'LabelTensor', 'Plotter', 'Condition', @@ -10,7 +11,8 @@ __all__ = [ from .meta import * from .label_tensor import LabelTensor from .pinn import PINN +from .trainer import Trainer from .plotter import Plotter from .condition import Condition from .geometry import Location -from .geometry import CartesianDomain +from .geometry import CartesianDomain \ No newline at end of file diff --git a/pina/dataset.py b/pina/dataset.py index fe7828f..ac81f62 100644 --- a/pina/dataset.py +++ b/pina/dataset.py @@ -117,6 +117,7 @@ class LabelTensorDataset(Dataset): def __len__(self): return max([len(getattr(self, label)) for label in self.labels]) +# TODO: working also for datapoints class DummyLoader: def __init__(self, data) -> None: diff --git a/pina/label_tensor.py b/pina/label_tensor.py index bdf3ec6..f79420c 100644 --- a/pina/label_tensor.py +++ b/pina/label_tensor.py @@ -88,6 +88,8 @@ class LabelTensor(torch.Tensor): self._labels = labels # assign the label + # TODO remove try/ except thing IMPORTANT + # make the label None of default def clone(self, *args, **kwargs): """ Clone the LabelTensor. For more details, see @@ -96,7 +98,12 @@ class LabelTensor(torch.Tensor): :return: a copy of the tensor :rtype: LabelTensor """ - return LabelTensor(super().clone(*args, **kwargs), self.labels) + try: + out = LabelTensor(super().clone(*args, **kwargs), self.labels) + except: + out = super().clone(*args, **kwargs) + + return out def to(self, *args, **kwargs): """ diff --git a/pina/loss.py b/pina/loss.py new file mode 100644 index 0000000..0073fb2 --- /dev/null +++ b/pina/loss.py @@ -0,0 +1,127 @@ +""" Module for EquationInterface class """ +from abc import ABCMeta, abstractmethod +from torch.nn.modules.loss import _Loss +import torch +from .utils import check_consistency + +__all__ = ['LpLoss'] + +class LossInterface(_Loss, metaclass=ABCMeta): + """ + The abstract `LossInterface` class. All the class defining a PINA Loss + should be inheritied from this class. + """ + + def __init__(self, reduction = 'mean'): + """ + :param str reduction: Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction + will be applied, ``'mean'``: the sum of the output will be divided + by the number of elements in the output, ``'sum'``: the output will + be summed. Note: :attr:`size_average` and :attr:`reduce` are in the + process of being deprecated, and in the meantime, specifying either of + those two args will override :attr:`reduction`. Default: ``'mean'``. + """ + super().__init__(reduction=reduction, size_average=None, reduce=None) + + @abstractmethod + def forward(self): + pass + + def _reduction(self, loss): + """Simple helper function to check reduction + + :param reduction: Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction + will be applied, ``'mean'``: the sum of the output will be divided + by the number of elements in the output, ``'sum'``: the output will + be summed. Note: :attr:`size_average` and :attr:`reduce` are in the + process of being deprecated, and in the meantime, specifying either of + those two args will override :attr:`reduction`. Default: ``'mean'``. + :type reduction: str, optional + :param loss: Loss tensor for each element. + :type loss: torch.Tensor + :return: Reduced loss. + :rtype: torch.Tensor + """ + if self.reduction == "none": + ret = loss + elif self.reduction == "mean": + ret = torch.mean(loss, keepdim=True, dim=-1) + elif self.reduction == "sum": + ret = torch.sum(loss, keepdim=True, dim=-1) + else: + raise ValueError(self.reduction + " is not valid") + return ret + +class LpLoss(LossInterface): + """ + The Lp loss implementation class. Creates a criterion that measures + the Lp error between each element in the input :math:`x` and + target :math:`y`. + + The unreduced (i.e. with :attr:`reduction` set to ``'none'``) loss can + be described as: + + .. math:: + \ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad + l_n = \left| x_n - y_n \right|^p, + + If ``'relative'`` is set to true: + + .. math:: + \ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad + l_n = \left[\frac{\left| x_n - y_n \right|^p}{\left|y_n \right|^p}\right]^{1/p}, + + where :math:`N` is the batch size. If :attr:`reduction` is not ``'none'`` + (default ``'mean'``), then: + + .. math:: + \ell(x, y) = + \begin{cases} + \operatorname{mean}(L), & \text{if reduction} = \text{`mean';}\\ + \operatorname{sum}(L), & \text{if reduction} = \text{`sum'.} + \end{cases} + + :math:`x` and :math:`y` are tensors of arbitrary shapes with a total + of :math:`n` elements each. + + The sum operation still operates over all the elements, and divides by :math:`n`. + + The division by :math:`n` can be avoided if one sets ``reduction = 'sum'``. + """ + + def __init__(self, p=2, reduction = 'mean', relative = False): + """ + :param int p: Degree of Lp norm. It specifies the type of norm to + be calculated. See :meth:`torch.linalg.norm` ```'ord'``` to + see the possible degrees. Default 2 (euclidean norm). + :param str reduction: Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction + will be applied, ``'mean'``: the sum of the output will be divided + by the number of elements in the output, ``'sum'``: the output will + be summed. Note: :attr:`size_average` and :attr:`reduce` are in the + process of being deprecated, and in the meantime, specifying either of + those two args will override :attr:`reduction`. Default: ``'mean'``. + :param bool relative: Specifies if relative error should be computed. + """ + super().__init__(reduction=reduction) + + # check consistency + check_consistency(p, (str,int,float), 'degree p') + self.p = p + check_consistency(relative, bool, 'relative') + self.relative = relative + + def forward(self, input, target): + """Forward method for loss function. + + :param torch.Tensor input: Input tensor from real data. + :param torch.Tensor target: Model tensor output. + :return: Loss evaluation. + :rtype: torch.Tensor + """ + loss = torch.linalg.norm((input-target), ord=self.p, dim=-1) + if self.relative: + loss = loss / torch.linalg.norm(input, ord=self.p, dim=-1) + return self._reduction(loss) diff --git a/pina/model/__init__.py b/pina/model/__init__.py index 5f55123..81fbc09 100644 --- a/pina/model/__init__.py +++ b/pina/model/__init__.py @@ -2,10 +2,8 @@ __all__ = [ 'FeedForward', 'MultiFeedForward', 'DeepONet', - 'Network', ] from .feed_forward import FeedForward from .multi_feed_forward import MultiFeedForward from .deeponet import DeepONet -from .network import Network diff --git a/pina/model/network.py b/pina/model/network.py index e4f50e7..752d1df 100644 --- a/pina/model/network.py +++ b/pina/model/network.py @@ -1,107 +1,47 @@ import torch -from pina.label_tensor import LabelTensor +import torch.nn as nn +from ..utils import check_consistency class Network(torch.nn.Module): - """The PINA implementation of any neural network. - - :param torch.nn.Module model: the torch model of the network. - :param list(str) input_variables: the list containing the labels - corresponding to the input components of the model. - :param list(str) output_variables: the list containing the labels - corresponding to the components of the output computed by the model. - :param torch.nn.Module extra_features: the additional input - features to use as augmented input. - - :Example: - >>> class SimpleNet(nn.Module): - >>> def __init__(self): - >>> super().__init__() - >>> self.layers = nn.Sequential( - >>> nn.Linear(2, 20), - >>> nn.Tanh(), - >>> nn.Linear(20, 1) - >>> ) - >>> def forward(self, x): - >>> return self.layers(x) - >>> net = SimpleNet() - >>> input_variables = ['x', 'y'] - >>> output_variables =['u'] - >>> model_feat = Network(net, input_variables, output_variables) - Network( - (extra_features): Sequential() - (model): Sequential( - (0): Linear(in_features=2, out_features=20, bias=True) - (1): Tanh() - (2): Linear(in_features=20, out_features=1, bias=True) - ) - ) - """ - - def __init__(self, model, input_variables, - output_variables, extra_features=None): + + def __init__(self, model, extra_features=None): super().__init__() - print('HHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHH') - if extra_features is None: - extra_features = [] - - self._extra_features = torch.nn.Sequential(*extra_features) + # check model consistency + check_consistency(model, nn.Module, 'torch model') self._model = model - self._input_variables = input_variables - self._output_variables = output_variables - print(output_variables) - # check model and input/output - self._check_consistency() + # check consistency and assign extra fatures + if extra_features is None: + self._extra_features = [] + else: + for feat in extra_features: + check_consistency(feat, nn.Module, 'extra features') + self._extra_features = nn.Sequential(*extra_features) - def _check_consistency(self): - """Checking the consistency of model with input and output variables - - :raises ValueError: Error in constructing the PINA network - """ - try: - pass - # tmp = torch.rand((10, len(self._input_variables))) - # tmp = LabelTensor(tmp, self._input_variables) - # tmp = self.forward(tmp) # trying a forward pass - # tmp = LabelTensor(tmp, self._output_variables) - except: - raise ValueError('Error in constructing the PINA network.' - ' Check compatibility of input/output' - ' variables shape with the torch model' - ' or check the correctness of the torch' - ' model itself.') + # check model works with inputs + # TODO def forward(self, x): - """Forward method for Network class + """ + Forward method for Network class. This class + implements the standard forward method, and + it adds the possibility to pass extra features. :param torch.tensor x: input of the network :return torch.tensor: output of the network """ - - x = x.extract(self._input_variables) - + # extract features and append for feature in self._extra_features: x = x.append(feature(x)) - - output = self._model(x).as_subclass(LabelTensor) - output.labels = self._output_variables - - return output - - @property - def input_variables(self): - return self._input_variables - - @property - def output_variables(self): - return self._output_variables - - @property - def extra_features(self): - return self._extra_features + # perform forward pass + return self._model(x) @property def model(self): return self._model + + @property + def extra_features(self): + return self._extra_features \ No newline at end of file diff --git a/pina/pinn.py b/pina/pinn.py index b7a22af..62c647b 100644 --- a/pina/pinn.py +++ b/pina/pinn.py @@ -2,352 +2,118 @@ import torch import torch.optim.lr_scheduler as lrs -from .problem import AbstractProblem -from .model import Network + +from .solver import SolverInterface from .label_tensor import LabelTensor -from .utils import merge_tensors -from .dataset import DummyLoader +from .utils import check_consistency +from .writer import Writer +from .loss import LossInterface +from torch.nn.modules.loss import _Loss torch.pi = torch.acos(torch.zeros(1)).item() * 2 # which is 3.1415927410125732 -class PINN(object): +class PINN(SolverInterface): def __init__(self, problem, model, extra_features=None, + loss = torch.nn.MSELoss(), optimizer=torch.optim.Adam, - optimizer_kwargs=None, - lr=0.001, - lr_scheduler_type=lrs.ConstantLR, - lr_scheduler_kwargs={"factor": 1, "total_iters": 0}, - regularizer=0.00001, - batch_size=None, - dtype=torch.float32, - device='cpu', - writer=None, - error_norm='mse'): + optimizer_kwargs={'lr' : 0.001}, + scheduler=lrs.ConstantLR, + scheduler_kwargs={"factor": 1, "total_iters": 0}, + ): ''' - :param AbstractProblem problem: the formualation of the problem. - :param torch.nn.Module model: the neural network model to use. - :param torch.nn.Module extra_features: the additional input + :param AbstractProblem problem: The formualation of the problem. + :param torch.nn.Module model: The neural network model to use. + :param torch.nn.Module loss: The loss function used as minimizer, + default torch.nn.MSELoss(). + :param torch.nn.Module extra_features: The additional input features to use as augmented input. - :param torch.optim.Optimizer optimizer: the neural network optimizer to + :param torch.optim.Optimizer optimizer: The neural network optimizer to use; default is `torch.optim.Adam`. :param dict optimizer_kwargs: Optimizer constructor keyword args. - :param float lr: the learning rate; default is 0.001. - :param torch.optim.LRScheduler lr_scheduler_type: Learning + :param float lr: The learning rate; default is 0.001. + :param torch.optim.LRScheduler scheduler: Learning rate scheduler. - :param dict lr_scheduler_kwargs: LR scheduler constructor keyword args. - :param float regularizer: the coefficient for L2 regularizer term. - :param type dtype: the data type to use for the model. Valid option are - `torch.float32` and `torch.float64` (`torch.float16` only on GPU); - default is `torch.float64`. - :param str device: the device used for training; default 'cpu' - option include 'cuda' if cuda is available. - :param (str, int) error_norm: the loss function used as minimizer, - default mean square error 'mse'. If string options include mean - error 'me' and mean square error 'mse'. If int, the p-norm is - calculated where p is specifined by the int input. - :param int batch_size: batch size for the dataloader; default 5. + :param dict scheduler_kwargs: LR scheduler constructor keyword args. ''' - - if dtype == torch.float64: - raise NotImplementedError('only float for now') - - self.problem = problem - - # self._architecture = architecture if architecture else dict() - # self._architecture['input_dimension'] = self.problem.domain_bound.shape[0] - # self._architecture['output_dimension'] = len(self.problem.variables) - # if hasattr(self.problem, 'params_domain'): - # self._architecture['input_dimension'] += self.problem.params_domain.shape[0] - - self.error_norm = error_norm - - if device == 'cuda' and not torch.cuda.is_available(): - raise RuntimeError - self.device = torch.device(device) - - self.dtype = dtype - self.history_loss = {} - - - self.model = Network(model=model, - input_variables=problem.input_variables, - output_variables=problem.output_variables, - extra_features=extra_features) - - self.model.to(dtype=self.dtype, device=self.device) - - self.truth_values = {} - self.input_pts = {} - - self.trained_epoch = 0 - - from .writer import Writer - if writer is None: - writer = Writer() - self.writer = writer - - if not optimizer_kwargs: - optimizer_kwargs = {} - optimizer_kwargs['lr'] = lr - self.optimizer = optimizer( - self.model.parameters())#, weight_decay=regularizer, **optimizer_kwargs) - #self._lr_scheduler = lr_scheduler_type( - # self.optimizer, **lr_scheduler_kwargs) - - self.batch_size = batch_size - # self.data_set = PinaDataset(self) - - @property - def problem(self): - """ The problem formulation.""" - return self._problem - - @problem.setter - def problem(self, problem): - """ - Set the problem formulation.""" - if not isinstance(problem, AbstractProblem): - raise TypeError - self._problem = problem - - def _compute_norm(self, vec): - """ - Compute the norm of the `vec` one-dimensional tensor based on the - `self.error_norm` attribute. - - .. todo: complete - - :param torch.Tensor vec: the tensor - """ - if isinstance(self.error_norm, int): - return torch.linalg.vector_norm(vec, ord=self.error_norm, dtype=self.dytpe) - elif self.error_norm == 'mse': - return torch.mean(vec.pow(2)) - elif self.error_norm == 'me': - return torch.mean(torch.abs(vec)) - else: - raise RuntimeError - - def save_state(self, filename): - """ - Save the state of the model. - - :param str filename: the filename to save the state to. - """ - checkpoint = { - 'epoch': self.trained_epoch, - 'model_state': self.model.state_dict(), - 'optimizer_state': self.optimizer.state_dict(), - 'optimizer_class': self.optimizer.__class__, - 'history': self.history_loss, - 'input_points_dict': self.input_pts, - } - - # TODO save also architecture param? - # if isinstance(self.model, DeepFeedForward): - # checkpoint['model_class'] = self.model.__class__ - # checkpoint['model_structure'] = { - # } - torch.save(checkpoint, filename) - - def load_state(self, filename): - """ - Load the state of the model. + super().__init__(model=model, problem=problem, extra_features=extra_features) - :param str filename: the filename to load the state from. + # check consistency + check_consistency(optimizer, torch.optim.Optimizer, 'optimizer', subclass=True) + check_consistency(optimizer_kwargs, dict, 'optimizer_kwargs') + check_consistency(scheduler, lrs.LRScheduler, 'scheduler', subclass=True) + check_consistency(scheduler_kwargs, dict, 'scheduler_kwargs') + check_consistency(loss, (LossInterface, _Loss), 'loss', subclass=False) + + # assign variables + self._optimizer = optimizer(self.model.parameters(), **optimizer_kwargs) + self._scheduler = scheduler(self._optimizer, **scheduler_kwargs) + self._loss = loss + self._writer = Writer() + + + def forward(self, x): + """Forward pass implementation for the PINN + solver. + + :param torch.tensor x: Input data. + :return: PINN solution. + :rtype: torch.tensor + """ + # extract labels + x = x.extract(self.problem.input_variables) + # perform forward pass + output = self.model(x).as_subclass(LabelTensor) + # set the labels + output.labels = self.problem.output_variables + return output + + def configure_optimizers(self): + """Optimizer configuration for the PINN + solver. + + :return: The optimizers and the schedulers + :rtype: tuple(list, list) + """ + return [self._optimizer], [self._scheduler] + + def training_step(self, batch, batch_idx): + """PINN solver training step. + + :param batch: The batch element in the dataloader. + :type batch: tuple + :param batch_idx: The batch index. + :type batch_idx: int + :return: The sum of the loss functions. + :rtype: LabelTensor """ - checkpoint = torch.load(filename) - self.model.load_state_dict(checkpoint['model_state']) + condition_losses = [] - self.optimizer = checkpoint['optimizer_class'](self.model.parameters()) - self.optimizer.load_state_dict(checkpoint['optimizer_state']) + for condition_name, samples in batch.items(): - self.trained_epoch = checkpoint['epoch'] - self.history_loss = checkpoint['history'] + if condition_name not in self.problem.conditions: + raise RuntimeError('Something wrong happened.') - self.input_pts = checkpoint['input_points_dict'] + condition = self.problem.conditions[condition_name] - return self + # PINN loss: equation evaluated on location or input_points + if hasattr(condition, 'equation'): + target = condition.equation.residual(samples, self.forward(samples)) + loss = self._loss(torch.zeros_like(target), target) + # PINN loss: evaluate model(input_points) vs output_points + elif hasattr(condition, 'output_points'): + input_pts, output_pts = samples + loss = self._loss(self.forward(input_pts), output_pts) - def span_pts(self, *args, **kwargs): - """ - Generate a set of points to span the `Location` of all the conditions of - the problem. + condition_losses.append(loss * condition.data_weight) - >>> pinn.span_pts(n=10, mode='grid') - >>> pinn.span_pts(n=10, mode='grid', location=['bound1']) - >>> pinn.span_pts(n=10, mode='grid', variables=['x']) - """ - - if all(key in kwargs for key in ['n', 'mode']): - argument = {} - argument['n'] = kwargs['n'] - argument['mode'] = kwargs['mode'] - argument['variables'] = self.problem.input_variables - arguments = [argument] - elif any(key in kwargs for key in ['n', 'mode']) and args: - raise ValueError("Don't mix args and kwargs") - elif isinstance(args[0], int) and isinstance(args[1], str): - argument = {} - argument['n'] = int(args[0]) - argument['mode'] = args[1] - argument['variables'] = self.problem.input_variables - arguments = [argument] - elif all(isinstance(arg, dict) for arg in args): - arguments = args - else: - raise RuntimeError - - locations = kwargs.get('locations', 'all') - - if locations == 'all': - locations = [condition for condition in self.problem.conditions] - for location in locations: - condition = self.problem.conditions[location] - - samples = tuple(condition.location.sample( - argument['n'], - argument['mode'], - variables=argument['variables']) - for argument in arguments) - pts = merge_tensors(samples) - - # TODO - # pts = pts.double() - self.input_pts[location] = pts - - def _residual_loss(self, input_pts, equation): - """ - Compute the residual loss for a given condition. - - :param torch.Tensor pts: the points to evaluate the residual at. - :param Equation equation: the equation to evaluate the residual with. - """ - - input_pts = input_pts.to(dtype=self.dtype, device=self.device) - input_pts.requires_grad_(True) - input_pts.retain_grad() - - predicted = self.model(input_pts) - residuals = equation.residual(input_pts, predicted) - return self._compute_norm(residuals) - - def _data_loss(self, input_pts, output_pts): - """ - Compute the residual loss for a given condition. - - :param torch.Tensor pts: the points to evaluate the residual at. - :param Equation equation: the equation to evaluate the residual with. - """ - input_pts = input_pts.to(dtype=self.dtype, device=self.device) - output_pts = output_pts.to(dtype=self.dtype, device=self.device) - predicted = self.model(input_pts) - residuals = predicted - output_pts - return self._compute_norm(residuals) - - - # def closure(self): - # """ - # """ - # self.optimizer.zero_grad() - - # condition_losses = [] - # from torch.utils.data import DataLoader - # from .utils import MyDataset - # loader = DataLoader( - # MyDataset(self.input_pts), - # batch_size=self.batch_size, - # num_workers=1 - # ) - # for condition_name in self.problem.conditions: - # condition = self.problem.conditions[condition_name] - - # batch_losses = [] - # for batch in data_loader[condition_name]: - - # if hasattr(condition, 'equation'): - # loss = self._residual_loss( - # batch[condition_name], condition.equation) - # elif hasattr(condition, 'output_points'): - # loss = self._data_loss( - # batch[condition_name], condition.output_points) - - # batch_losses.append(loss * condition.data_weight) - - # condition_losses.append(sum(batch_losses)) - - # loss = sum(condition_losses) - # loss.backward() - # return loss - - def closure(self): - """ - """ - self.optimizer.zero_grad() - - losses = [] - for i, batch in enumerate(self.loader): - - condition_losses = [] - - for condition_name, samples in batch.items(): - - if condition_name not in self.problem.conditions: - raise RuntimeError('Something wrong happened.') - - if samples is None or samples.nelement() == 0: - continue - - condition = self.problem.conditions[condition_name] - - if hasattr(condition, 'equation'): - loss = self._residual_loss(samples, condition.equation) - elif hasattr(condition, 'output_points'): - loss = self._data_loss(samples, condition.output_points) - - condition_losses.append(loss * condition.data_weight) - - losses.append(sum(condition_losses)) - - loss = sum(losses) - loss.backward() - return losses[0] - - def train(self, stop=100): - - self.model.train() - - ############################################################ - ## TODO: move to problem class - for condition in list(set(self.problem.conditions.keys()) - set(self.input_pts.keys())): - self.input_pts[condition] = self.problem.conditions[condition].input_points - - mydata = self.input_pts - - self.loader = DummyLoader(mydata) - - while True: - - loss = self.optimizer.step(closure=self.closure) - - self.writer.write_loss_in_loop(self, loss) - - #self._lr_scheduler.step() - - if isinstance(stop, int): - if self.trained_epoch == stop: - break - elif isinstance(stop, float): - if loss.item() < stop: - break - - self.trained_epoch += 1 - - self.model.eval() \ No newline at end of file + # TODO Fix the bug, tot_loss is a label tensor without labels + # we need to pass it as a torch tensor to make everything work + total_loss = sum(condition_losses) + return total_loss \ No newline at end of file diff --git a/pina/plotter.py b/pina/plotter.py index b509e2d..d92e780 100644 --- a/pina/plotter.py +++ b/pina/plotter.py @@ -11,11 +11,11 @@ class Plotter: Implementation of a plotter class, for easy visualizations. """ - def plot_samples(self, pinn, variables=None): + def plot_samples(self, solver, variables=None): """ - Plot a sample of solution. + Plot the training grid samples. - :param PINN pinn: the PINN object. + :param SolverInterface solver: the SolverInterface object. :param list(str) variables: variables to plot. If None, all variables are plotted. If 'spatial', only spatial variables are plotted. If 'temporal', only temporal variables are plotted. Defaults to None. @@ -26,15 +26,15 @@ class Plotter: :Example: >>> plotter = Plotter() - >>> plotter.plot_samples(pinn=pinn, variables='spatial') + >>> plotter.plot_samples(solver=solver, variables='spatial') """ if variables is None: - variables = pinn.problem.domain.variables + variables = solver.problem.domain.variables elif variables == 'spatial': - variables = pinn.problem.spatial_domain.variables + variables = solver.problem.spatial_domain.variables elif variables == 'temporal': - variables = pinn.problem.temporal_domain.variables + variables = solver.problem.temporal_domain.variables if len(variables) not in [1, 2, 3]: raise ValueError @@ -42,8 +42,8 @@ class Plotter: fig = plt.figure() proj = '3d' if len(variables) == 3 else None ax = fig.add_subplot(projection=proj) - for location in pinn.input_pts: - coords = pinn.input_pts[location].extract(variables).T.detach() + for location in solver.problem.input_pts: + coords = solver.problem.input_pts[location].extract(variables).T.detach() if coords.shape[0] == 1: # 1D samples ax.plot(coords[0], torch.zeros(coords[0].shape), '.', label=location) @@ -69,7 +69,7 @@ class Plotter: :param pts: Points to plot the solution. :type pts: torch.Tensor - :param pred: PINN solution evaluated at 'pts'. + :param pred: SolverInterface solution evaluated at 'pts'. :type pred: torch.Tensor :param method: not used, kept for code compatibility :type method: None @@ -95,7 +95,7 @@ class Plotter: :param pts: Points to plot the solution. :type pts: torch.Tensor - :param pred: PINN solution evaluated at 'pts'. + :param pred: SolverInterface solution evaluated at 'pts'. :type pred: torch.Tensor :param method: matplotlib method to plot 2-dimensional data, see https://matplotlib.org/stable/api/axes_api.html for @@ -129,12 +129,12 @@ class Plotter: *grids, pred_output.cpu().detach(), **kwargs) fig.colorbar(cb, ax=ax) - def plot(self, pinn, components=None, fixed_variables={}, method='contourf', + def plot(self, solver, components=None, fixed_variables={}, method='contourf', res=256, filename=None, **kwargs): """ - Plot sample of PINN output. + Plot sample of SolverInterface output. - :param PINN pinn: the PINN object. + :param SolverInterface solver: the SolverInterface object. :param list(str) components: the output variable to plot. If None, all the output variables of the problem are selected. Default value is None. @@ -150,12 +150,12 @@ class Plotter: is shown using the setted matplotlib frontend. Default is None. """ if components is None: - components = [pinn.problem.output_variables] + components = [solver.problem.output_variables] v = [ - var for var in pinn.problem.input_variables + var for var in solver.problem.input_variables if var not in fixed_variables.keys() ] - pts = pinn.problem.domain.sample(res, 'grid', variables=v) + pts = solver.problem.domain.sample(res, 'grid', variables=v) fixed_pts = torch.ones(pts.shape[0], len(fixed_variables)) fixed_pts *= torch.tensor(list(fixed_variables.values())) @@ -163,15 +163,15 @@ class Plotter: fixed_pts.labels = list(fixed_variables.keys()) pts = pts.append(fixed_pts) - pts = pts.to(device=pinn.device) + pts = pts.to(device=solver.device) - predicted_output = pinn.model(pts) + predicted_output = solver.forward(pts) if isinstance(components, str): predicted_output = predicted_output.extract(components) elif callable(components): predicted_output = components(predicted_output) - truth_solution = getattr(pinn.problem, 'truth_solution', None) + truth_solution = getattr(solver.problem, 'truth_solution', None) if len(v) == 1: self._1d_plot(pts, predicted_output, method, truth_solution, **kwargs) @@ -186,37 +186,25 @@ class Plotter: else: plt.show() - def plot_loss(self, pinn, label=None, log_scale=True, filename=None): - """ - Plot the loss function values during traininig. + # TODO loss + # def plot_loss(self, solver, label=None, log_scale=True): + # """ + # Plot the loss function values during traininig. - :param PINN pinn: the PINN object. - :param str label: the label to use in the legend, defaults to None. - :param bool log_scale: If True, the y axis is in log scale. Default is - True. - :param str filename: the file name to save the plot. If None, the plot - is not saved. Default is None. - """ + # :param SolverInterface solver: the SolverInterface object. + # :param str label: the label to use in the legend, defaults to None. + # :param bool log_scale: If True, the y axis is in log scale. Default is + # True. + # """ - if not label: - label = str(pinn) + # if not label: + # label = str(solver) - epochs = list(pinn.history_loss.keys()) - loss = np.array(list(pinn.history_loss.values())) + # epochs = list(solver.history_loss.keys()) + # loss = np.array(list(solver.history_loss.values())) + # if loss.ndim != 1: + # loss = loss[:, 0] - # if multiple outputs, sum the loss - if loss.ndim != 1: - loss = np.sum(loss, axis=1) - - # plot loss - plt.plot(epochs, loss, label=label) - plt.legend() - if log_scale: - plt.yscale('log') - plt.title('Loss function') - plt.xlabel('Epochs') - plt.ylabel('Loss') - - # save plot - if filename: - plt.savefig(filename) + # plt.plot(epochs, loss, label=label) + # if log_scale: + # plt.yscale('log') diff --git a/pina/problem/abstract_problem.py b/pina/problem/abstract_problem.py index fc81e22..8f32e90 100644 --- a/pina/problem/abstract_problem.py +++ b/pina/problem/abstract_problem.py @@ -1,5 +1,6 @@ """ Module for AbstractProblem class """ from abc import ABCMeta, abstractmethod +from ..utils import merge_tensors class AbstractProblem(metaclass=ABCMeta): @@ -11,6 +12,19 @@ class AbstractProblem(metaclass=ABCMeta): the output variables, the condition(s), and the domain(s) where the conditions are applied. """ + + def __init__(self): + + # variable storing all points + self.input_pts = {} + + # varible to check if sampling is done. If no location + # element is presented in Condition this variable is set to true + self._have_sampled_points = {} + + # put in self.input_pts all the points that we don't need to sample + self._span_condition_points() + @property def input_variables(self): """ @@ -80,3 +94,91 @@ class AbstractProblem(metaclass=ABCMeta): The conditions of the problem. """ pass + + def _span_condition_points(self): + """ + Simple function to get the condition points + """ + for condition_name in self.conditions: + condition = self.conditions[condition_name] + if hasattr(condition, 'equation') and hasattr(condition, 'input_points'): + samples = condition.input_points + elif hasattr(condition, 'output_points') and hasattr(condition, 'input_points'): + samples = (condition.input_points, condition.output_points) + # skip if we need to sample + elif hasattr(condition, 'location'): + self._have_sampled_points[condition_name] = False + continue + self.input_pts[condition_name] = samples + + def discretise_domain(self, *args, **kwargs): + """ + Generate a set of points to span the `Location` of all the conditions of + the problem. + + >>> pinn.span_pts(n=10, mode='grid') + >>> pinn.span_pts(n=10, mode='grid', location=['bound1']) + >>> pinn.span_pts(n=10, mode='grid', variables=['x']) + """ + if all(key in kwargs for key in ['n', 'mode']): + argument = {} + argument['n'] = kwargs['n'] + argument['mode'] = kwargs['mode'] + argument['variables'] = self.input_variables + arguments = [argument] + elif any(key in kwargs for key in ['n', 'mode']) and args: + raise ValueError("Don't mix args and kwargs") + elif isinstance(args[0], int) and isinstance(args[1], str): + argument = {} + argument['n'] = int(args[0]) + argument['mode'] = args[1] + argument['variables'] = self.input_variables + arguments = [argument] + elif all(isinstance(arg, dict) for arg in args): + arguments = args + else: + raise RuntimeError + + locations = kwargs.get('locations', 'all') + + if locations == 'all': + locations = [condition for condition in self.conditions] + for location in locations: + condition = self.conditions[location] + + samples = tuple(condition.location.sample( + argument['n'], + argument['mode'], + variables=argument['variables']) + for argument in arguments) + pts = merge_tensors(samples) + self.input_pts[location] = pts + # setting the grad + self.input_pts[location].requires_grad_(True) + self.input_pts[location].retain_grad() + # the condition is sampled + self._have_sampled_points[location] = True + + @property + def have_sampled_points(self): + """ + Check if all points for + ``'Location'`` are sampled. + """ + return all(self._have_sampled_points.values()) + + @property + def not_sampled_points(self): + """Check which points are + not sampled. + """ + # variables which are not sampled + not_sampled = None + if self.have_sampled_points is False: + # check which one are not sampled: + not_sampled = [] + for condition_name, is_sample in self._have_sampled_points.items(): + if not is_sample: + not_sampled.append(condition_name) + return not_sampled + diff --git a/pina/solver.py b/pina/solver.py new file mode 100644 index 0000000..9625603 --- /dev/null +++ b/pina/solver.py @@ -0,0 +1,65 @@ +""" Solver module. """ + +from abc import ABCMeta, abstractmethod +from .model.network import Network +import lightning.pytorch as pl +from .utils import check_consistency +from .problem import AbstractProblem + +class SolverInterface(pl.LightningModule, metaclass=ABCMeta): + """ Solver base class. """ + def __init__(self, model, problem, extra_features=None): + """ + :param model: A torch neural network model instance. + :type model: torch.nn.Module + :param problem: A problem definition instance. + :type problem: AbstractProblem + :param list(torch.nn.Module) extra_features: the additional input + features to use as augmented input. + """ + super().__init__() + + # check inheritance for pina problem + check_consistency(problem, AbstractProblem, 'pina problem') + + # assigning class variables (check consistency inside Network class) + self._pina_model = Network(model=model, extra_features=extra_features) + self._problem = problem + + @abstractmethod + def forward(self): + pass + + @abstractmethod + def training_step(self): + pass + + @abstractmethod + def configure_optimizers(self): + pass + + @property + def model(self): + """ + The torch model.""" + return self._pina_model + + @property + def problem(self): + """ + The problem formulation.""" + return self._problem + + # @model.setter + # def model(self, new_model): + # """ + # Set the torch.""" + # check_consistency(new_model, nn.Module, 'torch model') + # self._model= new_model + + # @problem.setter + # def problem(self, problem): + # """ + # Set the problem formulation.""" + # check_consistency(problem, AbstractProblem, 'pina problem') + # self._problem = problem \ No newline at end of file diff --git a/pina/trainer.py b/pina/trainer.py new file mode 100644 index 0000000..997f14e --- /dev/null +++ b/pina/trainer.py @@ -0,0 +1,31 @@ +""" Solver module. """ + +import lightning.pytorch as pl +from .utils import check_consistency +from .dataset import DummyLoader +from .solver import SolverInterface + +class Trainer(pl.Trainer): + + def __init__(self, solver, kwargs={}): + super().__init__(**kwargs) + + # check inheritance consistency for solver + check_consistency(solver, SolverInterface, 'Solver model') + self._model = solver + + # create dataloader + if solver.problem.have_sampled_points is False: + raise RuntimeError(f'Input points in {solver.problem.not_sampled_points} ' + 'training are None. Please ' + 'sample points in your problem by calling ' + 'discretise_domain function before train ' + 'in the provided locations.') + + # TODO: make a better dataloader for train + self._loader = DummyLoader(solver.problem.input_pts) + + + def train(self): # TODO add kwargs and lightining capabilities + return super().fit(self._model, self._loader) + diff --git a/pina/utils.py b/pina/utils.py index 3f5d7ef..798a4ad 100644 --- a/pina/utils.py +++ b/pina/utils.py @@ -10,6 +10,29 @@ from .label_tensor import LabelTensor import torch +def check_consistency(object, object_instance, object_name, subclass=False): + """Helper function to check object inheritance consistency. + Given a specific ``'object'`` we check if the object is + instance of a specific ``'object_instance'``, or in case + ``'subclass=True'`` we check if the object is subclass + if the ``'object_instance'``. + + :param Object object: The object to check the inheritance + :param Object object_instance: The parent class from where the object + is expected to inherit + :param str object_name: The name of the object + :param bool subclass: Check if is a subclass and not instance + :raises ValueError: If the object does not inherit from the + specified class + """ + if not subclass: + if not isinstance(object, object_instance): + raise ValueError(f"{object_name} must be {object_instance}") + else: + if not issubclass(object, object_instance): + raise ValueError(f"{object_name} must be {object_instance}") + + def number_parameters(model, aggregate=True, only_trainable=True): # TODO: check """ Return the number of parameters of a given `model`. @@ -189,8 +212,7 @@ class LabelTensorDataset(Dataset): class LabelTensorDataLoader(DataLoader): def collate_fn(self, data): - print(data) - gggggggggg + pass # return dict(zip(self.pinn.input_pts.keys(), dataloaders)) # class SampleDataset(torch.utils.data.Dataset): @@ -239,5 +261,4 @@ class LabelTensorDataset(Dataset): class LabelTensorDataLoader(DataLoader): def collate_fn(self, data): - print(data) - gggggggggg \ No newline at end of file + pass \ No newline at end of file diff --git a/tests/test_loss.py b/tests/test_loss.py new file mode 100644 index 0000000..24db012 --- /dev/null +++ b/tests/test_loss.py @@ -0,0 +1,49 @@ +import torch +import pytest + +from pina.loss import * + +input = torch.tensor([[3.], [1.], [-8.]]) +target = torch.tensor([[6.], [4.], [2.]]) +available_reductions = ['str', 'mean', 'none'] + + +def test_LpLoss_constructor(): + # test reduction + for reduction in available_reductions: + LpLoss(reduction=reduction) + # test p + for p in [float('inf'), -float('inf'), 1, 10, -8]: + LpLoss(p=p) + +def test_LpLoss_forward(): + # l2 loss + loss = LpLoss(p=2, reduction='mean') + l2_loss = torch.mean(torch.sqrt((input-target).pow(2))) + assert loss(input, target) == l2_loss + # l1 loss + loss = LpLoss(p=1, reduction='sum') + l1_loss = torch.sum(torch.abs(input-target)) + assert loss(input, target) == l1_loss + +def test_LpRelativeLoss_constructor(): + # test reduction + for reduction in available_reductions: + LpLoss(reduction=reduction, relative=True) + # test p + for p in [float('inf'), -float('inf'), 1, 10, -8]: + LpLoss(p=p,relative=True) + +def test_LpRelativeLoss_forward(): + # l2 relative loss + loss = LpLoss(p=2, reduction='mean',relative=True) + l2_loss = torch.sqrt((input-target).pow(2))/torch.sqrt(input.pow(2)) + assert loss(input, target) == torch.mean(l2_loss) + # l1 relative loss + loss = LpLoss(p=1, reduction='sum',relative=True) + l1_loss = torch.abs(input-target)/torch.abs(input) + assert loss(input, target) == torch.sum(l1_loss) + + + + diff --git a/tests/test_model/test_network.py b/tests/test_model/test_network.py deleted file mode 100644 index e5223e0..0000000 --- a/tests/test_model/test_network.py +++ /dev/null @@ -1,55 +0,0 @@ -import torch -import torch.nn as nn -import pytest -from pina.model import Network, FeedForward -from pina import LabelTensor - - -class myFeature(torch.nn.Module): - """ - Feature: sin(x) - """ - - def __init__(self): - super(myFeature, self).__init__() - - def forward(self, x): - t = (torch.sin(x.extract(['x'])*torch.pi) * - torch.sin(x.extract(['y'])*torch.pi)) - return LabelTensor(t, ['sin(x)sin(y)']) - - -input_variables = ['x', 'y'] -output_variables = ['u'] -data = torch.rand((20, 2)) -input_ = LabelTensor(data, input_variables) - - -def test_constructor(): - net = FeedForward(2, 1) - pina_net = Network(model=net, input_variables=input_variables, - output_variables=output_variables) - - -def test_forward(): - net = FeedForward(2, 1) - pina_net = Network(model=net, input_variables=input_variables, - output_variables=output_variables) - output_ = pina_net(input_) - assert output_.labels == output_variables - - -def test_constructor_extrafeat(): - net = FeedForward(3, 1) - feat = [myFeature()] - pina_net = Network(model=net, input_variables=input_variables, - output_variables=output_variables, extra_features=feat) - - -def test_forward_extrafeat(): - net = FeedForward(3, 1) - feat = [myFeature()] - pina_net = Network(model=net, input_variables=input_variables, - output_variables=output_variables, extra_features=feat) - output_ = pina_net(input_) - assert output_.labels == output_variables diff --git a/tests/test_pinn.py b/tests/test_pinn.py index 6e6b28b..e04aa58 100644 --- a/tests/test_pinn.py +++ b/tests/test_pinn.py @@ -1,17 +1,18 @@ import torch import pytest -from pina import LabelTensor, Condition, CartesianDomain, PINN from pina.problem import SpatialProblem -from pina.model import FeedForward from pina.operators import nabla +from pina.geometry import CartesianDomain +from pina import Condition, LabelTensor, PINN +from pina.trainer import Trainer +from pina.model import FeedForward from pina.equation.equation import Equation from pina.equation.equation_factory import FixedValue +from pina.plotter import Plotter +from pina.loss import LpLoss -in_ = LabelTensor(torch.tensor([[0., 1.]]), ['x', 'y']) -out_ = LabelTensor(torch.tensor([[0.]]), ['u']) - def laplace_equation(input_, output_): force_term = (torch.sin(input_.extract(['x'])*torch.pi) * torch.sin(input_.extract(['y'])*torch.pi)) @@ -19,6 +20,8 @@ def laplace_equation(input_, output_): return nabla_u - force_term my_laplace = Equation(laplace_equation) +in_ = LabelTensor(torch.tensor([[0., 1.]], requires_grad=True), ['x', 'y']) +out_ = LabelTensor(torch.tensor([[0.]], requires_grad=True), ['u']) class Poisson(SpatialProblem): output_variables = ['u'] @@ -68,75 +71,40 @@ class myFeature(torch.nn.Module): return LabelTensor(t, ['sin(x)sin(y)']) -problem = Poisson() -model = FeedForward(len(problem.input_variables),len(problem.output_variables)) -model_extra_feat = FeedForward(len(problem.input_variables) + 1,len(problem.output_variables)) +# make the problem +poisson_problem = Poisson() +model = FeedForward(len(poisson_problem.input_variables),len(poisson_problem.output_variables)) +model_extra_feats = FeedForward(len(poisson_problem.input_variables)+1,len(poisson_problem.output_variables)) +extra_feats = [myFeature()] def test_constructor(): - PINN(problem, model) + PINN(problem = poisson_problem, model=model, extra_features=None) def test_constructor_extra_feats(): - PINN(problem, model_extra_feat, [myFeature()]) - - -def test_span_pts(): - pinn = PINN(problem, model) - n = 10 - boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4'] - pinn.span_pts(n, 'grid', locations=boundaries) - for b in boundaries: - assert pinn.input_pts[b].shape[0] == n - pinn.span_pts(n, 'random', locations=boundaries) - for b in boundaries: - assert pinn.input_pts[b].shape[0] == n - - pinn.span_pts(n, 'grid', locations=['D']) - assert pinn.input_pts['D'].shape[0] == n**2 - pinn.span_pts(n, 'random', locations=['D']) - assert pinn.input_pts['D'].shape[0] == n - - pinn.span_pts(n, 'latin', locations=['D']) - assert pinn.input_pts['D'].shape[0] == n - - pinn.span_pts(n, 'lh', locations=['D']) - assert pinn.input_pts['D'].shape[0] == n - - -def test_sampling_all_args(): - pinn = PINN(problem, model) - n = 10 - pinn.span_pts(n, 'grid', locations=['D']) - - -def test_sampling_all_kwargs(): - pinn = PINN(problem, model) - n = 10 - pinn.span_pts(n=n, mode='latin', locations=['D']) - - -def test_sampling_dict(): - pinn = PINN(problem, model) - n = 10 - pinn.span_pts( - {'variables': ['x', 'y'], 'mode': 'grid', 'n': n}, locations=['D']) - - -def test_sampling_mixed_args_kwargs(): - pinn = PINN(problem, model) - n = 10 - with pytest.raises(ValueError): - pinn.span_pts(n, mode='latin', locations=['D']) - + model_extra_feats = FeedForward(len(poisson_problem.input_variables)+1,len(poisson_problem.output_variables)) + PINN(problem = poisson_problem, model=model_extra_feats, extra_features=extra_feats) def test_train(): - pinn = PINN(problem, model) + poisson_problem = Poisson() boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4'] n = 10 - pinn.span_pts(n, 'grid', locations=boundaries) - pinn.span_pts(n, 'grid', locations=['D']) - pinn.train(5) + poisson_problem.discretise_domain(n, 'grid', locations=boundaries) + poisson_problem.discretise_domain(n, 'grid', locations=['D']) + pinn = PINN(problem = poisson_problem, model=model, extra_features=None, loss=LpLoss()) + trainer = Trainer(solver=pinn, kwargs={'max_epochs' : 5}) + trainer.train() + +def test_train_extra_feats(): + poisson_problem = Poisson() + boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4'] + n = 10 + poisson_problem.discretise_domain(n, 'grid', locations=boundaries) + poisson_problem.discretise_domain(n, 'grid', locations=['D']) + pinn = PINN(problem = poisson_problem, model=model_extra_feats, extra_features=extra_feats) + trainer = Trainer(solver=pinn, kwargs={'max_epochs' : 5}) + trainer.train() """ def test_train_2(): @@ -146,8 +114,8 @@ def test_train_2(): param = [0, 3] for i, truth_key in zip(param, expected_keys): pinn = PINN(problem, model) - pinn.span_pts(n, 'grid', locations=boundaries) - pinn.span_pts(n, 'grid', locations=['D']) + pinn.discretise_domain(n, 'grid', locations=boundaries) + pinn.discretise_domain(n, 'grid', locations=['D']) pinn.train(50, save_loss=i) assert list(pinn.history_loss.keys()) == truth_key @@ -156,8 +124,8 @@ def test_train_extra_feats(): pinn = PINN(problem, model_extra_feat, [myFeature()]) boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4'] n = 10 - pinn.span_pts(n, 'grid', locations=boundaries) - pinn.span_pts(n, 'grid', locations=['D']) + pinn.discretise_domain(n, 'grid', locations=boundaries) + pinn.discretise_domain(n, 'grid', locations=['D']) pinn.train(5) @@ -168,8 +136,8 @@ def test_train_2_extra_feats(): param = [0, 3] for i, truth_key in zip(param, expected_keys): pinn = PINN(problem, model_extra_feat, [myFeature()]) - pinn.span_pts(n, 'grid', locations=boundaries) - pinn.span_pts(n, 'grid', locations=['D']) + pinn.discretise_domain(n, 'grid', locations=boundaries) + pinn.discretise_domain(n, 'grid', locations=['D']) pinn.train(50, save_loss=i) assert list(pinn.history_loss.keys()) == truth_key @@ -181,8 +149,8 @@ def test_train_with_optimizer_kwargs(): param = [0, 3] for i, truth_key in zip(param, expected_keys): pinn = PINN(problem, model, optimizer_kwargs={'lr' : 0.3}) - pinn.span_pts(n, 'grid', locations=boundaries) - pinn.span_pts(n, 'grid', locations=['D']) + pinn.discretise_domain(n, 'grid', locations=boundaries) + pinn.discretise_domain(n, 'grid', locations=['D']) pinn.train(50, save_loss=i) assert list(pinn.history_loss.keys()) == truth_key @@ -199,8 +167,8 @@ def test_train_with_lr_scheduler(): lr_scheduler_type=torch.optim.lr_scheduler.CyclicLR, lr_scheduler_kwargs={'base_lr' : 0.1, 'max_lr' : 0.3, 'cycle_momentum': False} ) - pinn.span_pts(n, 'grid', locations=boundaries) - pinn.span_pts(n, 'grid', locations=['D']) + pinn.discretise_domain(n, 'grid', locations=boundaries) + pinn.discretise_domain(n, 'grid', locations=['D']) pinn.train(50, save_loss=i) assert list(pinn.history_loss.keys()) == truth_key @@ -209,8 +177,8 @@ def test_train_with_lr_scheduler(): # pinn = PINN(problem, model, batch_size=6) # boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4'] # n = 10 -# pinn.span_pts(n, 'grid', locations=boundaries) -# pinn.span_pts(n, 'grid', locations=['D']) +# pinn.discretise_domain(n, 'grid', locations=boundaries) +# pinn.discretise_domain(n, 'grid', locations=['D']) # pinn.train(5) @@ -221,8 +189,8 @@ def test_train_with_lr_scheduler(): # param = [0, 3] # for i, truth_key in zip(param, expected_keys): # pinn = PINN(problem, model, batch_size=6) -# pinn.span_pts(n, 'grid', locations=boundaries) -# pinn.span_pts(n, 'grid', locations=['D']) +# pinn.discretise_domain(n, 'grid', locations=boundaries) +# pinn.discretise_domain(n, 'grid', locations=['D']) # pinn.train(50, save_loss=i) # assert list(pinn.history_loss.keys()) == truth_key @@ -233,15 +201,15 @@ if torch.cuda.is_available(): # pinn = PINN(problem, model, batch_size=20, device='cuda') # boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4'] # n = 100 - # pinn.span_pts(n, 'grid', locations=boundaries) - # pinn.span_pts(n, 'grid', locations=['D']) + # pinn.discretise_domain(n, 'grid', locations=boundaries) + # pinn.discretise_domain(n, 'grid', locations=['D']) # pinn.train(5) def test_gpu_train_nobatch(): pinn = PINN(problem, model, batch_size=None, device='cuda') boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4'] n = 100 - pinn.span_pts(n, 'grid', locations=boundaries) - pinn.span_pts(n, 'grid', locations=['D']) + pinn.discretise_domain(n, 'grid', locations=boundaries) + pinn.discretise_domain(n, 'grid', locations=['D']) pinn.train(5) """ \ No newline at end of file diff --git a/tests/test_problem.py b/tests/test_problem.py new file mode 100644 index 0000000..d991c1b --- /dev/null +++ b/tests/test_problem.py @@ -0,0 +1,97 @@ +import torch +import pytest + +from pina.problem import SpatialProblem +from pina.operators import nabla +from pina import LabelTensor, Condition +from pina.geometry import CartesianDomain +from pina.equation.equation import Equation +from pina.equation.equation_factory import FixedValue + + +def laplace_equation(input_, output_): + force_term = (torch.sin(input_.extract(['x'])*torch.pi) * + torch.sin(input_.extract(['y'])*torch.pi)) + nabla_u = nabla(output_.extract(['u']), input_) + return nabla_u - force_term + +my_laplace = Equation(laplace_equation) +in_ = LabelTensor(torch.tensor([[0., 1.]], requires_grad=True), ['x', 'y']) +out_ = LabelTensor(torch.tensor([[0.]], requires_grad=True), ['u']) + +class Poisson(SpatialProblem): + output_variables = ['u'] + spatial_domain = CartesianDomain({'x': [0, 1], 'y': [0, 1]}) + + conditions = { + 'gamma1': Condition( + location=CartesianDomain({'x': [0, 1], 'y': 1}), + equation=FixedValue(0.0)), + 'gamma2': Condition( + location=CartesianDomain({'x': [0, 1], 'y': 0}), + equation=FixedValue(0.0)), + 'gamma3': Condition( + location=CartesianDomain({'x': 1, 'y': [0, 1]}), + equation=FixedValue(0.0)), + 'gamma4': Condition( + location=CartesianDomain({'x': 0, 'y': [0, 1]}), + equation=FixedValue(0.0)), + 'D': Condition( + location=CartesianDomain({'x': [0, 1], 'y': [0, 1]}), + equation=my_laplace), + 'data': Condition( + input_points=in_, + output_points=out_) + } + + def poisson_sol(self, pts): + return -( + torch.sin(pts.extract(['x'])*torch.pi) * + torch.sin(pts.extract(['y'])*torch.pi) + )/(2*torch.pi**2) + + truth_solution = poisson_sol + + +# make the problem +poisson_problem = Poisson() + + +def test_discretise_domain(): + n = 10 + boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4'] + poisson_problem.discretise_domain(n, 'grid', locations=boundaries) + for b in boundaries: + assert poisson_problem.input_pts[b].shape[0] == n + poisson_problem.discretise_domain(n, 'random', locations=boundaries) + for b in boundaries: + assert poisson_problem.input_pts[b].shape[0] == n + + poisson_problem.discretise_domain(n, 'grid', locations=['D']) + assert poisson_problem.input_pts['D'].shape[0] == n**2 + poisson_problem.discretise_domain(n, 'random', locations=['D']) + assert poisson_problem.input_pts['D'].shape[0] == n + + poisson_problem.discretise_domain(n, 'latin', locations=['D']) + assert poisson_problem.input_pts['D'].shape[0] == n + + poisson_problem.discretise_domain(n, 'lh', locations=['D']) + assert poisson_problem.input_pts['D'].shape[0] == n + +def test_sampling_all_args(): + n = 10 + poisson_problem.discretise_domain(n, 'grid', locations=['D']) + +def test_sampling_all_kwargs(): + n = 10 + poisson_problem.discretise_domain(n=n, mode='latin', locations=['D']) + +def test_sampling_dict(): + n = 10 + poisson_problem.discretise_domain( + {'variables': ['x', 'y'], 'mode': 'grid', 'n': n}, locations=['D']) + +def test_sampling_mixed_args_kwargs(): + n = 10 + with pytest.raises(ValueError): + poisson_problem.discretise_domain(n, mode='latin', locations=['D']) \ No newline at end of file