import torch import torch.nn as nn from ...utils import check_consistency class ResidualBlock(nn.Module): """Residual block base class. Implementation of a residual block. .. seealso:: **Original reference**: He, Kaiming, et al. "Deep residual learning for image recognition." Proceedings of the IEEE conference on computer vision and pattern recognition. 2016.. `_. """ def __init__(self, input_dim, output_dim, hidden_dim, spectral_norm=False, activation=torch.nn.ReLU()): """Residual block constructor :param int input_dim: Dimension of the input to pass to the feedforward linear layer. :param int output_dim: Dimension of the output from the residual layer. :param int hidden_dim: Hidden dimension for mapping the input (first block). :param bool spectral_norm: Apply spectral normalization to feedforward layers, defaults to False. :param torch.nn.Module activation: Cctivation function after first block. """ super().__init__() # check consistency check_consistency(spectral_norm, bool) check_consistency(input_dim, int) check_consistency(output_dim, int) check_consistency(hidden_dim, int) check_consistency(activation, torch.nn.Module) # assign variables self._spectral_norm = spectral_norm self._input_dim = input_dim self._output_dim = output_dim self._hidden_dim = hidden_dim self._activation = activation # create layers self.l1 = self._spect_norm(nn.Linear(input_dim, hidden_dim)) self.l2 = self._spect_norm(nn.Linear(hidden_dim, output_dim)) self.l3 = self._spect_norm(nn.Linear(input_dim, output_dim)) def forward(self, x): """Forward pass for residual block layer. :param torch.Tensor x: Input tensor for the residual layer. :return: Output tensor for the residual layer. :rtype: torch.Tensor """ y = self.activation(self.l1(x)) y = self.l2(y) x = self.l3(x) return y + x def _spect_norm(self, x): """Perform spectral norm on the layers. :param x: A torch.nn.Module Linear layer :type x: torch.nn.Module :return: The spectral norm of the layer :rtype: torch.nn.Module """ return nn.utils.spectral_norm(x) if self._spectral_norm else x @ property def spectral_norm(self): return self._spectral_norm @ property def input_dim(self): return self._input_dim @ property def output_dim(self): return self._output_dim @ property def hidden_dim(self): return self._hidden_dim @ property def activation(self): return self._activation class EnhancedLinear(torch.nn.Module): """ TODO """ def __init__(self, layer, activation=None, dropout=None): super().__init__() # check consistency check_consistency(layer, nn.Module) if activation is not None: check_consistency(activation, nn.Module) if dropout is not None: check_consistency(dropout, float) # assign forward if (dropout is None) and (activation is None): self._model = torch.nn.Sequential(layer) elif (dropout is None) and (activation is not None): self._model = torch.nn.Sequential(layer, activation) elif (dropout is not None) and (activation is None): self._model = torch.nn.Sequential(layer, self._drop(dropout)) elif (dropout is not None) and (activation is not None): self._model = torch.nn.Sequential(layer, activation, self._drop(dropout)) def forward(self, x): return self._model(x)