* Adding Equations, solving typos * improve _code.rst * the team rst and restuctore index.rst * fixing errors --------- Co-authored-by: Dario Coscia <dariocoscia@dhcp-015.eduroam.sissa.it>
62 lines
1.9 KiB
Python
62 lines
1.9 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
from ..utils import check_consistency
|
|
|
|
|
|
class Network(torch.nn.Module):
|
|
|
|
def __init__(self, model, extra_features=None):
|
|
"""
|
|
Network class with standard forward method
|
|
and possibility to pass extra features. This
|
|
class is used internally in PINA to convert
|
|
any :class:`torch.nn.Module` s in a PINA module.
|
|
|
|
:param model: The torch model to convert in a PINA model.
|
|
:type model: torch.nn.Module
|
|
:param extra_features: List of torch models to augment the input, defaults to None.
|
|
:type extra_features: list(torch.nn.Module)
|
|
"""
|
|
super().__init__()
|
|
|
|
# check model consistency
|
|
check_consistency(model, nn.Module)
|
|
self._model = model
|
|
|
|
# check consistency and assign extra fatures
|
|
if extra_features is None:
|
|
self._extra_features = []
|
|
else:
|
|
for feat in extra_features:
|
|
check_consistency(feat, nn.Module)
|
|
self._extra_features = nn.Sequential(*extra_features)
|
|
|
|
# check model works with inputs
|
|
# TODO
|
|
|
|
def forward(self, x):
|
|
"""
|
|
Forward method for Network class. This class
|
|
implements the standard forward method, and
|
|
it adds the possibility to pass extra features.
|
|
All the PINA models ``forward`` s are overriden
|
|
by this class, to enable :class:`pina.label_tensor.LabelTensor` labels
|
|
extraction.
|
|
|
|
:param torch.Tensor x: Input of the network.
|
|
:return torch.Tensor: Output of the network.
|
|
"""
|
|
# extract features and append
|
|
for feature in self._extra_features:
|
|
x = x.append(feature(x))
|
|
# perform forward pass
|
|
return self._model(x)
|
|
|
|
@property
|
|
def model(self):
|
|
return self._model
|
|
|
|
@property
|
|
def extra_features(self):
|
|
return self._extra_features
|