Network handles forward for all solvers

This commit is contained in:
Dario Coscia
2023-11-09 15:16:57 +01:00
committed by Nicola Demo
parent 4844640727
commit c90301c204
5 changed files with 63 additions and 67 deletions

View File

@@ -80,7 +80,7 @@ class SolverInterface(pytorch_lightning.LightningModule, metaclass=ABCMeta):
raise ValueError(
'You passed a list of extrafeatures list with len'
f'different of models len. Expected {len_model} '
f'got {len(extra_features)}. If you want to use'
f'got {len(extra_features)}. If you want to use '
'the same list of extra features for all models, '
'just pass a list of extrafeatures and not a list '
'of list of extra features.')
@@ -91,6 +91,8 @@ class SolverInterface(pytorch_lightning.LightningModule, metaclass=ABCMeta):
for idx in range(len_model):
model_ = Network(model=models[idx],
input_variables=problem.input_variables,
output_variables=problem.output_variables,
extra_features=extra_features[idx])
optim_ = optimizers[idx](model_.parameters(),
**optimizers_kwargs[idx])