fix doc model part 2
This commit is contained in:
@@ -7,30 +7,27 @@ from .integral import Integral
|
||||
|
||||
|
||||
class ContinuousConvBlock(BaseContinuousConv):
|
||||
"""
|
||||
Implementation of Continuous Convolutional operator.
|
||||
|
||||
The algorithm expects input to be in the form:
|
||||
:math:`[B, N_{in}, N, D]`
|
||||
where :math:`B` is the batch_size, :math:`N_{in}` is the number of input
|
||||
fields, :math:`N` the number of points in the mesh, :math:`D` the dimension
|
||||
of the problem. In particular:
|
||||
r"""
|
||||
Continuous Convolutional block.
|
||||
|
||||
The class expects the input to be in the form:
|
||||
:math:`[B \times N_{in} \times N \times D]`, where :math:`B` is the
|
||||
batch_size, :math:`N_{in}` is the number of input fields, :math:`N`
|
||||
the number of points in the mesh, :math:`D` the dimension of the problem.
|
||||
In particular:
|
||||
* :math:`D` is the number of spatial variables + 1. The last column must
|
||||
contain the field value. For example for 2D problems :math:`D=3` and
|
||||
the tensor will be something like ``[first coordinate, second
|
||||
coordinate, field value]``.
|
||||
* :math:`N_{in}` represents the number of vectorial function presented.
|
||||
For example a vectorial function :math:`f = [f_1, f_2]` will have
|
||||
contain the field value.
|
||||
* :math:`N_{in}` represents the number of function components.
|
||||
For instance, a vectorial function :math:`f = [f_1, f_2]` has
|
||||
:math:`N_{in}=2`.
|
||||
|
||||
.. seealso::
|
||||
|
||||
**Original reference**: Coscia, D., Meneghetti, L., Demo, N. et al.
|
||||
*A continuous convolutional trainable filter for modelling
|
||||
unstructured data*. Comput Mech 72, 253–265 (2023).
|
||||
**Original reference**:
|
||||
Coscia, D., Meneghetti, L., Demo, N. et al.
|
||||
*A continuous convolutional trainable filter for modelling unstructured
|
||||
data*. Comput Mech 72, 253-265 (2023).
|
||||
DOI `<https://doi.org/10.1007/s00466-023-02291-1>`_
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -44,53 +41,48 @@ class ContinuousConvBlock(BaseContinuousConv):
|
||||
no_overlap=False,
|
||||
):
|
||||
"""
|
||||
:param input_numb_field: Number of fields :math:`N_{in}` in the input.
|
||||
:type input_numb_field: int
|
||||
:param output_numb_field: Number of fields :math:`N_{out}` in the
|
||||
output.
|
||||
:type output_numb_field: int
|
||||
:param filter_dim: Dimension of the filter.
|
||||
:type filter_dim: tuple(int) | list(int)
|
||||
:param stride: Stride for the filter.
|
||||
:type stride: dict
|
||||
:param model: Neural network for inner parametrization,
|
||||
defaults to ``None``. If None, a default multilayer perceptron
|
||||
of width three and size twenty with ReLU activation is used.
|
||||
:type model: torch.nn.Module
|
||||
:param optimize: Flag for performing optimization on the continuous
|
||||
filter, defaults to False. The flag `optimize=True` should be
|
||||
used only when the scatter datapoints are fixed through the
|
||||
training. If torch model is in ``.eval()`` mode, the flag is
|
||||
automatically set to False always.
|
||||
:type optimize: bool
|
||||
:param no_overlap: Flag for performing optimization on the transpose
|
||||
continuous filter, defaults to False. The flag set to `True` should
|
||||
be used only when the filter positions do not overlap for different
|
||||
strides. RuntimeError will raise in case of non-compatible strides.
|
||||
:type no_overlap: bool
|
||||
Initialization of the :class:`ContinuousConvBlock` class.
|
||||
|
||||
:param int input_numb_field: The number of input fields.
|
||||
:param int output_numb_field: The number of input fields.
|
||||
:param filter_dim: The shape of the filter.
|
||||
:type filter_dim: list[int] | tuple[int]
|
||||
:param dict stride: The stride of the filter.
|
||||
:param torch.nn.Module model: The neural network for inner
|
||||
parametrization. Default is ``None``.
|
||||
:param bool optimize: If ``True``, optimization is performed on the
|
||||
continuous filter. It should be used only when the training points
|
||||
are fixed. If ``model`` is in ``eval`` mode, it is reset to
|
||||
``False``. Default is ``False``.
|
||||
:param bool no_overlap: If ``True``, optimization is performed on the
|
||||
transposed continuous filter. It should be used only when the filter
|
||||
positions do not overlap for different strides.
|
||||
Default is ``False``.
|
||||
|
||||
.. note::
|
||||
Using `optimize=True` the filter can be use either in `forward`
|
||||
or in `transpose` mode, not both. If `optimize=False` the same
|
||||
filter can be used for both `transpose` and `forward` modes.
|
||||
If ``optimize=True``, the filter can be use either in ``forward``
|
||||
or in ``transpose`` mode, not both.
|
||||
|
||||
:Example:
|
||||
>>> class MLP(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self. model = torch.nn.Sequential(
|
||||
torch.nn.Linear(2, 8),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Linear(8, 8),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Linear(8, 1))
|
||||
def forward(self, x):
|
||||
return self.model(x)
|
||||
... def __init__(self) -> None:
|
||||
... super().__init__()
|
||||
... self. model = torch.nn.Sequential(
|
||||
... torch.nn.Linear(2, 8),
|
||||
... torch.nn.ReLU(),
|
||||
... torch.nn.Linear(8, 8),
|
||||
... torch.nn.ReLU(),
|
||||
... torch.nn.Linear(8, 1)
|
||||
... )
|
||||
... def forward(self, x):
|
||||
... return self.model(x)
|
||||
>>> dim = [3, 3]
|
||||
>>> stride = {"domain": [10, 10],
|
||||
"start": [0, 0],
|
||||
"jumps": [3, 3],
|
||||
"direction": [1, 1.]}
|
||||
>>> stride = {
|
||||
... "domain": [10, 10],
|
||||
... "start": [0, 0],
|
||||
... "jumps": [3, 3],
|
||||
... "direction": [1, 1.]
|
||||
... }
|
||||
>>> conv = ContinuousConv2D(1, 2, dim, stride, MLP)
|
||||
>>> conv
|
||||
ContinuousConv2D(
|
||||
@@ -116,7 +108,6 @@ class ContinuousConvBlock(BaseContinuousConv):
|
||||
)
|
||||
)
|
||||
"""
|
||||
|
||||
super().__init__(
|
||||
input_numb_field=input_numb_field,
|
||||
output_numb_field=output_numb_field,
|
||||
@@ -143,13 +134,13 @@ class ContinuousConvBlock(BaseContinuousConv):
|
||||
|
||||
def _spawn_networks(self, model):
|
||||
"""
|
||||
Private method to create a collection of kernels
|
||||
Create a collection of kernels
|
||||
|
||||
:param model: A :class:`torch.nn.Module` model in form of Object class.
|
||||
:type model: torch.nn.Module
|
||||
:return: List of :class:`torch.nn.Module` models.
|
||||
:param torch.nn.Module model: A neural network model.
|
||||
:raises ValueError: If the model is not a subclass of
|
||||
``torch.nn.Module``.
|
||||
:return: A list of models.
|
||||
:rtype: torch.nn.ModuleList
|
||||
|
||||
"""
|
||||
nets = []
|
||||
if self._net is None:
|
||||
@@ -176,13 +167,11 @@ class ContinuousConvBlock(BaseContinuousConv):
|
||||
|
||||
def _extract_mapped_points(self, batch_idx, index, x):
|
||||
"""
|
||||
Priviate method to extract mapped points in the filter
|
||||
Extract mapped points in the filter.
|
||||
|
||||
:param x: Input tensor of shape ``[channel, N, dim]``
|
||||
:type x: torch.Tensor
|
||||
:param torch.Tensor x: Input tensor of shape ``[channel, N, dim]``
|
||||
:return: Mapped points and indeces for each channel,
|
||||
:rtype: torch.Tensor, list
|
||||
|
||||
:rtype: tuple
|
||||
"""
|
||||
mapped_points = []
|
||||
indeces_channels = []
|
||||
@@ -218,11 +207,9 @@ class ContinuousConvBlock(BaseContinuousConv):
|
||||
|
||||
def _find_index(self, X):
|
||||
"""
|
||||
Private method to extract indeces for convolution.
|
||||
|
||||
:param X: Input tensor, as in ContinuousConvBlock ``__init__``.
|
||||
:type X: torch.Tensor
|
||||
Extract indeces for convolution.
|
||||
|
||||
:param torch.Tensor X: The input tensor.
|
||||
"""
|
||||
# append the index for each stride
|
||||
index = []
|
||||
@@ -236,11 +223,9 @@ class ContinuousConvBlock(BaseContinuousConv):
|
||||
|
||||
def _make_grid_forward(self, X):
|
||||
"""
|
||||
Private method to create forward convolution grid.
|
||||
|
||||
:param X: Input tensor, as in ContinuousConvBlock docstring.
|
||||
:type X: torch.Tensor
|
||||
Create forward convolution grid.
|
||||
|
||||
:param torch.Tensor X: The input tensor.
|
||||
"""
|
||||
# filter dimension + number of points in output grid
|
||||
filter_dim = len(self._dim)
|
||||
@@ -264,12 +249,9 @@ class ContinuousConvBlock(BaseContinuousConv):
|
||||
|
||||
def _make_grid_transpose(self, X):
|
||||
"""
|
||||
Private method to create transpose convolution grid.
|
||||
|
||||
:param X: Input tensor, as in ContinuousConvBlock docstring.
|
||||
:type X: torch.Tensor
|
||||
|
||||
Create transpose convolution grid.
|
||||
|
||||
:param torch.Tensor X: The input tensor.
|
||||
"""
|
||||
# initialize to all zeros
|
||||
tmp = torch.zeros_like(X).as_subclass(torch.Tensor)
|
||||
@@ -280,14 +262,12 @@ class ContinuousConvBlock(BaseContinuousConv):
|
||||
|
||||
def _make_grid(self, X, type_):
|
||||
"""
|
||||
Private method to create convolution grid.
|
||||
|
||||
:param X: Input tensor, as in ContinuousConvBlock docstring.
|
||||
:type X: torch.Tensor
|
||||
:param type: Type of convolution, ``['forward', 'inverse']`` the
|
||||
possibilities.
|
||||
:type type: str
|
||||
Create convolution grid.
|
||||
|
||||
:param torch.Tensor X: The input tensor.
|
||||
:param str type_: The type of convolution.
|
||||
Available options are: ``forward`` and ``inverse``.
|
||||
:raises TypeError: If the type is not in the available options.
|
||||
"""
|
||||
# choose the type of convolution
|
||||
if type_ == "forward":
|
||||
@@ -300,15 +280,12 @@ class ContinuousConvBlock(BaseContinuousConv):
|
||||
|
||||
def _initialize_convolution(self, X, type_="forward"):
|
||||
"""
|
||||
Private method to intialize the convolution.
|
||||
The convolution is initialized by setting a grid and
|
||||
calculate the index for finding the points inside the
|
||||
filter.
|
||||
Initialize the convolution by setting a grid and computing the index to
|
||||
find the points inside the filter.
|
||||
|
||||
:param X: Input tensor, as in ContinuousConvBlock docstring.
|
||||
:type X: torch.Tensor
|
||||
:param str type: type of convolution, ``['forward', 'inverse'] ``the
|
||||
possibilities.
|
||||
:param torch.Tensor X: The input tensor.
|
||||
:param str type_: The type of convolution. Available options are:
|
||||
``forward`` and ``inverse``. Default is ``forward``.
|
||||
"""
|
||||
|
||||
# variable for the convolution
|
||||
@@ -319,11 +296,10 @@ class ContinuousConvBlock(BaseContinuousConv):
|
||||
|
||||
def forward(self, X):
|
||||
"""
|
||||
Forward pass in the convolutional layer.
|
||||
Forward pass.
|
||||
|
||||
:param x: Input data for the convolution :math:`[B, N_{in}, N, D]`.
|
||||
:type x: torch.Tensor
|
||||
:return: Convolution output :math:`[B, N_{out}, N, D]`.
|
||||
:param torch.Tensor x: The input tensor.
|
||||
:return: The output tensor.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
|
||||
@@ -381,25 +357,14 @@ class ContinuousConvBlock(BaseContinuousConv):
|
||||
|
||||
def transpose_no_overlap(self, integrals, X):
|
||||
"""
|
||||
Transpose pass in the layer for no-overlapping filters
|
||||
Transpose pass in the layer for no-overlapping filters.
|
||||
|
||||
:param integrals: Weights for the transpose convolution. Shape
|
||||
:math:`[B, N_{in}, N]`
|
||||
where B is the batch_size, :math`N_{in}` is the number of input
|
||||
fields, :math:`N` the number of points in the mesh, D the dimension
|
||||
of the problem.
|
||||
:type integral: torch.tensor
|
||||
:param X: Input data. Expect tensor of shape
|
||||
:math:`[B, N_{in}, M, D]` where :math:`B` is the batch_size,
|
||||
:math`N_{in}`is the number of input fields, :math:`M` the number of
|
||||
points
|
||||
in the mesh, :math:`D` the dimension of the problem.
|
||||
:type X: torch.Tensor
|
||||
:return: Feed forward transpose convolution. Tensor of shape
|
||||
:math:`[B, N_{out}, M, D]` where :math:`B` is the batch_size,
|
||||
:math`N_{out}`is the number of input fields, :math:`M` the number of
|
||||
points
|
||||
in the mesh, :math:`D` the dimension of the problem.
|
||||
:param torch.Tensor integrals: The weights for the transpose convolution.
|
||||
Expected shape :math:`[B, N_{in}, N]`.
|
||||
:param torch.Tensor X: The input data.
|
||||
Expected shape :math:`[B, N_{in}, M, D]`.
|
||||
:return: Feed forward transpose convolution.
|
||||
Expected shape: :math:`[B, N_{out}, M, D]`.
|
||||
:rtype: torch.Tensor
|
||||
|
||||
.. note::
|
||||
@@ -466,25 +431,14 @@ class ContinuousConvBlock(BaseContinuousConv):
|
||||
|
||||
def transpose_overlap(self, integrals, X):
|
||||
"""
|
||||
Transpose pass in the layer for overlapping filters
|
||||
Transpose pass in the layer for overlapping filters.
|
||||
|
||||
:param integrals: Weights for the transpose convolution. Shape
|
||||
:math:`[B, N_{in}, N]`
|
||||
where B is the batch_size, :math`N_{in}` is the number of input
|
||||
fields, :math:`N` the number of points in the mesh, D the dimension
|
||||
of the problem.
|
||||
:type integral: torch.tensor
|
||||
:param X: Input data. Expect tensor of shape
|
||||
:math:`[B, N_{in}, M, D]` where :math:`B` is the batch_size,
|
||||
:math`N_{in}`is the number of input fields, :math:`M` the number of
|
||||
points
|
||||
in the mesh, :math:`D` the dimension of the problem.
|
||||
:type X: torch.Tensor
|
||||
:return: Feed forward transpose convolution. Tensor of shape
|
||||
:math:`[B, N_{out}, M, D]` where :math:`B` is the batch_size,
|
||||
:math`N_{out}`is the number of input fields, :math:`M` the number of
|
||||
points
|
||||
in the mesh, :math:`D` the dimension of the problem.
|
||||
:param torch.Tensor integrals: The weights for the transpose convolution.
|
||||
Expected shape :math:`[B, N_{in}, N]`.
|
||||
:param torch.Tensor X: The input data.
|
||||
Expected shape :math:`[B, N_{in}, M, D]`.
|
||||
:return: Feed forward transpose convolution.
|
||||
Expected shape: :math:`[B, N_{out}, M, D]`.
|
||||
:rtype: torch.Tensor
|
||||
|
||||
.. note:: This function is automatically called when ``.transpose()``
|
||||
|
||||
Reference in New Issue
Block a user