fix doc model part 2
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
"""
|
||||
TODO: Add description
|
||||
Module for the Stride class.
|
||||
"""
|
||||
|
||||
import torch
|
||||
@@ -7,14 +7,16 @@ import torch
|
||||
|
||||
class Stride:
|
||||
"""
|
||||
TODO
|
||||
Stride class for continous convolution.
|
||||
"""
|
||||
|
||||
def __init__(self, dict_):
|
||||
"""Stride class for continous convolution
|
||||
"""
|
||||
Initialization of the :class:`Stride` class.
|
||||
|
||||
:param param: type of continuous convolution
|
||||
:type param: string
|
||||
:param dict dict_: Dictionary having as keys the domain size ``domain``,
|
||||
the starting position of the filter ``start``, the jump size for the
|
||||
filter ``jump``, and the direction of the filter ``direction``.
|
||||
"""
|
||||
|
||||
self._dict_stride = dict_
|
||||
@@ -22,52 +24,50 @@ class Stride:
|
||||
self._stride_discrete = self._create_stride_discrete(dict_)
|
||||
|
||||
def _create_stride_discrete(self, my_dict):
|
||||
"""Creating the list for applying the filter
|
||||
"""
|
||||
Create a tensor of positions where to apply the filter.
|
||||
|
||||
:param dict my_dict_: Dictionary having as keys the domain size
|
||||
``domain``, the starting position of the filter ``start``, the jump
|
||||
size for the filter ``jump``, and the direction of the filter
|
||||
``direction``.
|
||||
:raises IndexError: Values in the dict must have all same length.
|
||||
:raises ValueError: Domain values must be greater than 0.
|
||||
:raises ValueError: Direction must be either equal to ``1``, ``-1`` or
|
||||
``0``.
|
||||
:raises IndexError: Direction and jumps must be zero in the same index.
|
||||
:return: The positions for the filter
|
||||
:rtype: torch.Tensor
|
||||
|
||||
:param my_dict: Dictionary with the following arguments:
|
||||
domain size, starting position of the filter, jump size
|
||||
for the filter and direction of the filter
|
||||
:type my_dict: dict
|
||||
:raises IndexError: Values in the dict must have all same length
|
||||
:raises ValueError: Domain values must be greater than 0
|
||||
:raises ValueError: Direction must be either equal to 1, -1 or 0
|
||||
:raises IndexError: Direction and jumps must have zero in the same
|
||||
index
|
||||
:return: list of positions for the filter
|
||||
:rtype: list
|
||||
:Example:
|
||||
|
||||
|
||||
>>> stride = {"domain": [4, 4],
|
||||
"start": [-4, 2],
|
||||
"jump": [2, 2],
|
||||
"direction": [1, 1],
|
||||
}
|
||||
>>> create_stride(stride)
|
||||
[[-4.0, 2.0], [-4.0, 4.0], [-2.0, 2.0], [-2.0, 4.0]]
|
||||
>>> stride_dict = {
|
||||
... "domain": [4, 4],
|
||||
... "start": [-4, 2],
|
||||
... "jump": [2, 2],
|
||||
... "direction": [1, 1],
|
||||
... }
|
||||
>>> Stride(stride_dict)
|
||||
"""
|
||||
|
||||
# we must check boundaries of the input as well
|
||||
|
||||
domain, start, jumps, direction = my_dict.values()
|
||||
|
||||
# checking
|
||||
|
||||
if not all(len(s) == len(domain) for s in my_dict.values()):
|
||||
raise IndexError("values in the dict must have all same length")
|
||||
raise IndexError("Values in the dict must have all same length")
|
||||
|
||||
if not all(v >= 0 for v in domain):
|
||||
raise ValueError("domain values must be greater than 0")
|
||||
raise ValueError("Domain values must be greater than 0")
|
||||
|
||||
if not all(v in (0, -1, 1) for v in direction):
|
||||
raise ValueError("direction must be either equal to 1, -1 or 0")
|
||||
raise ValueError("Direction must be either equal to 1, -1 or 0")
|
||||
|
||||
seq_jumps = [i for i, e in enumerate(jumps) if e == 0]
|
||||
seq_direction = [i for i, e in enumerate(direction) if e == 0]
|
||||
|
||||
if seq_direction != seq_jumps:
|
||||
raise IndexError(
|
||||
"direction and jumps must have zero in the same index"
|
||||
"Direction and jumps must have zero in the same index"
|
||||
)
|
||||
|
||||
if seq_jumps:
|
||||
|
||||
Reference in New Issue
Block a user