70 lines
1.7 KiB
Python
70 lines
1.7 KiB
Python
"""
|
|
Module for utility functions for the convolutional layer.
|
|
"""
|
|
|
|
import torch
|
|
|
|
|
|
def check_point(x, current_stride, dim):
|
|
"""
|
|
Check if the point is in the current stride.
|
|
|
|
:param torch.Tensor x: The input data.
|
|
:param int current_stride: The current stride.
|
|
:param int dim: The shape of the filter.
|
|
:return: The indeces of the points in the current stride.
|
|
:rtype: torch.Tensor
|
|
"""
|
|
max_stride = current_stride + dim
|
|
indeces = torch.logical_and(
|
|
x[..., :-1] < max_stride, x[..., :-1] >= current_stride
|
|
).all(dim=-1)
|
|
return indeces
|
|
|
|
|
|
def map_points_(x, filter_position):
|
|
"""
|
|
The mapping function for n-dimensional case.
|
|
|
|
:param torch.Tensor x: The two-dimensional input data.
|
|
:param list[int] filter_position: The position of the filter.
|
|
:return: The data mapped in-place.
|
|
:rtype: torch.tensor
|
|
"""
|
|
x.add_(-filter_position)
|
|
|
|
return x
|
|
|
|
|
|
def optimizing(f):
|
|
"""
|
|
Decorator to call the function only once.
|
|
|
|
:param f: python function
|
|
:type f: Callable
|
|
"""
|
|
|
|
def wrapper(*args, **kwargs):
|
|
"""
|
|
Wrapper function.
|
|
|
|
:param args: The arguments of the function.
|
|
:param kwargs: The keyword arguments of the function.
|
|
"""
|
|
if kwargs["type_"] == "forward":
|
|
if not wrapper.has_run_inverse:
|
|
wrapper.has_run_inverse = True
|
|
return f(*args, **kwargs)
|
|
|
|
if kwargs["type_"] == "inverse":
|
|
if not wrapper.has_run:
|
|
wrapper.has_run = True
|
|
return f(*args, **kwargs)
|
|
|
|
return f(*args, **kwargs)
|
|
|
|
wrapper.has_run_inverse = False
|
|
wrapper.has_run = False
|
|
|
|
return wrapper
|