16
pina/pinn.py
16
pina/pinn.py
@@ -26,21 +26,22 @@ class PINN(object):
|
||||
device='cpu',
|
||||
error_norm='mse'):
|
||||
'''
|
||||
:param Problem problem: the formualation of the problem.
|
||||
:param AbstractProblem problem: the formualation of the problem.
|
||||
:param torch.nn.Module model: the neural network model to use.
|
||||
:param torch.optim optimizer: the neural network optimizer to use;
|
||||
default is `torch.optim.Adam`.
|
||||
:param torch.optim.Optimizer optimizer: the neural network optimizer to
|
||||
use; default is `torch.optim.Adam`.
|
||||
:param dict optimizer_kwargs: Optimizer constructor keyword args.
|
||||
:param float lr: the learning rate; default is 0.001.
|
||||
:param torch.optim.lr_scheduler._LRScheduler lr_scheduler_type: Learning rate scheduler.
|
||||
:param torch.optim.LRScheduler lr_scheduler_type: Learning
|
||||
rate scheduler.
|
||||
:param dict lr_scheduler_kwargs: LR scheduler constructor keyword args.
|
||||
:param float regularizer: the coefficient for L2 regularizer term.
|
||||
:param type dtype: the data type to use for the model. Valid option are
|
||||
`torch.float32` and `torch.float64` (`torch.float16` only on GPU);
|
||||
default is `torch.float64`.
|
||||
:param string device: the device used for training; default 'cpu'
|
||||
:param str device: the device used for training; default 'cpu'
|
||||
option include 'cuda' if cuda is available.
|
||||
:param string/int error_norm: the loss function used as minimizer,
|
||||
:param (str, int) error_norm: the loss function used as minimizer,
|
||||
default mean square error 'mse'. If string options include mean
|
||||
error 'me' and mean square error 'mse'. If int, the p-norm is
|
||||
calculated where p is specifined by the int input.
|
||||
@@ -161,6 +162,9 @@ class PINN(object):
|
||||
|
||||
def span_pts(self, *args, **kwargs):
|
||||
"""
|
||||
Generate a set of points to span the `Location` of all the conditions of
|
||||
the problem.
|
||||
|
||||
>>> pinn.span_pts(n=10, mode='grid')
|
||||
>>> pinn.span_pts(n=10, mode='grid', location=['bound1'])
|
||||
>>> pinn.span_pts(n=10, mode='grid', variables=['x'])
|
||||
|
||||
Reference in New Issue
Block a user