Continuous Convolution (#69)
* network handling update * adding tutorial * docs
This commit is contained in:
48
pina/model/layers/utils_convolution.py
Normal file
48
pina/model/layers/utils_convolution.py
Normal file
@@ -0,0 +1,48 @@
|
||||
import torch
|
||||
|
||||
|
||||
def check_point(x, current_stride, dim):
|
||||
max_stride = current_stride + dim
|
||||
indeces = torch.logical_and(x[..., :-1] < max_stride,
|
||||
x[..., :-1] >= current_stride).all(dim=-1)
|
||||
return indeces
|
||||
|
||||
|
||||
def map_points_(x, filter_position):
|
||||
"""Mapping function n dimensional case
|
||||
|
||||
:param x: input data of two dimension
|
||||
:type x: torch.tensor
|
||||
:param filter_position: position of the filter
|
||||
:type dim: list[numeric]
|
||||
:return: data mapped inplace
|
||||
:rtype: torch.tensor
|
||||
"""
|
||||
x.add_(-filter_position)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def optimizing(f):
|
||||
"""Decorator for calling a function just once
|
||||
|
||||
:param f: python function
|
||||
:type f: function
|
||||
"""
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
|
||||
if kwargs['type'] == 'forward':
|
||||
if not wrapper.has_run_inverse:
|
||||
wrapper.has_run_inverse = True
|
||||
return f(*args, **kwargs)
|
||||
|
||||
if kwargs['type'] == 'inverse':
|
||||
if not wrapper.has_run:
|
||||
wrapper.has_run = True
|
||||
return f(*args, **kwargs)
|
||||
|
||||
wrapper.has_run_inverse = False
|
||||
wrapper.has_run = False
|
||||
|
||||
return wrapper
|
||||
Reference in New Issue
Block a user