Files
PINA/pina/model/multi_feed_forward.py
Dario Coscia 8b7b61b3bd Documentation for v0.1 version (#199)
* Adding Equations, solving typos
* improve _code.rst
* the team rst and restuctore index.rst
* fixing errors

---------

Co-authored-by: Dario Coscia <dariocoscia@dhcp-015.eduroam.sissa.it>
2023-11-17 09:51:29 +01:00

26 lines
729 B
Python

"""Module for Multi FeedForward model"""
import torch
from .feed_forward import FeedForward
class MultiFeedForward(torch.nn.Module):
"""
The PINA implementation of MultiFeedForward network.
This model allows to create a network with multiple FeedForward combined
together. The user has to define the `forward` method choosing how to
combine the different FeedForward networks.
:param dict ffn_dict: dictionary of FeedForward networks.
"""
def __init__(self, ffn_dict):
super().__init__()
if not isinstance(ffn_dict, dict):
raise TypeError
for name, constructor_args in ffn_dict.items():
setattr(self, name, FeedForward(**constructor_args))