edited utils to take list (#115)
* enhanced difference domain * refactored utils * fixed typo * added tests --------- Co-authored-by: Dario Coscia <93731561+dario-coscia@users.noreply.github.com>
This commit is contained in:
@@ -4,20 +4,20 @@ from ..utils import check_consistency
|
||||
|
||||
|
||||
class Network(torch.nn.Module):
|
||||
|
||||
|
||||
def __init__(self, model, extra_features=None):
|
||||
super().__init__()
|
||||
|
||||
# check model consistency
|
||||
check_consistency(model, nn.Module, 'torch model')
|
||||
check_consistency(model, nn.Module)
|
||||
self._model = model
|
||||
|
||||
# check consistency and assign extra fatures
|
||||
# check consistency and assign extra fatures
|
||||
if extra_features is None:
|
||||
self._extra_features = []
|
||||
else:
|
||||
for feat in extra_features:
|
||||
check_consistency(feat, nn.Module, 'extra features')
|
||||
check_consistency(feat, nn.Module)
|
||||
self._extra_features = nn.Sequential(*extra_features)
|
||||
|
||||
# check model works with inputs
|
||||
@@ -44,4 +44,4 @@ class Network(torch.nn.Module):
|
||||
|
||||
@property
|
||||
def extra_features(self):
|
||||
return self._extra_features
|
||||
return self._extra_features
|
||||
|
||||
Reference in New Issue
Block a user