🎨 Format Python code with psf/black
This commit is contained in:
@@ -16,18 +16,20 @@ class ResidualBlock(nn.Module):
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
input_dim,
|
||||
output_dim,
|
||||
hidden_dim,
|
||||
spectral_norm=False,
|
||||
activation=torch.nn.ReLU()):
|
||||
def __init__(
|
||||
self,
|
||||
input_dim,
|
||||
output_dim,
|
||||
hidden_dim,
|
||||
spectral_norm=False,
|
||||
activation=torch.nn.ReLU(),
|
||||
):
|
||||
"""
|
||||
Initializes the ResidualBlock module.
|
||||
|
||||
:param int input_dim: Dimension of the input to pass to the
|
||||
feedforward linear layer.
|
||||
:param int output_dim: Dimension of the output from the
|
||||
:param int output_dim: Dimension of the output from the
|
||||
residual layer.
|
||||
:param int hidden_dim: Hidden dimension for mapping the input
|
||||
(first block).
|
||||
@@ -82,6 +84,7 @@ class ResidualBlock(nn.Module):
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class EnhancedLinear(torch.nn.Module):
|
||||
"""
|
||||
A wrapper class for enhancing a linear layer with activation and/or dropout.
|
||||
@@ -132,8 +135,9 @@ class EnhancedLinear(torch.nn.Module):
|
||||
self._model = torch.nn.Sequential(layer, self._drop(dropout))
|
||||
|
||||
elif (dropout is not None) and (activation is not None):
|
||||
self._model = torch.nn.Sequential(layer, activation,
|
||||
self._drop(dropout))
|
||||
self._model = torch.nn.Sequential(
|
||||
layer, activation, self._drop(dropout)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user