Documentation for v0.1 version (#199)

* Adding Equations, solving typos
* improve _code.rst
* the team rst and restuctore index.rst
* fixing errors

---------

Co-authored-by: Dario Coscia <dariocoscia@dhcp-015.eduroam.sissa.it>
This commit is contained in:
Dario Coscia
2023-11-08 14:39:00 +01:00
committed by Nicola Demo
parent 3f9305d475
commit 8b7b61b3bd
144 changed files with 2741 additions and 1766 deletions

View File

@@ -10,10 +10,16 @@ class BaseContinuousConv(torch.nn.Module, metaclass=ABCMeta):
Abstract class
"""
def __init__(self, input_numb_field, output_numb_field,
filter_dim, stride, model=None, optimize=False,
def __init__(self,
input_numb_field,
output_numb_field,
filter_dim,
stride,
model=None,
optimize=False,
no_overlap=False):
"""Base Class for Continuous Convolution.
"""
Base Class for Continuous Convolution.
The algorithm expects input to be in the form:
$$[B \times N_{in} \times N \times D]$$
@@ -50,7 +56,7 @@ class BaseContinuousConv(torch.nn.Module, metaclass=ABCMeta):
:param stride: stride for the filter
:type stride: dict
:param model: neural network for inner parametrization,
defaults to None
defaults to None.
:type model: torch.nn.Module, optional
:param optimize: flag for performing optimization on the continuous
filter, defaults to False. The flag `optimize=True` should be
@@ -114,37 +120,37 @@ class BaseContinuousConv(torch.nn.Module, metaclass=ABCMeta):
self.transpose = self.transpose_overlap
class DefaultKernel(torch.nn.Module):
def __init__(self, input_dim, output_dim):
super().__init__()
assert isinstance(input_dim, int)
assert isinstance(output_dim, int)
self._model = torch.nn.Sequential(
torch.nn.Linear(input_dim, 20),
torch.nn.ReLU(),
torch.nn.Linear(20, 20),
torch.nn.ReLU(),
torch.nn.Linear(20, output_dim)
)
self._model = torch.nn.Sequential(torch.nn.Linear(input_dim, 20),
torch.nn.ReLU(),
torch.nn.Linear(20, 20),
torch.nn.ReLU(),
torch.nn.Linear(20, output_dim))
def forward(self, x):
return self._model(x)
@ property
@property
def net(self):
return self._net
@ property
@property
def stride(self):
return self._stride
@ property
@property
def filter_dim(self):
return self._dim
@ property
@property
def input_numb_field(self):
return self._input_numb_field
@ property
@property
def output_numb_field(self):
return self._output_numb_field