* Adding Equations, solving typos * improve _code.rst * the team rst and restuctore index.rst * fixing errors --------- Co-authored-by: Dario Coscia <dariocoscia@dhcp-015.eduroam.sissa.it>
26 lines
729 B
Python
26 lines
729 B
Python
"""Module for Multi FeedForward model"""
|
|
import torch
|
|
|
|
from .feed_forward import FeedForward
|
|
|
|
|
|
class MultiFeedForward(torch.nn.Module):
|
|
"""
|
|
The PINA implementation of MultiFeedForward network.
|
|
|
|
This model allows to create a network with multiple FeedForward combined
|
|
together. The user has to define the `forward` method choosing how to
|
|
combine the different FeedForward networks.
|
|
|
|
:param dict ffn_dict: dictionary of FeedForward networks.
|
|
"""
|
|
|
|
def __init__(self, ffn_dict):
|
|
super().__init__()
|
|
|
|
if not isinstance(ffn_dict, dict):
|
|
raise TypeError
|
|
|
|
for name, constructor_args in ffn_dict.items():
|
|
setattr(self, name, FeedForward(**constructor_args))
|