fix doc model part 2
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
"""Module for Base Continuous Convolution class."""
|
||||
"""Module for the Base Continuous Convolution class."""
|
||||
|
||||
from abc import ABCMeta, abstractmethod
|
||||
import torch
|
||||
@@ -7,8 +7,31 @@ from .utils_convolution import optimizing
|
||||
|
||||
|
||||
class BaseContinuousConv(torch.nn.Module, metaclass=ABCMeta):
|
||||
"""
|
||||
Abstract class
|
||||
r"""
|
||||
Base Class for Continuous Convolution.
|
||||
|
||||
The class expects the input to be in the form:
|
||||
:math:`[B \times N_{in} \times N \times D]`, where :math:`B` is the
|
||||
batch_size, :math:`N_{in}` is the number of input fields, :math:`N`
|
||||
the number of points in the mesh, :math:`D` the dimension of the problem.
|
||||
In particular:
|
||||
* :math:`D` is the number of spatial variables + 1. The last column must
|
||||
contain the field value.
|
||||
* :math:`N_{in}` represents the number of function components.
|
||||
For instance, a vectorial function :math:`f = [f_1, f_2]` has
|
||||
:math:`N_{in}=2`.
|
||||
|
||||
:Note
|
||||
A 2-dimensional vector-valued function defined on a 3-dimensional input
|
||||
evaluated on a 100 points input mesh and batch size of 8 is represented
|
||||
as a tensor of shape ``[8, 2, 100, 4]``, where the columns
|
||||
``[:, 0, :, -1]`` and ``[:, 1, :, -1]`` represent the first and second,
|
||||
components of the function, respectively.
|
||||
|
||||
The algorithm returns a tensor of shape:
|
||||
:math:`[B \times N_{out} \times N \times D]`, where :math:`B` is the
|
||||
batch_size, :math:`N_{out}` is the number of output fields, :math:`N`
|
||||
the number of points in the mesh, :math:`D` the dimension of the problem.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -22,56 +45,30 @@ class BaseContinuousConv(torch.nn.Module, metaclass=ABCMeta):
|
||||
no_overlap=False,
|
||||
):
|
||||
"""
|
||||
Base Class for Continuous Convolution.
|
||||
Initialization of the :class:`BaseContinuousConv` class.
|
||||
|
||||
The algorithm expects input to be in the form:
|
||||
$$[B \times N_{in} \times N \times D]$$
|
||||
where $B$ is the batch_size, $N_{in}$ is the number of input
|
||||
fields, $N$ the number of points in the mesh, $D$ the dimension
|
||||
of the problem. In particular:
|
||||
* $D$ is the number of spatial variables + 1. The last column must
|
||||
contain the field value. For example for 2D problems $D=3$ and
|
||||
the tensor will be something like `[first coordinate, second
|
||||
coordinate, field value]`.
|
||||
* $N_{in}$ represents the number of vectorial function presented.
|
||||
For example a vectorial function $f = [f_1, f_2]$ will have
|
||||
$N_{in}=2$.
|
||||
|
||||
:Note
|
||||
A 2-dimensional vectorial function $N_{in}=2$ of 3-dimensional
|
||||
input $D=3+1=4$ with 100 points input mesh and batch size of 8
|
||||
is represented as a tensor `[8, 2, 100, 4]`, where the columns
|
||||
`[:, 0, :, -1]` and `[:, 1, :, -1]` represent the first and
|
||||
second filed value respectively
|
||||
|
||||
The algorithm returns a tensor of shape:
|
||||
$$[B \times N_{out} \times N' \times D]$$
|
||||
where $B$ is the batch_size, $N_{out}$ is the number of output
|
||||
fields, $N'$ the number of points in the mesh, $D$ the dimension
|
||||
of the problem.
|
||||
|
||||
:param input_numb_field: number of fields in the input
|
||||
:type input_numb_field: int
|
||||
:param output_numb_field: number of fields in the output
|
||||
:type output_numb_field: int
|
||||
:param filter_dim: dimension of the filter
|
||||
:type filter_dim: tuple/ list
|
||||
:param stride: stride for the filter
|
||||
:type stride: dict
|
||||
:param model: neural network for inner parametrization,
|
||||
defaults to None.
|
||||
:type model: torch.nn.Module, optional
|
||||
:param optimize: flag for performing optimization on the continuous
|
||||
filter, defaults to False. The flag `optimize=True` should be
|
||||
used only when the scatter datapoints are fixed through the
|
||||
training. If torch model is in `.eval()` mode, the flag is
|
||||
automatically set to False always.
|
||||
:type optimize: bool, optional
|
||||
:param no_overlap: flag for performing optimization on the transpose
|
||||
continuous filter, defaults to False. The flag set to `True` should
|
||||
be used only when the filter positions do not overlap for different
|
||||
strides. RuntimeError will raise in case of non-compatible strides.
|
||||
:type no_overlap: bool, optional
|
||||
:param int input_numb_field: The number of input fields.
|
||||
:param int output_numb_field: The number of input fields.
|
||||
:param filter_dim: The shape of the filter.
|
||||
:type filter_dim: list[int] | tuple[int]
|
||||
:param dict stride: The stride of the filter.
|
||||
:param torch.nn.Module model: The neural network for inner
|
||||
parametrization. Default is ``None``.
|
||||
:param bool optimize: If ``True``, optimization is performed on the
|
||||
continuous filter. It should be used only when the training points
|
||||
are fixed. If ``model`` is in ``eval`` mode, it is reset to
|
||||
``False``. Default is ``False``.
|
||||
:param bool no_overlap: If ``True``, optimization is performed on the
|
||||
transposed continuous filter. It should be used only when the filter
|
||||
positions do not overlap for different strides.
|
||||
Default is ``False``.
|
||||
:raises ValueError: If ``input_numb_field`` is not an integer.
|
||||
:raises ValueError: If ``output_numb_field`` is not an integer.
|
||||
:raises ValueError: If ``filter_dim`` is not a list or tuple.
|
||||
:raises ValueError: If ``stride`` is not a dictionary.
|
||||
:raises ValueError: If ``optimize`` is not a boolean.
|
||||
:raises ValueError: If ``no_overlap`` is not a boolean.
|
||||
:raises NotImplementedError: If ``no_overlap`` is ``True``.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
@@ -119,12 +116,17 @@ class BaseContinuousConv(torch.nn.Module, metaclass=ABCMeta):
|
||||
|
||||
class DefaultKernel(torch.nn.Module):
|
||||
"""
|
||||
TODO
|
||||
The default kernel.
|
||||
"""
|
||||
|
||||
def __init__(self, input_dim, output_dim):
|
||||
"""
|
||||
TODO
|
||||
Initialization of the :class:`DefaultKernel` class.
|
||||
|
||||
:param int input_dim: The input dimension.
|
||||
:param int output_dim: The output dimension.
|
||||
:raises ValueError: If ``input_dim`` is not an integer.
|
||||
:raises ValueError: If ``output_dim`` is not an integer.
|
||||
"""
|
||||
super().__init__()
|
||||
assert isinstance(input_dim, int)
|
||||
@@ -139,65 +141,93 @@ class BaseContinuousConv(torch.nn.Module, metaclass=ABCMeta):
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
TODO
|
||||
Forward pass.
|
||||
|
||||
:param torch.Tensor x: The input data.
|
||||
:return: The output data.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
return self._model(x)
|
||||
|
||||
@property
|
||||
def net(self):
|
||||
"""
|
||||
TODO
|
||||
The neural network for inner parametrization.
|
||||
|
||||
:return: The neural network.
|
||||
:rtype: torch.nn.Module
|
||||
"""
|
||||
return self._net
|
||||
|
||||
@property
|
||||
def stride(self):
|
||||
"""
|
||||
TODO
|
||||
The stride of the filter.
|
||||
|
||||
:return: The stride of the filter.
|
||||
:rtype: dict
|
||||
"""
|
||||
return self._stride
|
||||
|
||||
@property
|
||||
def filter_dim(self):
|
||||
"""
|
||||
TODO
|
||||
The shape of the filter.
|
||||
|
||||
:return: The shape of the filter.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
return self._dim
|
||||
|
||||
@property
|
||||
def input_numb_field(self):
|
||||
"""
|
||||
TODO
|
||||
The number of input fields.
|
||||
|
||||
:return: The number of input fields.
|
||||
:rtype: int
|
||||
"""
|
||||
return self._input_numb_field
|
||||
|
||||
@property
|
||||
def output_numb_field(self):
|
||||
"""
|
||||
TODO
|
||||
The number of output fields.
|
||||
|
||||
:return: The number of output fields.
|
||||
:rtype: int
|
||||
"""
|
||||
return self._output_numb_field
|
||||
|
||||
@abstractmethod
|
||||
def forward(self, X):
|
||||
"""
|
||||
TODO
|
||||
Forward pass.
|
||||
|
||||
:param torch.Tensor X: The input data.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def transpose_overlap(self, X):
|
||||
"""
|
||||
TODO
|
||||
Transpose the convolution with overlap.
|
||||
|
||||
:param torch.Tensor X: The input data.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def transpose_no_overlap(self, X):
|
||||
"""
|
||||
TODO
|
||||
Transpose the convolution without overlap.
|
||||
|
||||
:param torch.Tensor X: The input data.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def _initialize_convolution(self, X, type_):
|
||||
"""
|
||||
TODO
|
||||
Initialize the convolution.
|
||||
|
||||
:param torch.Tensor X: The input data.
|
||||
:param str type_: The type of initialization.
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user