fix doc model part 2

This commit is contained in:
giovanni
2025-03-14 16:07:08 +01:00
committed by FilippoOlivo
parent 194f5d24c4
commit 28d24f3f41
18 changed files with 887 additions and 851 deletions

View File

@@ -1,4 +1,4 @@
"""Module for Base Continuous Convolution class."""
"""Module for the Base Continuous Convolution class."""
from abc import ABCMeta, abstractmethod
import torch
@@ -7,8 +7,31 @@ from .utils_convolution import optimizing
class BaseContinuousConv(torch.nn.Module, metaclass=ABCMeta):
"""
Abstract class
r"""
Base Class for Continuous Convolution.
The class expects the input to be in the form:
:math:`[B \times N_{in} \times N \times D]`, where :math:`B` is the
batch_size, :math:`N_{in}` is the number of input fields, :math:`N`
the number of points in the mesh, :math:`D` the dimension of the problem.
In particular:
* :math:`D` is the number of spatial variables + 1. The last column must
contain the field value.
* :math:`N_{in}` represents the number of function components.
For instance, a vectorial function :math:`f = [f_1, f_2]` has
:math:`N_{in}=2`.
:Note
A 2-dimensional vector-valued function defined on a 3-dimensional input
evaluated on a 100 points input mesh and batch size of 8 is represented
as a tensor of shape ``[8, 2, 100, 4]``, where the columns
``[:, 0, :, -1]`` and ``[:, 1, :, -1]`` represent the first and second,
components of the function, respectively.
The algorithm returns a tensor of shape:
:math:`[B \times N_{out} \times N \times D]`, where :math:`B` is the
batch_size, :math:`N_{out}` is the number of output fields, :math:`N`
the number of points in the mesh, :math:`D` the dimension of the problem.
"""
def __init__(
@@ -22,56 +45,30 @@ class BaseContinuousConv(torch.nn.Module, metaclass=ABCMeta):
no_overlap=False,
):
"""
Base Class for Continuous Convolution.
Initialization of the :class:`BaseContinuousConv` class.
The algorithm expects input to be in the form:
$$[B \times N_{in} \times N \times D]$$
where $B$ is the batch_size, $N_{in}$ is the number of input
fields, $N$ the number of points in the mesh, $D$ the dimension
of the problem. In particular:
* $D$ is the number of spatial variables + 1. The last column must
contain the field value. For example for 2D problems $D=3$ and
the tensor will be something like `[first coordinate, second
coordinate, field value]`.
* $N_{in}$ represents the number of vectorial function presented.
For example a vectorial function $f = [f_1, f_2]$ will have
$N_{in}=2$.
:Note
A 2-dimensional vectorial function $N_{in}=2$ of 3-dimensional
input $D=3+1=4$ with 100 points input mesh and batch size of 8
is represented as a tensor `[8, 2, 100, 4]`, where the columns
`[:, 0, :, -1]` and `[:, 1, :, -1]` represent the first and
second filed value respectively
The algorithm returns a tensor of shape:
$$[B \times N_{out} \times N' \times D]$$
where $B$ is the batch_size, $N_{out}$ is the number of output
fields, $N'$ the number of points in the mesh, $D$ the dimension
of the problem.
:param input_numb_field: number of fields in the input
:type input_numb_field: int
:param output_numb_field: number of fields in the output
:type output_numb_field: int
:param filter_dim: dimension of the filter
:type filter_dim: tuple/ list
:param stride: stride for the filter
:type stride: dict
:param model: neural network for inner parametrization,
defaults to None.
:type model: torch.nn.Module, optional
:param optimize: flag for performing optimization on the continuous
filter, defaults to False. The flag `optimize=True` should be
used only when the scatter datapoints are fixed through the
training. If torch model is in `.eval()` mode, the flag is
automatically set to False always.
:type optimize: bool, optional
:param no_overlap: flag for performing optimization on the transpose
continuous filter, defaults to False. The flag set to `True` should
be used only when the filter positions do not overlap for different
strides. RuntimeError will raise in case of non-compatible strides.
:type no_overlap: bool, optional
:param int input_numb_field: The number of input fields.
:param int output_numb_field: The number of input fields.
:param filter_dim: The shape of the filter.
:type filter_dim: list[int] | tuple[int]
:param dict stride: The stride of the filter.
:param torch.nn.Module model: The neural network for inner
parametrization. Default is ``None``.
:param bool optimize: If ``True``, optimization is performed on the
continuous filter. It should be used only when the training points
are fixed. If ``model`` is in ``eval`` mode, it is reset to
``False``. Default is ``False``.
:param bool no_overlap: If ``True``, optimization is performed on the
transposed continuous filter. It should be used only when the filter
positions do not overlap for different strides.
Default is ``False``.
:raises ValueError: If ``input_numb_field`` is not an integer.
:raises ValueError: If ``output_numb_field`` is not an integer.
:raises ValueError: If ``filter_dim`` is not a list or tuple.
:raises ValueError: If ``stride`` is not a dictionary.
:raises ValueError: If ``optimize`` is not a boolean.
:raises ValueError: If ``no_overlap`` is not a boolean.
:raises NotImplementedError: If ``no_overlap`` is ``True``.
"""
super().__init__()
@@ -119,12 +116,17 @@ class BaseContinuousConv(torch.nn.Module, metaclass=ABCMeta):
class DefaultKernel(torch.nn.Module):
"""
TODO
The default kernel.
"""
def __init__(self, input_dim, output_dim):
"""
TODO
Initialization of the :class:`DefaultKernel` class.
:param int input_dim: The input dimension.
:param int output_dim: The output dimension.
:raises ValueError: If ``input_dim`` is not an integer.
:raises ValueError: If ``output_dim`` is not an integer.
"""
super().__init__()
assert isinstance(input_dim, int)
@@ -139,65 +141,93 @@ class BaseContinuousConv(torch.nn.Module, metaclass=ABCMeta):
def forward(self, x):
"""
TODO
Forward pass.
:param torch.Tensor x: The input data.
:return: The output data.
:rtype: torch.Tensor
"""
return self._model(x)
@property
def net(self):
"""
TODO
The neural network for inner parametrization.
:return: The neural network.
:rtype: torch.nn.Module
"""
return self._net
@property
def stride(self):
"""
TODO
The stride of the filter.
:return: The stride of the filter.
:rtype: dict
"""
return self._stride
@property
def filter_dim(self):
"""
TODO
The shape of the filter.
:return: The shape of the filter.
:rtype: torch.Tensor
"""
return self._dim
@property
def input_numb_field(self):
"""
TODO
The number of input fields.
:return: The number of input fields.
:rtype: int
"""
return self._input_numb_field
@property
def output_numb_field(self):
"""
TODO
The number of output fields.
:return: The number of output fields.
:rtype: int
"""
return self._output_numb_field
@abstractmethod
def forward(self, X):
"""
TODO
Forward pass.
:param torch.Tensor X: The input data.
"""
@abstractmethod
def transpose_overlap(self, X):
"""
TODO
Transpose the convolution with overlap.
:param torch.Tensor X: The input data.
"""
@abstractmethod
def transpose_no_overlap(self, X):
"""
TODO
Transpose the convolution without overlap.
:param torch.Tensor X: The input data.
"""
@abstractmethod
def _initialize_convolution(self, X, type_):
"""
TODO
Initialize the convolution.
:param torch.Tensor X: The input data.
:param str type_: The type of initialization.
"""