Documentation for v0.1 version (#199)

* Adding Equations, solving typos
* improve _code.rst
* the team rst and restuctore index.rst
* fixing errors

---------

Co-authored-by: Dario Coscia <dariocoscia@dhcp-015.eduroam.sissa.it>
This commit is contained in:
Dario Coscia
2023-11-08 14:39:00 +01:00
committed by Nicola Demo
parent 3f9305d475
commit 8b7b61b3bd
144 changed files with 2741 additions and 1766 deletions

View File

@@ -2,30 +2,38 @@
from abc import ABCMeta, abstractmethod
from ..model.network import Network
import pytorch_lightning as pl
import pytorch_lightning
from ..utils import check_consistency
from ..problem import AbstractProblem
import torch
class SolverInterface(pl.LightningModule, metaclass=ABCMeta):
""" Solver base class. """
class SolverInterface(pytorch_lightning.LightningModule, metaclass=ABCMeta):
"""
Solver base class. This class inherits is a wrapper of
LightningModule class, inheriting all the
LightningModule methods.
"""
def __init__(self,
models,
problem,
optimizers,
optimizers_kwargs,
optimizers,
optimizers_kwargs,
extra_features=None):
"""
:param models: A torch neural network model instance.
:type models: torch.nn.Module
:param problem: A problem definition instance.
:type problem: AbstractProblem
:param list(torch.nn.Module) extra_features: the additional input
features to use as augmented input. If ``None`` no extra features
are passed. If it is a list of ``torch.nn.Module``, the extra feature
list is passed to all models. If it is a list of extra features' lists,
each single list of extra feature is passed to a model.
:param list(torch.optim.Optimizer) optimizer: A list of neural network optimizers to
use.
:param list(dict) optimizer_kwargs: A list of optimizer constructor keyword args.
:param list(torch.nn.Module) extra_features: The additional input
features to use as augmented input. If ``None`` no extra features
are passed. If it is a list of :class:`torch.nn.Module`, the extra feature
list is passed to all models. If it is a list of extra features' lists,
each single list of extra feature is passed to a model.
"""
super().__init__()
@@ -52,37 +60,40 @@ class SolverInterface(pl.LightningModule, metaclass=ABCMeta):
raise ValueError('You must define one optimizer for each model.'
f'Got {len_model} models, and {len_optimizer}'
' optimizers.')
# check length consistency optimizers kwargs
if len_optimizer_kwargs != len_optimizer:
raise ValueError('You must define one dictionary of keyword'
' arguments for each optimizers.'
f'Got {len_optimizer} optimizers, and'
f' {len_optimizer_kwargs} dicitionaries')
# extra features handling
if extra_features is None:
if extra_features is None:
extra_features = [None] * len_model
else:
# if we only have a list of extra features
if not isinstance(extra_features[0], (tuple, list)):
extra_features = [extra_features] * len_model
else: # if we have a list of list extra features
else: # if we have a list of list extra features
if len(extra_features) != len_model:
raise ValueError('You passed a list of extrafeatures list with len'
f'different of models len. Expected {len_model} '
f'got {len(extra_features)}. If you want to use'
'the same list of extra features for all models, '
'just pass a list of extrafeatures and not a list '
'of list of extra features.')
raise ValueError(
'You passed a list of extrafeatures list with len'
f'different of models len. Expected {len_model} '
f'got {len(extra_features)}. If you want to use'
'the same list of extra features for all models, '
'just pass a list of extrafeatures and not a list '
'of list of extra features.')
# assigning model and optimizers
self._pina_models = []
self._pina_optimizers = []
for idx in range(len_model):
model_ = Network(model=models[idx], extra_features=extra_features[idx])
optim_ = optimizers[idx](model_.parameters(), **optimizers_kwargs[idx])
model_ = Network(model=models[idx],
extra_features=extra_features[idx])
optim_ = optimizers[idx](model_.parameters(),
**optimizers_kwargs[idx])
self._pina_models.append(model_)
self._pina_optimizers.append(optim_)
@@ -90,9 +101,9 @@ class SolverInterface(pl.LightningModule, metaclass=ABCMeta):
self._pina_problem = problem
@abstractmethod
def forward(self):
def forward(self, *args, **kwargs):
pass
@abstractmethod
def training_step(self):
pass
@@ -131,4 +142,4 @@ class SolverInterface(pl.LightningModule, metaclass=ABCMeta):
# """
# Set the problem formulation."""
# check_consistency(problem, AbstractProblem, 'pina problem')
# self._problem = problem
# self._problem = problem