Files
PINA/pina/model/layers/stride.py
Dario Coscia 0bcaf62e59 Continuous Convolution (#69)
* network handling update
* adding tutorial
* docs
2023-02-27 10:59:18 +01:00

83 lines
2.7 KiB
Python

import torch
class Stride(object):
def __init__(self, dict):
"""Stride class for continous convolution
:param param: type of continuous convolution
:type param: string
"""
self._dict_stride = dict
self._stride_continuous = None
self._stride_discrete = self._create_stride_discrete(dict)
def _create_stride_discrete(self, my_dict):
"""Creating the list for applying the filter
:param my_dict: Dictionary with the following arguments:
domain size, starting position of the filter, jump size
for the filter and direction of the filter
:type my_dict: dict
:raises IndexError: Values in the dict must have all same length
:raises ValueError: Domain values must be greater than 0
:raises ValueError: Direction must be either equal to 1, -1 or 0
:raises IndexError: Direction and jumps must have zero in the same
index
:return: list of positions for the filter
:rtype: list
:Example:
>>> stride = {"domain": [4, 4],
"start": [-4, 2],
"jump": [2, 2],
"direction": [1, 1],
}
>>> create_stride(stride)
[[-4.0, 2.0], [-4.0, 4.0], [-2.0, 2.0], [-2.0, 4.0]]
"""
# we must check boundaries of the input as well
domain, start, jumps, direction = my_dict.values()
# checking
if not all([len(s) == len(domain) for s in my_dict.values()]):
raise IndexError("values in the dict must have all same length")
if not all(v >= 0 for v in domain):
raise ValueError("domain values must be greater than 0")
if not all(v == 1 or v == -1 or v == 0 for v in direction):
raise ValueError("direction must be either equal to 1, -1 or 0")
seq_jumps = [i for i, e in enumerate(jumps) if e == 0]
seq_direction = [i for i, e in enumerate(direction) if e == 0]
if seq_direction != seq_jumps:
raise IndexError(
"direction and jumps must have zero in the same index")
if seq_jumps:
for i in seq_jumps:
jumps[i] = domain[i]
direction[i] = 1
# creating the stride grid
values_mesh = [torch.arange(0, i, step).float()
for i, step in zip(domain, jumps)]
values_mesh = [single * dim for single,
dim in zip(values_mesh, direction)]
mesh = torch.meshgrid(values_mesh)
coordinates_mesh = [x.reshape(-1, 1) for x in mesh]
stride = torch.cat(coordinates_mesh, dim=1) + torch.tensor(start)
return stride