Files
PINA/pina/model/layers/residual.py
Dario Coscia 17464ceca9 New Residual Model and Fix relative import
* Adding Residual MLP
* Adding test Residual MLP
* Modified relative import Continuous Conv
2023-11-17 09:51:29 +01:00

131 lines
4.1 KiB
Python

import torch
import torch.nn as nn
from ...utils import check_consistency
class ResidualBlock(nn.Module):
"""Residual block base class. Implementation of a residual block.
.. seealso::
**Original reference**: He, Kaiming, et al.
"Deep residual learning for image recognition."
Proceedings of the IEEE conference on computer vision
and pattern recognition. 2016..
<https://arxiv.org/pdf/1512.03385.pdf>`_.
"""
def __init__(self, input_dim, output_dim,
hidden_dim, spectral_norm=False,
activation=torch.nn.ReLU()):
"""Residual block constructor
:param int input_dim: Dimension of the input to pass to the
feedforward linear layer.
:param int output_dim: Dimension of the output from the
residual layer.
:param int hidden_dim: Hidden dimension for mapping the input
(first block).
:param bool spectral_norm: Apply spectral normalization to feedforward
layers, defaults to False.
:param torch.nn.Module activation: Cctivation function after first block.
"""
super().__init__()
# check consistency
check_consistency(spectral_norm, bool)
check_consistency(input_dim, int)
check_consistency(output_dim, int)
check_consistency(hidden_dim, int)
check_consistency(activation, torch.nn.Module)
# assign variables
self._spectral_norm = spectral_norm
self._input_dim = input_dim
self._output_dim = output_dim
self._hidden_dim = hidden_dim
self._activation = activation
# create layers
self.l1 = self._spect_norm(nn.Linear(input_dim, hidden_dim))
self.l2 = self._spect_norm(nn.Linear(hidden_dim, output_dim))
self.l3 = self._spect_norm(nn.Linear(input_dim, output_dim))
def forward(self, x):
"""Forward pass for residual block layer.
:param torch.Tensor x: Input tensor for the residual layer.
:return: Output tensor for the residual layer.
:rtype: torch.Tensor
"""
y = self.activation(self.l1(x))
y = self.l2(y)
x = self.l3(x)
return y + x
def _spect_norm(self, x):
"""Perform spectral norm on the layers.
:param x: A torch.nn.Module Linear layer
:type x: torch.nn.Module
:return: The spectral norm of the layer
:rtype: torch.nn.Module
"""
return nn.utils.spectral_norm(x) if self._spectral_norm else x
@ property
def spectral_norm(self):
return self._spectral_norm
@ property
def input_dim(self):
return self._input_dim
@ property
def output_dim(self):
return self._output_dim
@ property
def hidden_dim(self):
return self._hidden_dim
@ property
def activation(self):
return self._activation
class EnhancedLinear(torch.nn.Module):
"""
TODO
"""
def __init__(self, layer, activation=None, dropout=None):
super().__init__()
# check consistency
check_consistency(layer, nn.Module)
if activation is not None:
check_consistency(activation, nn.Module)
if dropout is not None:
check_consistency(dropout, float)
# assign forward
if (dropout is None) and (activation is None):
self._model = torch.nn.Sequential(layer)
elif (dropout is None) and (activation is not None):
self._model = torch.nn.Sequential(layer,
activation)
elif (dropout is not None) and (activation is None):
self._model = torch.nn.Sequential(layer,
self._drop(dropout))
elif (dropout is not None) and (activation is not None):
self._model = torch.nn.Sequential(layer,
activation,
self._drop(dropout))
def forward(self, x):
return self._model(x)