New Residual Model and Fix relative import

* Adding Residual MLP
* Adding test Residual MLP
* Modified relative import Continuous Conv
This commit is contained in:
Dario Coscia
2023-09-13 12:45:22 +02:00
committed by Nicola Demo
parent ba7371f350
commit 17464ceca9
7 changed files with 211 additions and 9 deletions

View File

@@ -92,4 +92,39 @@ class ResidualBlock(nn.Module):
@ property
def activation(self):
return self._activation
return self._activation
class EnhancedLinear(torch.nn.Module):
"""
TODO
"""
def __init__(self, layer, activation=None, dropout=None):
super().__init__()
# check consistency
check_consistency(layer, nn.Module)
if activation is not None:
check_consistency(activation, nn.Module)
if dropout is not None:
check_consistency(dropout, float)
# assign forward
if (dropout is None) and (activation is None):
self._model = torch.nn.Sequential(layer)
elif (dropout is None) and (activation is not None):
self._model = torch.nn.Sequential(layer,
activation)
elif (dropout is not None) and (activation is None):
self._model = torch.nn.Sequential(layer,
self._drop(dropout))
elif (dropout is not None) and (activation is not None):
self._model = torch.nn.Sequential(layer,
activation,
self._drop(dropout))
def forward(self, x):
return self._model(x)