Implement Dataset, Dataloader and DataModule class and fix SupervisedSolver
This commit is contained in:
committed by
Nicola Demo
parent
b9753c34b2
commit
c9304fb9bb
@@ -27,13 +27,13 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
models,
|
||||
problem,
|
||||
optimizers,
|
||||
optimizers_kwargs,
|
||||
extra_features,
|
||||
loss,
|
||||
self,
|
||||
models,
|
||||
problem,
|
||||
optimizers,
|
||||
optimizers_kwargs,
|
||||
extra_features,
|
||||
loss,
|
||||
):
|
||||
"""
|
||||
:param models: Multiple torch neural network models instances.
|
||||
@@ -178,7 +178,7 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
|
||||
try:
|
||||
residual = equation.residual(samples, self.forward(samples))
|
||||
except (
|
||||
TypeError
|
||||
TypeError
|
||||
): # this occurs when the function has three inputs, i.e. inverse problem
|
||||
residual = equation.residual(
|
||||
samples, self.forward(samples), self._params
|
||||
|
||||
@@ -10,168 +10,6 @@ import torch
|
||||
import sys
|
||||
|
||||
|
||||
# class SolverInterface(pytorch_lightning.LightningModule, metaclass=ABCMeta):
|
||||
# """
|
||||
# Solver base class. This class inherits is a wrapper of
|
||||
# LightningModule class, inheriting all the
|
||||
# LightningModule methods.
|
||||
# """
|
||||
|
||||
# def __init__(
|
||||
# self,
|
||||
# models,
|
||||
# problem,
|
||||
# optimizers,
|
||||
# optimizers_kwargs,
|
||||
# extra_features=None,
|
||||
# ):
|
||||
# """
|
||||
# :param models: A torch neural network model instance.
|
||||
# :type models: torch.nn.Module
|
||||
# :param problem: A problem definition instance.
|
||||
# :type problem: AbstractProblem
|
||||
# :param list(torch.optim.Optimizer) optimizer: A list of neural network optimizers to
|
||||
# use.
|
||||
# :param list(dict) optimizer_kwargs: A list of optimizer constructor keyword args.
|
||||
# :param list(torch.nn.Module) extra_features: The additional input
|
||||
# features to use as augmented input. If ``None`` no extra features
|
||||
# are passed. If it is a list of :class:`torch.nn.Module`, the extra feature
|
||||
# list is passed to all models. If it is a list of extra features' lists,
|
||||
# each single list of extra feature is passed to a model.
|
||||
# """
|
||||
# super().__init__()
|
||||
|
||||
# # check consistency of the inputs
|
||||
# check_consistency(models, torch.nn.Module)
|
||||
# check_consistency(problem, AbstractProblem)
|
||||
# check_consistency(optimizers, torch.optim.Optimizer, subclass=True)
|
||||
# check_consistency(optimizers_kwargs, dict)
|
||||
|
||||
# # put everything in a list if only one input
|
||||
# if not isinstance(models, list):
|
||||
# models = [models]
|
||||
# if not isinstance(optimizers, list):
|
||||
# optimizers = [optimizers]
|
||||
# optimizers_kwargs = [optimizers_kwargs]
|
||||
|
||||
# # number of models and optimizers
|
||||
# len_model = len(models)
|
||||
# len_optimizer = len(optimizers)
|
||||
# len_optimizer_kwargs = len(optimizers_kwargs)
|
||||
|
||||
# # check length consistency optimizers
|
||||
# if len_model != len_optimizer:
|
||||
# raise ValueError(
|
||||
# "You must define one optimizer for each model."
|
||||
# f"Got {len_model} models, and {len_optimizer}"
|
||||
# " optimizers."
|
||||
# )
|
||||
|
||||
# # check length consistency optimizers kwargs
|
||||
# if len_optimizer_kwargs != len_optimizer:
|
||||
# raise ValueError(
|
||||
# "You must define one dictionary of keyword"
|
||||
# " arguments for each optimizers."
|
||||
# f"Got {len_optimizer} optimizers, and"
|
||||
# f" {len_optimizer_kwargs} dicitionaries"
|
||||
# )
|
||||
|
||||
# # extra features handling
|
||||
# if (extra_features is None) or (len(extra_features) == 0):
|
||||
# extra_features = [None] * len_model
|
||||
# else:
|
||||
# # if we only have a list of extra features
|
||||
# if not isinstance(extra_features[0], (tuple, list)):
|
||||
# extra_features = [extra_features] * len_model
|
||||
# else: # if we have a list of list extra features
|
||||
# if len(extra_features) != len_model:
|
||||
# raise ValueError(
|
||||
# "You passed a list of extrafeatures list with len"
|
||||
# f"different of models len. Expected {len_model} "
|
||||
# f"got {len(extra_features)}. If you want to use "
|
||||
# "the same list of extra features for all models, "
|
||||
# "just pass a list of extrafeatures and not a list "
|
||||
# "of list of extra features."
|
||||
# )
|
||||
|
||||
# # assigning model and optimizers
|
||||
# self._pina_models = []
|
||||
# self._pina_optimizers = []
|
||||
|
||||
# for idx in range(len_model):
|
||||
# model_ = Network(
|
||||
# model=models[idx],
|
||||
# input_variables=problem.input_variables,
|
||||
# output_variables=problem.output_variables,
|
||||
# extra_features=extra_features[idx],
|
||||
# )
|
||||
# optim_ = optimizers[idx](
|
||||
# model_.parameters(), **optimizers_kwargs[idx]
|
||||
# )
|
||||
# self._pina_models.append(model_)
|
||||
# self._pina_optimizers.append(optim_)
|
||||
|
||||
# # assigning problem
|
||||
# self._pina_problem = problem
|
||||
|
||||
# @abstractmethod
|
||||
# def forward(self, *args, **kwargs):
|
||||
# pass
|
||||
|
||||
# @abstractmethod
|
||||
# def training_step(self):
|
||||
# pass
|
||||
|
||||
# @abstractmethod
|
||||
# def configure_optimizers(self):
|
||||
# pass
|
||||
|
||||
# @property
|
||||
# def models(self):
|
||||
# """
|
||||
# The torch model."""
|
||||
# return self._pina_models
|
||||
|
||||
# @property
|
||||
# def optimizers(self):
|
||||
# """
|
||||
# The torch model."""
|
||||
# return self._pina_optimizers
|
||||
|
||||
# @property
|
||||
# def problem(self):
|
||||
# """
|
||||
# The problem formulation."""
|
||||
# return self._pina_problem
|
||||
|
||||
# def on_train_start(self):
|
||||
# """
|
||||
# On training epoch start this function is call to do global checks for
|
||||
# the different solvers.
|
||||
# """
|
||||
|
||||
# # 1. Check the verison for dataloader
|
||||
# dataloader = self.trainer.train_dataloader
|
||||
# if sys.version_info < (3, 8):
|
||||
# dataloader = dataloader.loaders
|
||||
# self._dataloader = dataloader
|
||||
|
||||
# return super().on_train_start()
|
||||
|
||||
# @model.setter
|
||||
# def model(self, new_model):
|
||||
# """
|
||||
# Set the torch."""
|
||||
# check_consistency(new_model, nn.Module, 'torch model')
|
||||
# self._model= new_model
|
||||
|
||||
# @problem.setter
|
||||
# def problem(self, problem):
|
||||
# """
|
||||
# Set the problem formulation."""
|
||||
# check_consistency(problem, AbstractProblem, 'pina problem')
|
||||
# self._problem = problem
|
||||
|
||||
class SolverInterface(pytorch_lightning.LightningModule, metaclass=ABCMeta):
|
||||
"""
|
||||
Solver base class. This class inherits is a wrapper of
|
||||
@@ -181,10 +19,12 @@ class SolverInterface(pytorch_lightning.LightningModule, metaclass=ABCMeta):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
models,
|
||||
problem,
|
||||
optimizer,
|
||||
scheduler,
|
||||
optimizers,
|
||||
schedulers,
|
||||
extra_features,
|
||||
use_lt=True
|
||||
):
|
||||
"""
|
||||
:param model: A torch neural network model instance.
|
||||
@@ -197,22 +37,45 @@ class SolverInterface(pytorch_lightning.LightningModule, metaclass=ABCMeta):
|
||||
super().__init__()
|
||||
|
||||
# check consistency of the inputs
|
||||
check_consistency(model, torch.nn.Module)
|
||||
check_consistency(problem, AbstractProblem)
|
||||
check_consistency(optimizer, Optimizer)
|
||||
check_consistency(scheduler, Scheduler)
|
||||
self._check_solver_consistency(problem)
|
||||
|
||||
# put everything in a list if only one input
|
||||
if not isinstance(model, list):
|
||||
model = [model]
|
||||
if not isinstance(scheduler, list):
|
||||
scheduler = [scheduler]
|
||||
if not isinstance(optimizer, list):
|
||||
optimizer = [optimizer]
|
||||
#Check consistency of models argument and encapsulate in list
|
||||
if not isinstance(models, list):
|
||||
check_consistency(models, torch.nn.Module)
|
||||
# put everything in a list if only one input
|
||||
models = [models]
|
||||
else:
|
||||
for idx in range(len(models)):
|
||||
# Check consistency
|
||||
check_consistency(models[idx], torch.nn.Module)
|
||||
len_model = len(models)
|
||||
|
||||
# number of models and optimizers
|
||||
len_model = len(model)
|
||||
len_optimizer = len(optimizer)
|
||||
#If use_lt is true add extract operation in input
|
||||
if use_lt is True:
|
||||
for idx in range(len(models)):
|
||||
models[idx] = Network(
|
||||
model = models[idx],
|
||||
input_variables=problem.input_variables,
|
||||
output_variables=problem.output_variables,
|
||||
extra_features=extra_features, )
|
||||
|
||||
#Check scheduler consistency + encapsulation
|
||||
if not isinstance(schedulers, list):
|
||||
check_consistency(schedulers, Scheduler)
|
||||
schedulers = [schedulers]
|
||||
else:
|
||||
for scheduler in schedulers:
|
||||
check_consistency(scheduler, Scheduler)
|
||||
|
||||
#Check optimizer consistency + encapsulation
|
||||
if not isinstance(optimizers, list):
|
||||
check_consistency(optimizers, Optimizer)
|
||||
optimizers = [optimizers]
|
||||
else:
|
||||
for optimizer in optimizers:
|
||||
check_consistency(optimizer, Optimizer)
|
||||
len_optimizer = len(optimizers)
|
||||
|
||||
# check length consistency optimizers
|
||||
if len_model != len_optimizer:
|
||||
@@ -223,10 +86,12 @@ class SolverInterface(pytorch_lightning.LightningModule, metaclass=ABCMeta):
|
||||
)
|
||||
|
||||
# extra features handling
|
||||
|
||||
self._pina_models = models
|
||||
self._pina_optimizers = optimizers
|
||||
self._pina_schedulers = schedulers
|
||||
self._pina_problem = problem
|
||||
self._pina_model = model
|
||||
self._pina_optimizer = optimizer
|
||||
self._pina_scheduler = scheduler
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def forward(self, *args, **kwargs):
|
||||
@@ -244,13 +109,13 @@ class SolverInterface(pytorch_lightning.LightningModule, metaclass=ABCMeta):
|
||||
def models(self):
|
||||
"""
|
||||
The torch model."""
|
||||
return self._pina_model
|
||||
return self._pina_models
|
||||
|
||||
@property
|
||||
def optimizers(self):
|
||||
"""
|
||||
The torch model."""
|
||||
return self._pina_optimizer
|
||||
return self._pina_optimizers
|
||||
|
||||
@property
|
||||
def problem(self):
|
||||
@@ -272,16 +137,10 @@ class SolverInterface(pytorch_lightning.LightningModule, metaclass=ABCMeta):
|
||||
|
||||
return super().on_train_start()
|
||||
|
||||
# @model.setter
|
||||
# def model(self, new_model):
|
||||
# """
|
||||
# Set the torch."""
|
||||
# check_consistency(new_model, nn.Module, 'torch model')
|
||||
# self._model= new_model
|
||||
|
||||
# @problem.setter
|
||||
# def problem(self, problem):
|
||||
# """
|
||||
# Set the problem formulation."""
|
||||
# check_consistency(problem, AbstractProblem, 'pina problem')
|
||||
# self._problem = problem
|
||||
def _check_solver_consistency(self, problem):
|
||||
"""
|
||||
TODO
|
||||
"""
|
||||
for _, condition in problem.conditions.items():
|
||||
if not set(self.accepted_condition_types).issubset(condition.condition_type):
|
||||
raise ValueError(f'{self.__name__} support only dose not support condition {condition.condition_type}')
|
||||
|
||||
@@ -2,9 +2,7 @@
|
||||
|
||||
import torch
|
||||
from torch.nn.modules.loss import _Loss
|
||||
|
||||
|
||||
from ..optim import Optimizer, Scheduler, TorchOptimizer, TorchScheduler
|
||||
from ..optim import TorchOptimizer, TorchScheduler
|
||||
from .solver import SolverInterface
|
||||
from ..label_tensor import LabelTensor
|
||||
from ..utils import check_consistency
|
||||
@@ -39,14 +37,17 @@ class SupervisedSolver(SolverInterface):
|
||||
we are seeking to approximate multiple (discretised) functions given
|
||||
multiple (discretised) input functions.
|
||||
"""
|
||||
accepted_condition_types = ['supervised']
|
||||
__name__ = 'SupervisedSolver'
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
problem,
|
||||
model,
|
||||
loss=None,
|
||||
optimizer=None,
|
||||
scheduler=None,
|
||||
self,
|
||||
problem,
|
||||
model,
|
||||
loss=None,
|
||||
optimizer=None,
|
||||
scheduler=None,
|
||||
extra_features=None
|
||||
):
|
||||
"""
|
||||
:param AbstractProblem problem: The formualation of the problem.
|
||||
@@ -57,11 +58,8 @@ class SupervisedSolver(SolverInterface):
|
||||
features to use as augmented input.
|
||||
:param torch.optim.Optimizer optimizer: The neural network optimizer to
|
||||
use; default is :class:`torch.optim.Adam`.
|
||||
:param dict optimizer_kwargs: Optimizer constructor keyword args.
|
||||
:param float lr: The learning rate; default is 0.001.
|
||||
:param torch.optim.LRScheduler scheduler: Learning
|
||||
rate scheduler.
|
||||
:param dict scheduler_kwargs: LR scheduler constructor keyword args.
|
||||
"""
|
||||
if loss is None:
|
||||
loss = torch.nn.MSELoss()
|
||||
@@ -74,18 +72,19 @@ class SupervisedSolver(SolverInterface):
|
||||
torch.optim.lr_scheduler.ConstantLR)
|
||||
|
||||
super().__init__(
|
||||
model=model,
|
||||
models=model,
|
||||
problem=problem,
|
||||
optimizer=optimizer,
|
||||
scheduler=scheduler,
|
||||
optimizers=optimizer,
|
||||
schedulers=scheduler,
|
||||
extra_features=extra_features
|
||||
)
|
||||
|
||||
# check consistency
|
||||
check_consistency(loss, (LossInterface, _Loss), subclass=False)
|
||||
self._loss = loss
|
||||
self._model = self._pina_model[0]
|
||||
self._optimizer = self._pina_optimizer[0]
|
||||
self._scheduler = self._pina_scheduler[0]
|
||||
self._model = self._pina_models[0]
|
||||
self._optimizer = self._pina_optimizers[0]
|
||||
self._scheduler = self._pina_schedulers[0]
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward pass implementation for the solver.
|
||||
@@ -97,12 +96,7 @@ class SupervisedSolver(SolverInterface):
|
||||
|
||||
output = self._model(x)
|
||||
|
||||
output.labels = {
|
||||
1: {
|
||||
"name": "output",
|
||||
"dof": self.problem.output_variables
|
||||
}
|
||||
}
|
||||
output.labels = self.problem.output_variables
|
||||
return output
|
||||
|
||||
def configure_optimizers(self):
|
||||
@@ -128,16 +122,14 @@ class SupervisedSolver(SolverInterface):
|
||||
:return: The sum of the loss functions.
|
||||
:rtype: LabelTensor
|
||||
"""
|
||||
|
||||
condition_idx = batch.condition
|
||||
condition_idx = batch.supervised.condition_indices
|
||||
|
||||
for condition_id in range(condition_idx.min(), condition_idx.max() + 1):
|
||||
|
||||
condition_name = self._dataloader.condition_names[condition_id]
|
||||
condition = self.problem.conditions[condition_name]
|
||||
pts = batch.input
|
||||
out = batch.output
|
||||
|
||||
pts = batch.supervised.input_points
|
||||
out = batch.supervised.output_points
|
||||
if condition_name not in self.problem.conditions:
|
||||
raise RuntimeError("Something wrong happened.")
|
||||
|
||||
@@ -167,8 +159,8 @@ class SupervisedSolver(SolverInterface):
|
||||
the network output against the true solution. This function
|
||||
should not be override if not intentionally.
|
||||
|
||||
:param LabelTensor input_tensor: The input to the neural networks.
|
||||
:param LabelTensor output_tensor: The true solution to compare the
|
||||
:param LabelTensor input_pts: The input to the neural networks.
|
||||
:param LabelTensor output_pts: The true solution to compare the
|
||||
network solution.
|
||||
:return: The residual loss averaged on the input coordinates
|
||||
:rtype: torch.Tensor
|
||||
@@ -181,7 +173,7 @@ class SupervisedSolver(SolverInterface):
|
||||
Scheduler for training.
|
||||
"""
|
||||
return self._scheduler
|
||||
|
||||
|
||||
@property
|
||||
def optimizer(self):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user