edited utils to take list (#115)

* enhanced difference domain
* refactored utils
* fixed typo
* added tests
---------

Co-authored-by: Dario Coscia <93731561+dario-coscia@users.noreply.github.com>
This commit is contained in:
Kush
2023-06-19 18:47:52 +02:00
committed by Nicola Demo
parent aaf2bed732
commit 62ec69ccac
9 changed files with 73 additions and 47 deletions

View File

@@ -4,20 +4,20 @@ from ..utils import check_consistency
class Network(torch.nn.Module):
def __init__(self, model, extra_features=None):
super().__init__()
# check model consistency
check_consistency(model, nn.Module, 'torch model')
check_consistency(model, nn.Module)
self._model = model
# check consistency and assign extra fatures
# check consistency and assign extra fatures
if extra_features is None:
self._extra_features = []
else:
for feat in extra_features:
check_consistency(feat, nn.Module, 'extra features')
check_consistency(feat, nn.Module)
self._extra_features = nn.Sequential(*extra_features)
# check model works with inputs
@@ -44,4 +44,4 @@ class Network(torch.nn.Module):
@property
def extra_features(self):
return self._extra_features
return self._extra_features