🎨 Format Python code with psf/black

This commit is contained in:
ndem0
2024-02-09 11:25:00 +00:00
committed by Nicola Demo
parent 591aeeb02b
commit cbb43a5392
64 changed files with 1323 additions and 955 deletions

View File

@@ -1,4 +1,5 @@
"""Module for Base Continuous Convolution class."""
from abc import ABCMeta, abstractmethod
import torch
from .stride import Stride
@@ -10,14 +11,16 @@ class BaseContinuousConv(torch.nn.Module, metaclass=ABCMeta):
Abstract class
"""
def __init__(self,
input_numb_field,
output_numb_field,
filter_dim,
stride,
model=None,
optimize=False,
no_overlap=False):
def __init__(
self,
input_numb_field,
output_numb_field,
filter_dim,
stride,
model=None,
optimize=False,
no_overlap=False,
):
"""
Base Class for Continuous Convolution.
@@ -75,43 +78,44 @@ class BaseContinuousConv(torch.nn.Module, metaclass=ABCMeta):
if isinstance(input_numb_field, int):
self._input_numb_field = input_numb_field
else:
raise ValueError('input_numb_field must be int.')
raise ValueError("input_numb_field must be int.")
if isinstance(output_numb_field, int):
self._output_numb_field = output_numb_field
else:
raise ValueError('input_numb_field must be int.')
raise ValueError("input_numb_field must be int.")
if isinstance(filter_dim, (tuple, list)):
vect = filter_dim
else:
raise ValueError('filter_dim must be tuple or list.')
raise ValueError("filter_dim must be tuple or list.")
vect = torch.tensor(vect)
self.register_buffer("_dim", vect, persistent=False)
if isinstance(stride, dict):
self._stride = Stride(stride)
else:
raise ValueError('stride must be dictionary.')
raise ValueError("stride must be dictionary.")
self._net = model
if isinstance(optimize, bool):
self._optimize = optimize
else:
raise ValueError('optimize must be bool.')
raise ValueError("optimize must be bool.")
# choosing how to initialize based on optimization
if self._optimize:
# optimizing decorator ensure the function is called
# just once
self._choose_initialization = optimizing(
self._initialize_convolution)
self._initialize_convolution
)
else:
self._choose_initialization = self._initialize_convolution
if not isinstance(no_overlap, bool):
raise ValueError('no_overlap must be bool.')
raise ValueError("no_overlap must be bool.")
if no_overlap:
raise NotImplementedError
@@ -125,11 +129,13 @@ class BaseContinuousConv(torch.nn.Module, metaclass=ABCMeta):
super().__init__()
assert isinstance(input_dim, int)
assert isinstance(output_dim, int)
self._model = torch.nn.Sequential(torch.nn.Linear(input_dim, 20),
torch.nn.ReLU(),
torch.nn.Linear(20, 20),
torch.nn.ReLU(),
torch.nn.Linear(20, output_dim))
self._model = torch.nn.Sequential(
torch.nn.Linear(input_dim, 20),
torch.nn.ReLU(),
torch.nn.Linear(20, 20),
torch.nn.ReLU(),
torch.nn.Linear(20, output_dim),
)
def forward(self, x):
return self._model(x)