Implement Dataset, Dataloader and DataModule class and fix SupervisedSolver

This commit is contained in:
FilippoOlivo
2024-10-16 11:24:37 +02:00
committed by Nicola Demo
parent b9753c34b2
commit c9304fb9bb
30 changed files with 770 additions and 784 deletions

View File

@@ -10,168 +10,6 @@ import torch
import sys
# class SolverInterface(pytorch_lightning.LightningModule, metaclass=ABCMeta):
# """
# Solver base class. This class inherits is a wrapper of
# LightningModule class, inheriting all the
# LightningModule methods.
# """
# def __init__(
# self,
# models,
# problem,
# optimizers,
# optimizers_kwargs,
# extra_features=None,
# ):
# """
# :param models: A torch neural network model instance.
# :type models: torch.nn.Module
# :param problem: A problem definition instance.
# :type problem: AbstractProblem
# :param list(torch.optim.Optimizer) optimizer: A list of neural network optimizers to
# use.
# :param list(dict) optimizer_kwargs: A list of optimizer constructor keyword args.
# :param list(torch.nn.Module) extra_features: The additional input
# features to use as augmented input. If ``None`` no extra features
# are passed. If it is a list of :class:`torch.nn.Module`, the extra feature
# list is passed to all models. If it is a list of extra features' lists,
# each single list of extra feature is passed to a model.
# """
# super().__init__()
# # check consistency of the inputs
# check_consistency(models, torch.nn.Module)
# check_consistency(problem, AbstractProblem)
# check_consistency(optimizers, torch.optim.Optimizer, subclass=True)
# check_consistency(optimizers_kwargs, dict)
# # put everything in a list if only one input
# if not isinstance(models, list):
# models = [models]
# if not isinstance(optimizers, list):
# optimizers = [optimizers]
# optimizers_kwargs = [optimizers_kwargs]
# # number of models and optimizers
# len_model = len(models)
# len_optimizer = len(optimizers)
# len_optimizer_kwargs = len(optimizers_kwargs)
# # check length consistency optimizers
# if len_model != len_optimizer:
# raise ValueError(
# "You must define one optimizer for each model."
# f"Got {len_model} models, and {len_optimizer}"
# " optimizers."
# )
# # check length consistency optimizers kwargs
# if len_optimizer_kwargs != len_optimizer:
# raise ValueError(
# "You must define one dictionary of keyword"
# " arguments for each optimizers."
# f"Got {len_optimizer} optimizers, and"
# f" {len_optimizer_kwargs} dicitionaries"
# )
# # extra features handling
# if (extra_features is None) or (len(extra_features) == 0):
# extra_features = [None] * len_model
# else:
# # if we only have a list of extra features
# if not isinstance(extra_features[0], (tuple, list)):
# extra_features = [extra_features] * len_model
# else: # if we have a list of list extra features
# if len(extra_features) != len_model:
# raise ValueError(
# "You passed a list of extrafeatures list with len"
# f"different of models len. Expected {len_model} "
# f"got {len(extra_features)}. If you want to use "
# "the same list of extra features for all models, "
# "just pass a list of extrafeatures and not a list "
# "of list of extra features."
# )
# # assigning model and optimizers
# self._pina_models = []
# self._pina_optimizers = []
# for idx in range(len_model):
# model_ = Network(
# model=models[idx],
# input_variables=problem.input_variables,
# output_variables=problem.output_variables,
# extra_features=extra_features[idx],
# )
# optim_ = optimizers[idx](
# model_.parameters(), **optimizers_kwargs[idx]
# )
# self._pina_models.append(model_)
# self._pina_optimizers.append(optim_)
# # assigning problem
# self._pina_problem = problem
# @abstractmethod
# def forward(self, *args, **kwargs):
# pass
# @abstractmethod
# def training_step(self):
# pass
# @abstractmethod
# def configure_optimizers(self):
# pass
# @property
# def models(self):
# """
# The torch model."""
# return self._pina_models
# @property
# def optimizers(self):
# """
# The torch model."""
# return self._pina_optimizers
# @property
# def problem(self):
# """
# The problem formulation."""
# return self._pina_problem
# def on_train_start(self):
# """
# On training epoch start this function is call to do global checks for
# the different solvers.
# """
# # 1. Check the verison for dataloader
# dataloader = self.trainer.train_dataloader
# if sys.version_info < (3, 8):
# dataloader = dataloader.loaders
# self._dataloader = dataloader
# return super().on_train_start()
# @model.setter
# def model(self, new_model):
# """
# Set the torch."""
# check_consistency(new_model, nn.Module, 'torch model')
# self._model= new_model
# @problem.setter
# def problem(self, problem):
# """
# Set the problem formulation."""
# check_consistency(problem, AbstractProblem, 'pina problem')
# self._problem = problem
class SolverInterface(pytorch_lightning.LightningModule, metaclass=ABCMeta):
"""
Solver base class. This class inherits is a wrapper of
@@ -181,10 +19,12 @@ class SolverInterface(pytorch_lightning.LightningModule, metaclass=ABCMeta):
def __init__(
self,
model,
models,
problem,
optimizer,
scheduler,
optimizers,
schedulers,
extra_features,
use_lt=True
):
"""
:param model: A torch neural network model instance.
@@ -197,22 +37,45 @@ class SolverInterface(pytorch_lightning.LightningModule, metaclass=ABCMeta):
super().__init__()
# check consistency of the inputs
check_consistency(model, torch.nn.Module)
check_consistency(problem, AbstractProblem)
check_consistency(optimizer, Optimizer)
check_consistency(scheduler, Scheduler)
self._check_solver_consistency(problem)
# put everything in a list if only one input
if not isinstance(model, list):
model = [model]
if not isinstance(scheduler, list):
scheduler = [scheduler]
if not isinstance(optimizer, list):
optimizer = [optimizer]
#Check consistency of models argument and encapsulate in list
if not isinstance(models, list):
check_consistency(models, torch.nn.Module)
# put everything in a list if only one input
models = [models]
else:
for idx in range(len(models)):
# Check consistency
check_consistency(models[idx], torch.nn.Module)
len_model = len(models)
# number of models and optimizers
len_model = len(model)
len_optimizer = len(optimizer)
#If use_lt is true add extract operation in input
if use_lt is True:
for idx in range(len(models)):
models[idx] = Network(
model = models[idx],
input_variables=problem.input_variables,
output_variables=problem.output_variables,
extra_features=extra_features, )
#Check scheduler consistency + encapsulation
if not isinstance(schedulers, list):
check_consistency(schedulers, Scheduler)
schedulers = [schedulers]
else:
for scheduler in schedulers:
check_consistency(scheduler, Scheduler)
#Check optimizer consistency + encapsulation
if not isinstance(optimizers, list):
check_consistency(optimizers, Optimizer)
optimizers = [optimizers]
else:
for optimizer in optimizers:
check_consistency(optimizer, Optimizer)
len_optimizer = len(optimizers)
# check length consistency optimizers
if len_model != len_optimizer:
@@ -223,10 +86,12 @@ class SolverInterface(pytorch_lightning.LightningModule, metaclass=ABCMeta):
)
# extra features handling
self._pina_models = models
self._pina_optimizers = optimizers
self._pina_schedulers = schedulers
self._pina_problem = problem
self._pina_model = model
self._pina_optimizer = optimizer
self._pina_scheduler = scheduler
@abstractmethod
def forward(self, *args, **kwargs):
@@ -244,13 +109,13 @@ class SolverInterface(pytorch_lightning.LightningModule, metaclass=ABCMeta):
def models(self):
"""
The torch model."""
return self._pina_model
return self._pina_models
@property
def optimizers(self):
"""
The torch model."""
return self._pina_optimizer
return self._pina_optimizers
@property
def problem(self):
@@ -272,16 +137,10 @@ class SolverInterface(pytorch_lightning.LightningModule, metaclass=ABCMeta):
return super().on_train_start()
# @model.setter
# def model(self, new_model):
# """
# Set the torch."""
# check_consistency(new_model, nn.Module, 'torch model')
# self._model= new_model
# @problem.setter
# def problem(self, problem):
# """
# Set the problem formulation."""
# check_consistency(problem, AbstractProblem, 'pina problem')
# self._problem = problem
def _check_solver_consistency(self, problem):
"""
TODO
"""
for _, condition in problem.conditions.items():
if not set(self.accepted_condition_types).issubset(condition.condition_type):
raise ValueError(f'{self.__name__} support only dose not support condition {condition.condition_type}')