Solvers for multiple models (#133)
* Solvers for multiple models - Implementing the possibility to add multiple models for solvers (e.g. GAN) - Implementing GAROM solver, see https://arxiv.org/abs/2305.15881 - Implementing tests for GAROM solver (cpu only) - Fixing docs PINNs - Creating a solver directory, for consistency in the package --------- Co-authored-by: Dario Coscia <dariocoscia@dhcp-040.eduroam.sissa.it>
This commit is contained in:
committed by
Nicola Demo
parent
6c8635c316
commit
701046661f
134
pina/solvers/solver.py
Normal file
134
pina/solvers/solver.py
Normal file
@@ -0,0 +1,134 @@
|
||||
""" Solver module. """
|
||||
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from ..model.network import Network
|
||||
import lightning.pytorch as pl
|
||||
from ..utils import check_consistency
|
||||
from ..problem import AbstractProblem
|
||||
import torch
|
||||
|
||||
|
||||
class SolverInterface(pl.LightningModule, metaclass=ABCMeta):
|
||||
""" Solver base class. """
|
||||
def __init__(self,
|
||||
models,
|
||||
problem,
|
||||
optimizers,
|
||||
optimizers_kwargs,
|
||||
extra_features=None):
|
||||
"""
|
||||
:param models: A torch neural network model instance.
|
||||
:type models: torch.nn.Module
|
||||
:param problem: A problem definition instance.
|
||||
:type problem: AbstractProblem
|
||||
:param list(torch.nn.Module) extra_features: the additional input
|
||||
features to use as augmented input. If ``None`` no extra features
|
||||
are passed. If it is a list of ``torch.nn.Module``, the extra feature
|
||||
list is passed to all models. If it is a list of extra features' lists,
|
||||
each single list of extra feature is passed to a model.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
# check consistency of the inputs
|
||||
check_consistency(models, torch.nn.Module)
|
||||
check_consistency(problem, AbstractProblem)
|
||||
check_consistency(optimizers, torch.optim.Optimizer, subclass=True)
|
||||
check_consistency(optimizers_kwargs, dict)
|
||||
|
||||
# put everything in a list if only one input
|
||||
if not isinstance(models, list):
|
||||
models = [models]
|
||||
if not isinstance(optimizers, list):
|
||||
optimizers = [optimizers]
|
||||
optimizers_kwargs = [optimizers_kwargs]
|
||||
|
||||
# number of models and optimizers
|
||||
len_model = len(models)
|
||||
len_optimizer = len(optimizers)
|
||||
len_optimizer_kwargs = len(optimizers_kwargs)
|
||||
|
||||
# check length consistency optimizers
|
||||
if len_model != len_optimizer:
|
||||
raise ValueError('You must define one optimizer for each model.'
|
||||
f'Got {len_model} models, and {len_optimizer}'
|
||||
' optimizers.')
|
||||
|
||||
# check length consistency optimizers kwargs
|
||||
if len_optimizer_kwargs != len_optimizer:
|
||||
raise ValueError('You must define one dictionary of keyword'
|
||||
' arguments for each optimizers.'
|
||||
f'Got {len_optimizer} optimizers, and'
|
||||
f' {len_optimizer_kwargs} dicitionaries')
|
||||
|
||||
# extra features handling
|
||||
if extra_features is None:
|
||||
extra_features = [None] * len_model
|
||||
else:
|
||||
# if we only have a list of extra features
|
||||
if not isinstance(extra_features[0], (tuple, list)):
|
||||
extra_features = [extra_features] * len_model
|
||||
else: # if we have a list of list extra features
|
||||
if len(extra_features) != len_model:
|
||||
raise ValueError('You passed a list of extrafeatures list with len'
|
||||
f'different of models len. Expected {len_model} '
|
||||
f'got {len(extra_features)}. If you want to use'
|
||||
'the same list of extra features for all models, '
|
||||
'just pass a list of extrafeatures and not a list '
|
||||
'of list of extra features.')
|
||||
|
||||
# assigning model and optimizers
|
||||
self._pina_models = []
|
||||
self._pina_optimizers = []
|
||||
|
||||
for idx in range(len_model):
|
||||
model_ = Network(model=models[idx], extra_features=extra_features[idx])
|
||||
optim_ = optimizers[idx](model_.parameters(), **optimizers_kwargs[idx])
|
||||
self._pina_models.append(model_)
|
||||
self._pina_optimizers.append(optim_)
|
||||
|
||||
# assigning problem
|
||||
self._pina_problem = problem
|
||||
|
||||
@abstractmethod
|
||||
def forward(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def training_step(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def configure_optimizers(self):
|
||||
pass
|
||||
|
||||
@property
|
||||
def models(self):
|
||||
"""
|
||||
The torch model."""
|
||||
return self._pina_models
|
||||
|
||||
@property
|
||||
def optimizers(self):
|
||||
"""
|
||||
The torch model."""
|
||||
return self._pina_optimizers
|
||||
|
||||
@property
|
||||
def problem(self):
|
||||
"""
|
||||
The problem formulation."""
|
||||
return self._pina_problem
|
||||
|
||||
# @model.setter
|
||||
# def model(self, new_model):
|
||||
# """
|
||||
# Set the torch."""
|
||||
# check_consistency(new_model, nn.Module, 'torch model')
|
||||
# self._model= new_model
|
||||
|
||||
# @problem.setter
|
||||
# def problem(self, problem):
|
||||
# """
|
||||
# Set the problem formulation."""
|
||||
# check_consistency(problem, AbstractProblem, 'pina problem')
|
||||
# self._problem = problem
|
||||
Reference in New Issue
Block a user