* clean `condition` module
* add docs
This commit is contained in:
Nicola Demo
2023-04-18 15:00:26 +02:00
committed by GitHub
parent 736c78fd64
commit 2ca08b5236
18 changed files with 198 additions and 158 deletions

View File

@@ -26,21 +26,22 @@ class PINN(object):
device='cpu',
error_norm='mse'):
'''
:param Problem problem: the formualation of the problem.
:param AbstractProblem problem: the formualation of the problem.
:param torch.nn.Module model: the neural network model to use.
:param torch.optim optimizer: the neural network optimizer to use;
default is `torch.optim.Adam`.
:param torch.optim.Optimizer optimizer: the neural network optimizer to
use; default is `torch.optim.Adam`.
:param dict optimizer_kwargs: Optimizer constructor keyword args.
:param float lr: the learning rate; default is 0.001.
:param torch.optim.lr_scheduler._LRScheduler lr_scheduler_type: Learning rate scheduler.
:param torch.optim.LRScheduler lr_scheduler_type: Learning
rate scheduler.
:param dict lr_scheduler_kwargs: LR scheduler constructor keyword args.
:param float regularizer: the coefficient for L2 regularizer term.
:param type dtype: the data type to use for the model. Valid option are
`torch.float32` and `torch.float64` (`torch.float16` only on GPU);
default is `torch.float64`.
:param string device: the device used for training; default 'cpu'
:param str device: the device used for training; default 'cpu'
option include 'cuda' if cuda is available.
:param string/int error_norm: the loss function used as minimizer,
:param (str, int) error_norm: the loss function used as minimizer,
default mean square error 'mse'. If string options include mean
error 'me' and mean square error 'mse'. If int, the p-norm is
calculated where p is specifined by the int input.
@@ -161,6 +162,9 @@ class PINN(object):
def span_pts(self, *args, **kwargs):
"""
Generate a set of points to span the `Location` of all the conditions of
the problem.
>>> pinn.span_pts(n=10, mode='grid')
>>> pinn.span_pts(n=10, mode='grid', location=['bound1'])
>>> pinn.span_pts(n=10, mode='grid', variables=['x'])