288 lines
9.1 KiB
Python
288 lines
9.1 KiB
Python
""" Solver module. """
|
|
|
|
from abc import ABCMeta, abstractmethod
|
|
from ..model.network import Network
|
|
import pytorch_lightning
|
|
from ..utils import check_consistency
|
|
from ..problem import AbstractProblem
|
|
from ..optim import Optimizer, Scheduler
|
|
import torch
|
|
import sys
|
|
|
|
|
|
# class SolverInterface(pytorch_lightning.LightningModule, metaclass=ABCMeta):
|
|
# """
|
|
# Solver base class. This class inherits is a wrapper of
|
|
# LightningModule class, inheriting all the
|
|
# LightningModule methods.
|
|
# """
|
|
|
|
# def __init__(
|
|
# self,
|
|
# models,
|
|
# problem,
|
|
# optimizers,
|
|
# optimizers_kwargs,
|
|
# extra_features=None,
|
|
# ):
|
|
# """
|
|
# :param models: A torch neural network model instance.
|
|
# :type models: torch.nn.Module
|
|
# :param problem: A problem definition instance.
|
|
# :type problem: AbstractProblem
|
|
# :param list(torch.optim.Optimizer) optimizer: A list of neural network optimizers to
|
|
# use.
|
|
# :param list(dict) optimizer_kwargs: A list of optimizer constructor keyword args.
|
|
# :param list(torch.nn.Module) extra_features: The additional input
|
|
# features to use as augmented input. If ``None`` no extra features
|
|
# are passed. If it is a list of :class:`torch.nn.Module`, the extra feature
|
|
# list is passed to all models. If it is a list of extra features' lists,
|
|
# each single list of extra feature is passed to a model.
|
|
# """
|
|
# super().__init__()
|
|
|
|
# # check consistency of the inputs
|
|
# check_consistency(models, torch.nn.Module)
|
|
# check_consistency(problem, AbstractProblem)
|
|
# check_consistency(optimizers, torch.optim.Optimizer, subclass=True)
|
|
# check_consistency(optimizers_kwargs, dict)
|
|
|
|
# # put everything in a list if only one input
|
|
# if not isinstance(models, list):
|
|
# models = [models]
|
|
# if not isinstance(optimizers, list):
|
|
# optimizers = [optimizers]
|
|
# optimizers_kwargs = [optimizers_kwargs]
|
|
|
|
# # number of models and optimizers
|
|
# len_model = len(models)
|
|
# len_optimizer = len(optimizers)
|
|
# len_optimizer_kwargs = len(optimizers_kwargs)
|
|
|
|
# # check length consistency optimizers
|
|
# if len_model != len_optimizer:
|
|
# raise ValueError(
|
|
# "You must define one optimizer for each model."
|
|
# f"Got {len_model} models, and {len_optimizer}"
|
|
# " optimizers."
|
|
# )
|
|
|
|
# # check length consistency optimizers kwargs
|
|
# if len_optimizer_kwargs != len_optimizer:
|
|
# raise ValueError(
|
|
# "You must define one dictionary of keyword"
|
|
# " arguments for each optimizers."
|
|
# f"Got {len_optimizer} optimizers, and"
|
|
# f" {len_optimizer_kwargs} dicitionaries"
|
|
# )
|
|
|
|
# # extra features handling
|
|
# if (extra_features is None) or (len(extra_features) == 0):
|
|
# extra_features = [None] * len_model
|
|
# else:
|
|
# # if we only have a list of extra features
|
|
# if not isinstance(extra_features[0], (tuple, list)):
|
|
# extra_features = [extra_features] * len_model
|
|
# else: # if we have a list of list extra features
|
|
# if len(extra_features) != len_model:
|
|
# raise ValueError(
|
|
# "You passed a list of extrafeatures list with len"
|
|
# f"different of models len. Expected {len_model} "
|
|
# f"got {len(extra_features)}. If you want to use "
|
|
# "the same list of extra features for all models, "
|
|
# "just pass a list of extrafeatures and not a list "
|
|
# "of list of extra features."
|
|
# )
|
|
|
|
# # assigning model and optimizers
|
|
# self._pina_models = []
|
|
# self._pina_optimizers = []
|
|
|
|
# for idx in range(len_model):
|
|
# model_ = Network(
|
|
# model=models[idx],
|
|
# input_variables=problem.input_variables,
|
|
# output_variables=problem.output_variables,
|
|
# extra_features=extra_features[idx],
|
|
# )
|
|
# optim_ = optimizers[idx](
|
|
# model_.parameters(), **optimizers_kwargs[idx]
|
|
# )
|
|
# self._pina_models.append(model_)
|
|
# self._pina_optimizers.append(optim_)
|
|
|
|
# # assigning problem
|
|
# self._pina_problem = problem
|
|
|
|
# @abstractmethod
|
|
# def forward(self, *args, **kwargs):
|
|
# pass
|
|
|
|
# @abstractmethod
|
|
# def training_step(self):
|
|
# pass
|
|
|
|
# @abstractmethod
|
|
# def configure_optimizers(self):
|
|
# pass
|
|
|
|
# @property
|
|
# def models(self):
|
|
# """
|
|
# The torch model."""
|
|
# return self._pina_models
|
|
|
|
# @property
|
|
# def optimizers(self):
|
|
# """
|
|
# The torch model."""
|
|
# return self._pina_optimizers
|
|
|
|
# @property
|
|
# def problem(self):
|
|
# """
|
|
# The problem formulation."""
|
|
# return self._pina_problem
|
|
|
|
# def on_train_start(self):
|
|
# """
|
|
# On training epoch start this function is call to do global checks for
|
|
# the different solvers.
|
|
# """
|
|
|
|
# # 1. Check the verison for dataloader
|
|
# dataloader = self.trainer.train_dataloader
|
|
# if sys.version_info < (3, 8):
|
|
# dataloader = dataloader.loaders
|
|
# self._dataloader = dataloader
|
|
|
|
# return super().on_train_start()
|
|
|
|
# @model.setter
|
|
# def model(self, new_model):
|
|
# """
|
|
# Set the torch."""
|
|
# check_consistency(new_model, nn.Module, 'torch model')
|
|
# self._model= new_model
|
|
|
|
# @problem.setter
|
|
# def problem(self, problem):
|
|
# """
|
|
# Set the problem formulation."""
|
|
# check_consistency(problem, AbstractProblem, 'pina problem')
|
|
# self._problem = problem
|
|
|
|
class SolverInterface(pytorch_lightning.LightningModule, metaclass=ABCMeta):
|
|
"""
|
|
Solver base class. This class inherits is a wrapper of
|
|
LightningModule class, inheriting all the
|
|
LightningModule methods.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
model,
|
|
problem,
|
|
optimizer,
|
|
scheduler,
|
|
):
|
|
"""
|
|
:param model: A torch neural network model instance.
|
|
:type model: torch.nn.Module
|
|
:param problem: A problem definition instance.
|
|
:type problem: AbstractProblem
|
|
:param list(torch.optim.Optimizer) optimizer: A list of neural network
|
|
optimizers to use.
|
|
"""
|
|
super().__init__()
|
|
|
|
# check consistency of the inputs
|
|
check_consistency(model, torch.nn.Module)
|
|
check_consistency(problem, AbstractProblem)
|
|
check_consistency(optimizer, Optimizer)
|
|
check_consistency(scheduler, Scheduler)
|
|
|
|
# put everything in a list if only one input
|
|
if not isinstance(model, list):
|
|
model = [model]
|
|
if not isinstance(scheduler, list):
|
|
scheduler = [scheduler]
|
|
if not isinstance(optimizer, list):
|
|
optimizer = [optimizer]
|
|
|
|
# number of models and optimizers
|
|
len_model = len(model)
|
|
len_optimizer = len(optimizer)
|
|
|
|
# check length consistency optimizers
|
|
if len_model != len_optimizer:
|
|
raise ValueError(
|
|
"You must define one optimizer for each model."
|
|
f"Got {len_model} models, and {len_optimizer}"
|
|
" optimizers."
|
|
)
|
|
|
|
# extra features handling
|
|
self._pina_problem = problem
|
|
self._pina_model = model
|
|
self._pina_optimizer = optimizer
|
|
self._pina_scheduler = scheduler
|
|
|
|
@abstractmethod
|
|
def forward(self, *args, **kwargs):
|
|
pass
|
|
|
|
@abstractmethod
|
|
def training_step(self):
|
|
pass
|
|
|
|
@abstractmethod
|
|
def configure_optimizers(self):
|
|
pass
|
|
|
|
@property
|
|
def models(self):
|
|
"""
|
|
The torch model."""
|
|
return self._pina_model
|
|
|
|
@property
|
|
def optimizers(self):
|
|
"""
|
|
The torch model."""
|
|
return self._pina_optimizer
|
|
|
|
@property
|
|
def problem(self):
|
|
"""
|
|
The problem formulation."""
|
|
return self._pina_problem
|
|
|
|
def on_train_start(self):
|
|
"""
|
|
On training epoch start this function is call to do global checks for
|
|
the different solvers.
|
|
"""
|
|
|
|
# 1. Check the verison for dataloader
|
|
dataloader = self.trainer.train_dataloader
|
|
if sys.version_info < (3, 8):
|
|
dataloader = dataloader.loaders
|
|
self._dataloader = dataloader
|
|
|
|
return super().on_train_start()
|
|
|
|
# @model.setter
|
|
# def model(self, new_model):
|
|
# """
|
|
# Set the torch."""
|
|
# check_consistency(new_model, nn.Module, 'torch model')
|
|
# self._model= new_model
|
|
|
|
# @problem.setter
|
|
# def problem(self, problem):
|
|
# """
|
|
# Set the problem formulation."""
|
|
# check_consistency(problem, AbstractProblem, 'pina problem')
|
|
# self._problem = problem
|