use LabelTensor, fix minor, docs
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
"""Module for FeedForward model"""
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
@@ -5,22 +6,50 @@ from pina.label_tensor import LabelTensor
|
||||
|
||||
|
||||
class FeedForward(torch.nn.Module):
|
||||
"""
|
||||
The PINA implementation of feedforward network, also refered as multilayer
|
||||
perceptron.
|
||||
|
||||
:param list(str) input_variables: the list containing the labels
|
||||
corresponding to the input components of the model.
|
||||
:param list(str) output_variables: the list containing the labels
|
||||
corresponding to the components of the output computed by the model.
|
||||
:param int inner_size: number of neurons in the hidden layer(s). Default is
|
||||
20.
|
||||
:param int n_layers: number of hidden layers. Default is 2.
|
||||
:param func: the activation function to use. If a single
|
||||
:class:`torch.nn.Module` is passed, this is used as activation function
|
||||
after any layers, except the last one. If a list of Modules is passed,
|
||||
they are used as activation functions at any layers, in order.
|
||||
:param iterable(int) layers: a list containing the number of neurons for
|
||||
any hidden layers. If specified, the parameters `n_layers` e
|
||||
`inner_size` are not considered.
|
||||
:param iterable(torch.nn.Module) extra_features: the additional input
|
||||
features to use ad augmented input.
|
||||
"""
|
||||
def __init__(self, input_variables, output_variables, inner_size=20,
|
||||
n_layers=2, func=nn.Tanh, layers=None, extra_features=None):
|
||||
'''
|
||||
'''
|
||||
"""
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
if extra_features is None:
|
||||
extra_features = []
|
||||
self.extra_features = nn.Sequential(*extra_features)
|
||||
|
||||
self.input_variables = input_variables
|
||||
self.input_dimension = len(input_variables)
|
||||
if isinstance(input_variables, int):
|
||||
self.input_variables = None
|
||||
self.input_dimension = input_variables
|
||||
elif isinstance(input_variables, (tuple, list)):
|
||||
self.input_variables = input_variables
|
||||
self.input_dimension = len(input_variables)
|
||||
|
||||
self.output_variables = output_variables
|
||||
self.output_dimension = len(output_variables)
|
||||
if isinstance(output_variables, int):
|
||||
self.output_variables = None
|
||||
self.output_dimension = output_variables
|
||||
elif isinstance(output_variables, (tuple, list)):
|
||||
self.output_variables = output_variables
|
||||
self.output_dimension = len(output_variables)
|
||||
|
||||
n_features = len(extra_features)
|
||||
|
||||
@@ -40,6 +69,9 @@ class FeedForward(torch.nn.Module):
|
||||
else:
|
||||
self.functions = [func for _ in range(len(self.layers)-1)]
|
||||
|
||||
if len(self.layers) != len(self.functions) + 1:
|
||||
raise RuntimeError('uncosistent number of layers and functions')
|
||||
|
||||
unique_list = []
|
||||
for layer, func in zip(self.layers[:-1], self.functions):
|
||||
unique_list.append(layer)
|
||||
@@ -51,18 +83,30 @@ class FeedForward(torch.nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Defines the computation performed at every call.
|
||||
|
||||
:param x: the input tensor.
|
||||
:type x: :class:`pina.LabelTensor`
|
||||
:return: the output computed by the model.
|
||||
:rtype: LabelTensor
|
||||
"""
|
||||
if self.input_variables:
|
||||
x = x.extract(self.input_variables)
|
||||
|
||||
x = x[self.input_variables]
|
||||
nf = len(self.extra_features)
|
||||
if nf == 0:
|
||||
return LabelTensor(self.model(x.tensor), self.output_variables)
|
||||
labels = []
|
||||
features = []
|
||||
for i, feature in enumerate(self.extra_features):
|
||||
labels.append('k{}'.format(i))
|
||||
features.append(feature(x))
|
||||
|
||||
# if self.extra_features
|
||||
input_ = torch.zeros(x.shape[0], nf+x.shape[1], dtype=x.dtype,
|
||||
device=x.device)
|
||||
input_[:, :x.shape[1]] = x.tensor
|
||||
for i, feature in enumerate(self.extra_features,
|
||||
start=self.input_dimension):
|
||||
input_[:, i] = feature(x)
|
||||
return LabelTensor(self.model(input_), self.output_variables)
|
||||
if labels and features:
|
||||
features = torch.cat(features, dim=1)
|
||||
features_tensor = LabelTensor(features, labels)
|
||||
input_ = x.append(features_tensor) # TODO error when no LabelTens
|
||||
else:
|
||||
input_ = x
|
||||
|
||||
if self.output_variables:
|
||||
return LabelTensor(self.model(input_), self.output_variables)
|
||||
else:
|
||||
return self.model(input_)
|
||||
|
||||
Reference in New Issue
Block a user