Refactoring code
This commit is contained in:
@@ -1,5 +1,3 @@
|
||||
|
||||
from .problem import Problem
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
@@ -12,18 +10,18 @@ from pina.label_tensor import LabelTensor
|
||||
|
||||
class DeepFeedForward(torch.nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
inner_size=20,
|
||||
n_layers=2,
|
||||
func=nn.Tanh,
|
||||
input_variables=None,
|
||||
output_variables=None,
|
||||
layers=None,
|
||||
def __init__(self,
|
||||
inner_size=20,
|
||||
n_layers=2,
|
||||
func=nn.Tanh,
|
||||
input_variables=None,
|
||||
output_variables=None,
|
||||
layers=None,
|
||||
extra_features=None):
|
||||
'''
|
||||
'''
|
||||
super(DeepFeedForward, self).__init__()
|
||||
|
||||
|
||||
if extra_features is None:
|
||||
extra_features = []
|
||||
self.extra_features = nn.Sequential(*extra_features)
|
||||
@@ -48,7 +46,7 @@ class DeepFeedForward(torch.nn.Module):
|
||||
self.layers = []
|
||||
for i in range(len(tmp_layers)-1):
|
||||
self.layers.append(nn.Linear(tmp_layers[i], tmp_layers[i+1]))
|
||||
|
||||
|
||||
|
||||
|
||||
if isinstance(func, list):
|
||||
|
||||
Reference in New Issue
Block a user