Network handles forward for all solvers
This commit is contained in:
committed by
Nicola Demo
parent
4844640727
commit
c90301c204
@@ -80,7 +80,7 @@ class SolverInterface(pytorch_lightning.LightningModule, metaclass=ABCMeta):
|
||||
raise ValueError(
|
||||
'You passed a list of extrafeatures list with len'
|
||||
f'different of models len. Expected {len_model} '
|
||||
f'got {len(extra_features)}. If you want to use'
|
||||
f'got {len(extra_features)}. If you want to use '
|
||||
'the same list of extra features for all models, '
|
||||
'just pass a list of extrafeatures and not a list '
|
||||
'of list of extra features.')
|
||||
@@ -91,6 +91,8 @@ class SolverInterface(pytorch_lightning.LightningModule, metaclass=ABCMeta):
|
||||
|
||||
for idx in range(len_model):
|
||||
model_ = Network(model=models[idx],
|
||||
input_variables=problem.input_variables,
|
||||
output_variables=problem.output_variables,
|
||||
extra_features=extra_features[idx])
|
||||
optim_ = optimizers[idx](model_.parameters(),
|
||||
**optimizers_kwargs[idx])
|
||||
|
||||
Reference in New Issue
Block a user