🎨 Format Python code with psf/black
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
"""Module for Base Continuous Convolution class."""
|
||||
|
||||
from abc import ABCMeta, abstractmethod
|
||||
import torch
|
||||
from .stride import Stride
|
||||
@@ -10,14 +11,16 @@ class BaseContinuousConv(torch.nn.Module, metaclass=ABCMeta):
|
||||
Abstract class
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
input_numb_field,
|
||||
output_numb_field,
|
||||
filter_dim,
|
||||
stride,
|
||||
model=None,
|
||||
optimize=False,
|
||||
no_overlap=False):
|
||||
def __init__(
|
||||
self,
|
||||
input_numb_field,
|
||||
output_numb_field,
|
||||
filter_dim,
|
||||
stride,
|
||||
model=None,
|
||||
optimize=False,
|
||||
no_overlap=False,
|
||||
):
|
||||
"""
|
||||
Base Class for Continuous Convolution.
|
||||
|
||||
@@ -75,43 +78,44 @@ class BaseContinuousConv(torch.nn.Module, metaclass=ABCMeta):
|
||||
if isinstance(input_numb_field, int):
|
||||
self._input_numb_field = input_numb_field
|
||||
else:
|
||||
raise ValueError('input_numb_field must be int.')
|
||||
raise ValueError("input_numb_field must be int.")
|
||||
|
||||
if isinstance(output_numb_field, int):
|
||||
self._output_numb_field = output_numb_field
|
||||
else:
|
||||
raise ValueError('input_numb_field must be int.')
|
||||
raise ValueError("input_numb_field must be int.")
|
||||
|
||||
if isinstance(filter_dim, (tuple, list)):
|
||||
vect = filter_dim
|
||||
else:
|
||||
raise ValueError('filter_dim must be tuple or list.')
|
||||
raise ValueError("filter_dim must be tuple or list.")
|
||||
vect = torch.tensor(vect)
|
||||
self.register_buffer("_dim", vect, persistent=False)
|
||||
|
||||
if isinstance(stride, dict):
|
||||
self._stride = Stride(stride)
|
||||
else:
|
||||
raise ValueError('stride must be dictionary.')
|
||||
raise ValueError("stride must be dictionary.")
|
||||
|
||||
self._net = model
|
||||
|
||||
if isinstance(optimize, bool):
|
||||
self._optimize = optimize
|
||||
else:
|
||||
raise ValueError('optimize must be bool.')
|
||||
raise ValueError("optimize must be bool.")
|
||||
|
||||
# choosing how to initialize based on optimization
|
||||
if self._optimize:
|
||||
# optimizing decorator ensure the function is called
|
||||
# just once
|
||||
self._choose_initialization = optimizing(
|
||||
self._initialize_convolution)
|
||||
self._initialize_convolution
|
||||
)
|
||||
else:
|
||||
self._choose_initialization = self._initialize_convolution
|
||||
|
||||
if not isinstance(no_overlap, bool):
|
||||
raise ValueError('no_overlap must be bool.')
|
||||
raise ValueError("no_overlap must be bool.")
|
||||
|
||||
if no_overlap:
|
||||
raise NotImplementedError
|
||||
@@ -125,11 +129,13 @@ class BaseContinuousConv(torch.nn.Module, metaclass=ABCMeta):
|
||||
super().__init__()
|
||||
assert isinstance(input_dim, int)
|
||||
assert isinstance(output_dim, int)
|
||||
self._model = torch.nn.Sequential(torch.nn.Linear(input_dim, 20),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Linear(20, 20),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Linear(20, output_dim))
|
||||
self._model = torch.nn.Sequential(
|
||||
torch.nn.Linear(input_dim, 20),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Linear(20, 20),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Linear(20, output_dim),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self._model(x)
|
||||
|
||||
Reference in New Issue
Block a user