Refactoring code

This commit is contained in:
Your Name
2022-01-27 14:55:42 +01:00
parent fb16fc7f3a
commit fa8ffd5042
32 changed files with 417 additions and 442 deletions

View File

@@ -1,5 +1,3 @@
from .problem import Problem
import torch
import torch.nn as nn
import numpy as np
@@ -12,18 +10,18 @@ from pina.label_tensor import LabelTensor
class DeepFeedForward(torch.nn.Module):
def __init__(self,
inner_size=20,
n_layers=2,
func=nn.Tanh,
input_variables=None,
output_variables=None,
layers=None,
def __init__(self,
inner_size=20,
n_layers=2,
func=nn.Tanh,
input_variables=None,
output_variables=None,
layers=None,
extra_features=None):
'''
'''
super(DeepFeedForward, self).__init__()
if extra_features is None:
extra_features = []
self.extra_features = nn.Sequential(*extra_features)
@@ -48,7 +46,7 @@ class DeepFeedForward(torch.nn.Module):
self.layers = []
for i in range(len(tmp_layers)-1):
self.layers.append(nn.Linear(tmp_layers[i], tmp_layers[i+1]))
if isinstance(func, list):