""" Solver module. """ from abc import ABCMeta, abstractmethod from ..model.network import Network import lightning.pytorch as pl from ..utils import check_consistency from ..problem import AbstractProblem import torch class SolverInterface(pl.LightningModule, metaclass=ABCMeta): """ Solver base class. """ def __init__(self, models, problem, optimizers, optimizers_kwargs, extra_features=None): """ :param models: A torch neural network model instance. :type models: torch.nn.Module :param problem: A problem definition instance. :type problem: AbstractProblem :param list(torch.nn.Module) extra_features: the additional input features to use as augmented input. If ``None`` no extra features are passed. If it is a list of ``torch.nn.Module``, the extra feature list is passed to all models. If it is a list of extra features' lists, each single list of extra feature is passed to a model. """ super().__init__() # check consistency of the inputs check_consistency(models, torch.nn.Module) check_consistency(problem, AbstractProblem) check_consistency(optimizers, torch.optim.Optimizer, subclass=True) check_consistency(optimizers_kwargs, dict) # put everything in a list if only one input if not isinstance(models, list): models = [models] if not isinstance(optimizers, list): optimizers = [optimizers] optimizers_kwargs = [optimizers_kwargs] # number of models and optimizers len_model = len(models) len_optimizer = len(optimizers) len_optimizer_kwargs = len(optimizers_kwargs) # check length consistency optimizers if len_model != len_optimizer: raise ValueError('You must define one optimizer for each model.' f'Got {len_model} models, and {len_optimizer}' ' optimizers.') # check length consistency optimizers kwargs if len_optimizer_kwargs != len_optimizer: raise ValueError('You must define one dictionary of keyword' ' arguments for each optimizers.' f'Got {len_optimizer} optimizers, and' f' {len_optimizer_kwargs} dicitionaries') # extra features handling if extra_features is None: extra_features = [None] * len_model else: # if we only have a list of extra features if not isinstance(extra_features[0], (tuple, list)): extra_features = [extra_features] * len_model else: # if we have a list of list extra features if len(extra_features) != len_model: raise ValueError('You passed a list of extrafeatures list with len' f'different of models len. Expected {len_model} ' f'got {len(extra_features)}. If you want to use' 'the same list of extra features for all models, ' 'just pass a list of extrafeatures and not a list ' 'of list of extra features.') # assigning model and optimizers self._pina_models = [] self._pina_optimizers = [] for idx in range(len_model): model_ = Network(model=models[idx], extra_features=extra_features[idx]) optim_ = optimizers[idx](model_.parameters(), **optimizers_kwargs[idx]) self._pina_models.append(model_) self._pina_optimizers.append(optim_) # assigning problem self._pina_problem = problem @abstractmethod def forward(self): pass @abstractmethod def training_step(self): pass @abstractmethod def configure_optimizers(self): pass @property def models(self): """ The torch model.""" return self._pina_models @property def optimizers(self): """ The torch model.""" return self._pina_optimizers @property def problem(self): """ The problem formulation.""" return self._pina_problem # @model.setter # def model(self, new_model): # """ # Set the torch.""" # check_consistency(new_model, nn.Module, 'torch model') # self._model= new_model # @problem.setter # def problem(self, problem): # """ # Set the problem formulation.""" # check_consistency(problem, AbstractProblem, 'pina problem') # self._problem = problem