🎨 Format Python code with psf/black

This commit is contained in:
ndem0
2024-02-09 11:25:00 +00:00
committed by Nicola Demo
parent 591aeeb02b
commit cbb43a5392
64 changed files with 1323 additions and 955 deletions

View File

@@ -15,12 +15,14 @@ class SolverInterface(pytorch_lightning.LightningModule, metaclass=ABCMeta):
LightningModule methods.
"""
def __init__(self,
models,
problem,
optimizers,
optimizers_kwargs,
extra_features=None):
def __init__(
self,
models,
problem,
optimizers,
optimizers_kwargs,
extra_features=None,
):
"""
:param models: A torch neural network model instance.
:type models: torch.nn.Module
@@ -30,7 +32,7 @@ class SolverInterface(pytorch_lightning.LightningModule, metaclass=ABCMeta):
use.
:param list(dict) optimizer_kwargs: A list of optimizer constructor keyword args.
:param list(torch.nn.Module) extra_features: The additional input
features to use as augmented input. If ``None`` no extra features
features to use as augmented input. If ``None`` no extra features
are passed. If it is a list of :class:`torch.nn.Module`, the extra feature
list is passed to all models. If it is a list of extra features' lists,
each single list of extra feature is passed to a model.
@@ -57,19 +59,23 @@ class SolverInterface(pytorch_lightning.LightningModule, metaclass=ABCMeta):
# check length consistency optimizers
if len_model != len_optimizer:
raise ValueError('You must define one optimizer for each model.'
f'Got {len_model} models, and {len_optimizer}'
' optimizers.')
raise ValueError(
"You must define one optimizer for each model."
f"Got {len_model} models, and {len_optimizer}"
" optimizers."
)
# check length consistency optimizers kwargs
if len_optimizer_kwargs != len_optimizer:
raise ValueError('You must define one dictionary of keyword'
' arguments for each optimizers.'
f'Got {len_optimizer} optimizers, and'
f' {len_optimizer_kwargs} dicitionaries')
raise ValueError(
"You must define one dictionary of keyword"
" arguments for each optimizers."
f"Got {len_optimizer} optimizers, and"
f" {len_optimizer_kwargs} dicitionaries"
)
# extra features handling
if (extra_features is None) or (len(extra_features)==0):
if (extra_features is None) or (len(extra_features) == 0):
extra_features = [None] * len_model
else:
# if we only have a list of extra features
@@ -78,24 +84,28 @@ class SolverInterface(pytorch_lightning.LightningModule, metaclass=ABCMeta):
else: # if we have a list of list extra features
if len(extra_features) != len_model:
raise ValueError(
'You passed a list of extrafeatures list with len'
f'different of models len. Expected {len_model} '
f'got {len(extra_features)}. If you want to use '
'the same list of extra features for all models, '
'just pass a list of extrafeatures and not a list '
'of list of extra features.')
"You passed a list of extrafeatures list with len"
f"different of models len. Expected {len_model} "
f"got {len(extra_features)}. If you want to use "
"the same list of extra features for all models, "
"just pass a list of extrafeatures and not a list "
"of list of extra features."
)
# assigning model and optimizers
self._pina_models = []
self._pina_optimizers = []
for idx in range(len_model):
model_ = Network(model=models[idx],
input_variables=problem.input_variables,
output_variables=problem.output_variables,
extra_features=extra_features[idx])
optim_ = optimizers[idx](model_.parameters(),
**optimizers_kwargs[idx])
model_ = Network(
model=models[idx],
input_variables=problem.input_variables,
output_variables=problem.output_variables,
extra_features=extra_features[idx],
)
optim_ = optimizers[idx](
model_.parameters(), **optimizers_kwargs[idx]
)
self._pina_models.append(model_)
self._pina_optimizers.append(optim_)