Files
PINA/pina/model/network.py
Dario Coscia 8b7b61b3bd Documentation for v0.1 version (#199)
* Adding Equations, solving typos
* improve _code.rst
* the team rst and restuctore index.rst
* fixing errors

---------

Co-authored-by: Dario Coscia <dariocoscia@dhcp-015.eduroam.sissa.it>
2023-11-17 09:51:29 +01:00

62 lines
1.9 KiB
Python

import torch
import torch.nn as nn
from ..utils import check_consistency
class Network(torch.nn.Module):
def __init__(self, model, extra_features=None):
"""
Network class with standard forward method
and possibility to pass extra features. This
class is used internally in PINA to convert
any :class:`torch.nn.Module` s in a PINA module.
:param model: The torch model to convert in a PINA model.
:type model: torch.nn.Module
:param extra_features: List of torch models to augment the input, defaults to None.
:type extra_features: list(torch.nn.Module)
"""
super().__init__()
# check model consistency
check_consistency(model, nn.Module)
self._model = model
# check consistency and assign extra fatures
if extra_features is None:
self._extra_features = []
else:
for feat in extra_features:
check_consistency(feat, nn.Module)
self._extra_features = nn.Sequential(*extra_features)
# check model works with inputs
# TODO
def forward(self, x):
"""
Forward method for Network class. This class
implements the standard forward method, and
it adds the possibility to pass extra features.
All the PINA models ``forward`` s are overriden
by this class, to enable :class:`pina.label_tensor.LabelTensor` labels
extraction.
:param torch.Tensor x: Input of the network.
:return torch.Tensor: Output of the network.
"""
# extract features and append
for feature in self._extra_features:
x = x.append(feature(x))
# perform forward pass
return self._model(x)
@property
def model(self):
return self._model
@property
def extra_features(self):
return self._extra_features