* clean `condition` module
* add docs
This commit is contained in:
Nicola Demo
2023-04-18 15:00:26 +02:00
committed by GitHub
parent 736c78fd64
commit 2ca08b5236
18 changed files with 198 additions and 158 deletions

View File

@@ -6,16 +6,17 @@ from .feed_forward import FeedForward
class MultiFeedForward(torch.nn.Module):
"""
This model allows to create a network with multiple FeedForward combined
together. The user has to define the `forward` method choosing how to
combine the different FeedForward networks.
:param dict dff_dict: dictionary of FeedForward networks.
"""
def __init__(self, dff_dict):
'''
dff_dict: dict of FeedForward objects
'''
def __init__(self, ffn_dict):
super().__init__()
if not isinstance(dff_dict, dict):
if not isinstance(ffn_dict, dict):
raise TypeError
for name, constructor_args in dff_dict.items():
for name, constructor_args in ffn_dict.items():
setattr(self, name, FeedForward(**constructor_args))