refact
This commit is contained in:
@@ -20,3 +20,6 @@ from .dataset import SamplePointDataset
|
|||||||
from .dataset import SamplePointLoader
|
from .dataset import SamplePointLoader
|
||||||
from .optimizer import TorchOptimizer
|
from .optimizer import TorchOptimizer
|
||||||
from .scheduler import TorchScheduler
|
from .scheduler import TorchScheduler
|
||||||
|
from .condition.condition import Condition
|
||||||
|
from .data.dataset import SamplePointDataset
|
||||||
|
from .data.dataset import SamplePointLoader
|
||||||
|
|||||||
10
pina/condition/__init__.py
Normal file
10
pina/condition/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
__all__ = [
|
||||||
|
'Condition',
|
||||||
|
'ConditionInterface',
|
||||||
|
'InputOutputCondition',
|
||||||
|
'InputEquationCondition'
|
||||||
|
'LocationEquationCondition',
|
||||||
|
]
|
||||||
|
|
||||||
|
from .condition_interface import ConditionInterface
|
||||||
|
from .input_output_condition import InputOutputCondition
|
||||||
@@ -1,8 +1,8 @@
|
|||||||
""" Condition module. """
|
""" Condition module. """
|
||||||
|
|
||||||
from .label_tensor import LabelTensor
|
from ..label_tensor import LabelTensor
|
||||||
from .geometry import Location
|
from ..geometry import Location
|
||||||
from .equation.equation import Equation
|
from ..equation.equation import Equation
|
||||||
|
|
||||||
|
|
||||||
def dummy(a):
|
def dummy(a):
|
||||||
@@ -59,23 +59,31 @@ class Condition:
|
|||||||
"data_weight",
|
"data_weight",
|
||||||
]
|
]
|
||||||
|
|
||||||
def _dictvalue_isinstance(self, dict_, key_, class_):
|
# def _dictvalue_isinstance(self, dict_, key_, class_):
|
||||||
"""Check if the value of a dictionary corresponding to `key` is an instance of `class_`."""
|
# """Check if the value of a dictionary corresponding to `key` is an instance of `class_`."""
|
||||||
if key_ not in dict_.keys():
|
# if key_ not in dict_.keys():
|
||||||
return True
|
# return True
|
||||||
|
|
||||||
return isinstance(dict_[key_], class_)
|
# return isinstance(dict_[key_], class_)
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
# def __init__(self, *args, **kwargs):
|
||||||
"""
|
# """
|
||||||
Constructor for the `Condition` class.
|
# Constructor for the `Condition` class.
|
||||||
"""
|
# """
|
||||||
self.data_weight = kwargs.pop("data_weight", 1.0)
|
# self.data_weight = kwargs.pop("data_weight", 1.0)
|
||||||
|
|
||||||
if len(args) != 0:
|
# if len(args) != 0:
|
||||||
raise ValueError(
|
# raise ValueError(
|
||||||
f"Condition takes only the following keyword arguments: {Condition.__slots__}."
|
# f"Condition takes only the following keyword arguments: {Condition.__slots__}."
|
||||||
)
|
# )
|
||||||
|
|
||||||
|
from . import InputOutputCondition
|
||||||
|
def __new__(cls, *args, **kwargs):
|
||||||
|
|
||||||
|
if sorted(kwargs.keys()) == sorted(["input_points", "output_points"]):
|
||||||
|
return InputOutputCondition(**kwargs)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid keyword arguments {kwargs.keys()}.")
|
||||||
|
|
||||||
if (
|
if (
|
||||||
sorted(kwargs.keys()) != sorted(["input_points", "output_points"])
|
sorted(kwargs.keys()) != sorted(["input_points", "output_points"])
|
||||||
15
pina/condition/condition_interface.py
Normal file
15
pina/condition/condition_interface.py
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
|
||||||
|
from abc import ABCMeta, abstractmethod
|
||||||
|
|
||||||
|
|
||||||
|
class ConditionInterface(metaclass=ABCMeta):
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def residual(self, model):
|
||||||
|
"""
|
||||||
|
Compute the residual of the condition.
|
||||||
|
|
||||||
|
:param model: The model to evaluate the condition.
|
||||||
|
:return: The residual of the condition.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
28
pina/condition/domain_equation_condition.py
Normal file
28
pina/condition/domain_equation_condition.py
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
from .condition_interface import ConditionInterface
|
||||||
|
|
||||||
|
class DomainEquationCondition(ConditionInterface):
|
||||||
|
"""
|
||||||
|
Condition for input/output data.
|
||||||
|
"""
|
||||||
|
|
||||||
|
__slots__ = ["domain", "equation"]
|
||||||
|
|
||||||
|
def __init__(self, domain, equation):
|
||||||
|
"""
|
||||||
|
Constructor for the `InputOutputCondition` class.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.domain = domain
|
||||||
|
self.equation = equation
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def batch_residual(model, input_pts, equation):
|
||||||
|
"""
|
||||||
|
Compute the residual of the condition for a single batch. Input and
|
||||||
|
output points are provided as arguments.
|
||||||
|
|
||||||
|
:param torch.nn.Module model: The model to evaluate the condition.
|
||||||
|
:param torch.Tensor input_points: The input points.
|
||||||
|
:param torch.Tensor output_points: The output points.
|
||||||
|
"""
|
||||||
|
return equation.residual(model(input_pts))
|
||||||
23
pina/condition/input_equation_condition.py
Normal file
23
pina/condition/input_equation_condition.py
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
|
||||||
|
from . import ConditionInterface
|
||||||
|
|
||||||
|
class InputOutputCondition(ConditionInterface):
|
||||||
|
"""
|
||||||
|
Condition for input/output data.
|
||||||
|
"""
|
||||||
|
|
||||||
|
__slots__ = ["input_points", "output_points"]
|
||||||
|
|
||||||
|
def __init__(self, input_points, output_points):
|
||||||
|
"""
|
||||||
|
Constructor for the `InputOutputCondition` class.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.input_points = input_points
|
||||||
|
self.output_points = output_points
|
||||||
|
|
||||||
|
def residual(self, model):
|
||||||
|
"""
|
||||||
|
Compute the residual of the condition.
|
||||||
|
"""
|
||||||
|
return self.output_points - model(self.input_points)
|
||||||
35
pina/condition/input_output_condition.py
Normal file
35
pina/condition/input_output_condition.py
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
|
||||||
|
from . import ConditionInterface
|
||||||
|
|
||||||
|
class InputOutputCondition(ConditionInterface):
|
||||||
|
"""
|
||||||
|
Condition for input/output data.
|
||||||
|
"""
|
||||||
|
|
||||||
|
__slots__ = ["input_points", "output_points"]
|
||||||
|
|
||||||
|
def __init__(self, input_points, output_points):
|
||||||
|
"""
|
||||||
|
Constructor for the `InputOutputCondition` class.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.input_points = input_points
|
||||||
|
self.output_points = output_points
|
||||||
|
|
||||||
|
def residual(self, model):
|
||||||
|
"""
|
||||||
|
Compute the residual of the condition.
|
||||||
|
"""
|
||||||
|
return self.batch_residual(model, self.input_points, self.output_points)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def batch_residual(model, input_points, output_points):
|
||||||
|
"""
|
||||||
|
Compute the residual of the condition for a single batch. Input and
|
||||||
|
output points are provided as arguments.
|
||||||
|
|
||||||
|
:param torch.nn.Module model: The model to evaluate the condition.
|
||||||
|
:param torch.Tensor input_points: The input points.
|
||||||
|
:param torch.Tensor output_points: The output points.
|
||||||
|
"""
|
||||||
|
return output_points - model(input_points)
|
||||||
0
pina/data/__init__.py
Normal file
0
pina/data/__init__.py
Normal file
@@ -1,6 +1,6 @@
|
|||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
import torch
|
import torch
|
||||||
from .label_tensor import LabelTensor
|
from ..label_tensor import LabelTensor
|
||||||
|
|
||||||
|
|
||||||
class SamplePointDataset(Dataset):
|
class SamplePointDataset(Dataset):
|
||||||
11
pina/optim/__init__.py
Normal file
11
pina/optim/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
__all__ = [
|
||||||
|
"Optimizer",
|
||||||
|
"TorchOptimizer",
|
||||||
|
"Scheduler",
|
||||||
|
"TorchScheduler",
|
||||||
|
]
|
||||||
|
|
||||||
|
from .optimizer_interface import Optimizer
|
||||||
|
from .torch_optimizer import TorchOptimizer
|
||||||
|
from .scheduler_interface import Scheduler
|
||||||
|
from .torch_scheduler import TorchScheduler
|
||||||
7
pina/optim/optimizer_interface.py
Normal file
7
pina/optim/optimizer_interface.py
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
""" Module for PINA Optimizer """
|
||||||
|
|
||||||
|
from abc import ABCMeta
|
||||||
|
|
||||||
|
|
||||||
|
class Optimizer(metaclass=ABCMeta): # TODO improve interface
|
||||||
|
pass
|
||||||
7
pina/optim/scheduler_interface.py
Normal file
7
pina/optim/scheduler_interface.py
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
""" Module for PINA Optimizer """
|
||||||
|
|
||||||
|
from abc import ABCMeta
|
||||||
|
|
||||||
|
|
||||||
|
class Scheduler(metaclass=ABCMeta): # TODO improve interface
|
||||||
|
pass
|
||||||
19
pina/optim/torch_optimizer.py
Normal file
19
pina/optim/torch_optimizer.py
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
""" Module for PINA Torch Optimizer """
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from ..utils import check_consistency
|
||||||
|
from .optimizer_interface import Optimizer
|
||||||
|
|
||||||
|
class TorchOptimizer(Optimizer):
|
||||||
|
|
||||||
|
def __init__(self, optimizer_class, **kwargs):
|
||||||
|
check_consistency(optimizer_class, torch.optim.Optimizer, subclass=True)
|
||||||
|
|
||||||
|
self.optimizer_class = optimizer_class
|
||||||
|
self.kwargs = kwargs
|
||||||
|
|
||||||
|
def hook(self, parameters):
|
||||||
|
self.optimizer_instance = self.optimizer_class(
|
||||||
|
parameters, **self.kwargs
|
||||||
|
)
|
||||||
27
pina/optim/torch_scheduler.py
Normal file
27
pina/optim/torch_scheduler.py
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
""" Module for PINA Torch Optimizer """
|
||||||
|
|
||||||
|
import torch
|
||||||
|
try:
|
||||||
|
from torch.optim.lr_scheduler import LRScheduler # torch >= 2.0
|
||||||
|
except ImportError:
|
||||||
|
from torch.optim.lr_scheduler import (
|
||||||
|
_LRScheduler as LRScheduler,
|
||||||
|
) # torch < 2.0
|
||||||
|
|
||||||
|
from ..utils import check_consistency
|
||||||
|
from .optimizer_interface import Optimizer
|
||||||
|
from .scheduler_interface import Scheduler
|
||||||
|
|
||||||
|
class TorchScheduler(Scheduler):
|
||||||
|
|
||||||
|
def __init__(self, scheduler_class, **kwargs):
|
||||||
|
check_consistency(scheduler_class, LRScheduler, subclass=True)
|
||||||
|
|
||||||
|
self.scheduler_class = scheduler_class
|
||||||
|
self.kwargs = kwargs
|
||||||
|
|
||||||
|
def hook(self, optimizer):
|
||||||
|
check_consistency(optimizer, Optimizer)
|
||||||
|
self.scheduler_instance = self.scheduler_class(
|
||||||
|
optimizer.optimizer_instance, **self.kwargs
|
||||||
|
)
|
||||||
@@ -5,10 +5,173 @@ from ..model.network import Network
|
|||||||
import pytorch_lightning
|
import pytorch_lightning
|
||||||
from ..utils import check_consistency
|
from ..utils import check_consistency
|
||||||
from ..problem import AbstractProblem
|
from ..problem import AbstractProblem
|
||||||
|
from ..optim import Optimizer, Scheduler
|
||||||
import torch
|
import torch
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
|
||||||
|
# class SolverInterface(pytorch_lightning.LightningModule, metaclass=ABCMeta):
|
||||||
|
# """
|
||||||
|
# Solver base class. This class inherits is a wrapper of
|
||||||
|
# LightningModule class, inheriting all the
|
||||||
|
# LightningModule methods.
|
||||||
|
# """
|
||||||
|
|
||||||
|
# def __init__(
|
||||||
|
# self,
|
||||||
|
# models,
|
||||||
|
# problem,
|
||||||
|
# optimizers,
|
||||||
|
# optimizers_kwargs,
|
||||||
|
# extra_features=None,
|
||||||
|
# ):
|
||||||
|
# """
|
||||||
|
# :param models: A torch neural network model instance.
|
||||||
|
# :type models: torch.nn.Module
|
||||||
|
# :param problem: A problem definition instance.
|
||||||
|
# :type problem: AbstractProblem
|
||||||
|
# :param list(torch.optim.Optimizer) optimizer: A list of neural network optimizers to
|
||||||
|
# use.
|
||||||
|
# :param list(dict) optimizer_kwargs: A list of optimizer constructor keyword args.
|
||||||
|
# :param list(torch.nn.Module) extra_features: The additional input
|
||||||
|
# features to use as augmented input. If ``None`` no extra features
|
||||||
|
# are passed. If it is a list of :class:`torch.nn.Module`, the extra feature
|
||||||
|
# list is passed to all models. If it is a list of extra features' lists,
|
||||||
|
# each single list of extra feature is passed to a model.
|
||||||
|
# """
|
||||||
|
# super().__init__()
|
||||||
|
|
||||||
|
# # check consistency of the inputs
|
||||||
|
# check_consistency(models, torch.nn.Module)
|
||||||
|
# check_consistency(problem, AbstractProblem)
|
||||||
|
# check_consistency(optimizers, torch.optim.Optimizer, subclass=True)
|
||||||
|
# check_consistency(optimizers_kwargs, dict)
|
||||||
|
|
||||||
|
# # put everything in a list if only one input
|
||||||
|
# if not isinstance(models, list):
|
||||||
|
# models = [models]
|
||||||
|
# if not isinstance(optimizers, list):
|
||||||
|
# optimizers = [optimizers]
|
||||||
|
# optimizers_kwargs = [optimizers_kwargs]
|
||||||
|
|
||||||
|
# # number of models and optimizers
|
||||||
|
# len_model = len(models)
|
||||||
|
# len_optimizer = len(optimizers)
|
||||||
|
# len_optimizer_kwargs = len(optimizers_kwargs)
|
||||||
|
|
||||||
|
# # check length consistency optimizers
|
||||||
|
# if len_model != len_optimizer:
|
||||||
|
# raise ValueError(
|
||||||
|
# "You must define one optimizer for each model."
|
||||||
|
# f"Got {len_model} models, and {len_optimizer}"
|
||||||
|
# " optimizers."
|
||||||
|
# )
|
||||||
|
|
||||||
|
# # check length consistency optimizers kwargs
|
||||||
|
# if len_optimizer_kwargs != len_optimizer:
|
||||||
|
# raise ValueError(
|
||||||
|
# "You must define one dictionary of keyword"
|
||||||
|
# " arguments for each optimizers."
|
||||||
|
# f"Got {len_optimizer} optimizers, and"
|
||||||
|
# f" {len_optimizer_kwargs} dicitionaries"
|
||||||
|
# )
|
||||||
|
|
||||||
|
# # extra features handling
|
||||||
|
# if (extra_features is None) or (len(extra_features) == 0):
|
||||||
|
# extra_features = [None] * len_model
|
||||||
|
# else:
|
||||||
|
# # if we only have a list of extra features
|
||||||
|
# if not isinstance(extra_features[0], (tuple, list)):
|
||||||
|
# extra_features = [extra_features] * len_model
|
||||||
|
# else: # if we have a list of list extra features
|
||||||
|
# if len(extra_features) != len_model:
|
||||||
|
# raise ValueError(
|
||||||
|
# "You passed a list of extrafeatures list with len"
|
||||||
|
# f"different of models len. Expected {len_model} "
|
||||||
|
# f"got {len(extra_features)}. If you want to use "
|
||||||
|
# "the same list of extra features for all models, "
|
||||||
|
# "just pass a list of extrafeatures and not a list "
|
||||||
|
# "of list of extra features."
|
||||||
|
# )
|
||||||
|
|
||||||
|
# # assigning model and optimizers
|
||||||
|
# self._pina_models = []
|
||||||
|
# self._pina_optimizers = []
|
||||||
|
|
||||||
|
# for idx in range(len_model):
|
||||||
|
# model_ = Network(
|
||||||
|
# model=models[idx],
|
||||||
|
# input_variables=problem.input_variables,
|
||||||
|
# output_variables=problem.output_variables,
|
||||||
|
# extra_features=extra_features[idx],
|
||||||
|
# )
|
||||||
|
# optim_ = optimizers[idx](
|
||||||
|
# model_.parameters(), **optimizers_kwargs[idx]
|
||||||
|
# )
|
||||||
|
# self._pina_models.append(model_)
|
||||||
|
# self._pina_optimizers.append(optim_)
|
||||||
|
|
||||||
|
# # assigning problem
|
||||||
|
# self._pina_problem = problem
|
||||||
|
|
||||||
|
# @abstractmethod
|
||||||
|
# def forward(self, *args, **kwargs):
|
||||||
|
# pass
|
||||||
|
|
||||||
|
# @abstractmethod
|
||||||
|
# def training_step(self):
|
||||||
|
# pass
|
||||||
|
|
||||||
|
# @abstractmethod
|
||||||
|
# def configure_optimizers(self):
|
||||||
|
# pass
|
||||||
|
|
||||||
|
# @property
|
||||||
|
# def models(self):
|
||||||
|
# """
|
||||||
|
# The torch model."""
|
||||||
|
# return self._pina_models
|
||||||
|
|
||||||
|
# @property
|
||||||
|
# def optimizers(self):
|
||||||
|
# """
|
||||||
|
# The torch model."""
|
||||||
|
# return self._pina_optimizers
|
||||||
|
|
||||||
|
# @property
|
||||||
|
# def problem(self):
|
||||||
|
# """
|
||||||
|
# The problem formulation."""
|
||||||
|
# return self._pina_problem
|
||||||
|
|
||||||
|
# def on_train_start(self):
|
||||||
|
# """
|
||||||
|
# On training epoch start this function is call to do global checks for
|
||||||
|
# the different solvers.
|
||||||
|
# """
|
||||||
|
|
||||||
|
# # 1. Check the verison for dataloader
|
||||||
|
# dataloader = self.trainer.train_dataloader
|
||||||
|
# if sys.version_info < (3, 8):
|
||||||
|
# dataloader = dataloader.loaders
|
||||||
|
# self._dataloader = dataloader
|
||||||
|
|
||||||
|
# return super().on_train_start()
|
||||||
|
|
||||||
|
# @model.setter
|
||||||
|
# def model(self, new_model):
|
||||||
|
# """
|
||||||
|
# Set the torch."""
|
||||||
|
# check_consistency(new_model, nn.Module, 'torch model')
|
||||||
|
# self._model= new_model
|
||||||
|
|
||||||
|
# @problem.setter
|
||||||
|
# def problem(self, problem):
|
||||||
|
# """
|
||||||
|
# Set the problem formulation."""
|
||||||
|
# check_consistency(problem, AbstractProblem, 'pina problem')
|
||||||
|
# self._problem = problem
|
||||||
|
|
||||||
class SolverInterface(pytorch_lightning.LightningModule, metaclass=ABCMeta):
|
class SolverInterface(pytorch_lightning.LightningModule, metaclass=ABCMeta):
|
||||||
"""
|
"""
|
||||||
Solver base class. This class inherits is a wrapper of
|
Solver base class. This class inherits is a wrapper of
|
||||||
@@ -18,45 +181,36 @@ class SolverInterface(pytorch_lightning.LightningModule, metaclass=ABCMeta):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
models,
|
model,
|
||||||
problem,
|
problem,
|
||||||
optimizers,
|
optimizer,
|
||||||
optimizers_kwargs,
|
scheduler,
|
||||||
extra_features=None,
|
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
:param models: A torch neural network model instance.
|
:param model: A torch neural network model instance.
|
||||||
:type models: torch.nn.Module
|
:type model: torch.nn.Module
|
||||||
:param problem: A problem definition instance.
|
:param problem: A problem definition instance.
|
||||||
:type problem: AbstractProblem
|
:type problem: AbstractProblem
|
||||||
:param list(torch.optim.Optimizer) optimizer: A list of neural network optimizers to
|
:param list(torch.optim.Optimizer) optimizer: A list of neural network
|
||||||
use.
|
optimizers to use.
|
||||||
:param list(dict) optimizer_kwargs: A list of optimizer constructor keyword args.
|
|
||||||
:param list(torch.nn.Module) extra_features: The additional input
|
|
||||||
features to use as augmented input. If ``None`` no extra features
|
|
||||||
are passed. If it is a list of :class:`torch.nn.Module`, the extra feature
|
|
||||||
list is passed to all models. If it is a list of extra features' lists,
|
|
||||||
each single list of extra feature is passed to a model.
|
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
# check consistency of the inputs
|
# check consistency of the inputs
|
||||||
check_consistency(models, torch.nn.Module)
|
check_consistency(model, torch.nn.Module)
|
||||||
check_consistency(problem, AbstractProblem)
|
check_consistency(problem, AbstractProblem)
|
||||||
check_consistency(optimizers, torch.optim.Optimizer, subclass=True)
|
check_consistency(optimizer, Optimizer)
|
||||||
check_consistency(optimizers_kwargs, dict)
|
check_consistency(scheduler, Scheduler)
|
||||||
|
|
||||||
# put everything in a list if only one input
|
# put everything in a list if only one input
|
||||||
if not isinstance(models, list):
|
if not isinstance(model, list):
|
||||||
models = [models]
|
model = [model]
|
||||||
if not isinstance(optimizers, list):
|
if not isinstance(optimizer, list):
|
||||||
optimizers = [optimizers]
|
optimizer = [optimizer]
|
||||||
optimizers_kwargs = [optimizers_kwargs]
|
|
||||||
|
|
||||||
# number of models and optimizers
|
# number of models and optimizers
|
||||||
len_model = len(models)
|
len_model = len(model)
|
||||||
len_optimizer = len(optimizers)
|
len_optimizer = len(optimizer)
|
||||||
len_optimizer_kwargs = len(optimizers_kwargs)
|
|
||||||
|
|
||||||
# check length consistency optimizers
|
# check length consistency optimizers
|
||||||
if len_model != len_optimizer:
|
if len_model != len_optimizer:
|
||||||
@@ -66,52 +220,11 @@ class SolverInterface(pytorch_lightning.LightningModule, metaclass=ABCMeta):
|
|||||||
" optimizers."
|
" optimizers."
|
||||||
)
|
)
|
||||||
|
|
||||||
# check length consistency optimizers kwargs
|
|
||||||
if len_optimizer_kwargs != len_optimizer:
|
|
||||||
raise ValueError(
|
|
||||||
"You must define one dictionary of keyword"
|
|
||||||
" arguments for each optimizers."
|
|
||||||
f"Got {len_optimizer} optimizers, and"
|
|
||||||
f" {len_optimizer_kwargs} dicitionaries"
|
|
||||||
)
|
|
||||||
|
|
||||||
# extra features handling
|
# extra features handling
|
||||||
if (extra_features is None) or (len(extra_features) == 0):
|
|
||||||
extra_features = [None] * len_model
|
|
||||||
else:
|
|
||||||
# if we only have a list of extra features
|
|
||||||
if not isinstance(extra_features[0], (tuple, list)):
|
|
||||||
extra_features = [extra_features] * len_model
|
|
||||||
else: # if we have a list of list extra features
|
|
||||||
if len(extra_features) != len_model:
|
|
||||||
raise ValueError(
|
|
||||||
"You passed a list of extrafeatures list with len"
|
|
||||||
f"different of models len. Expected {len_model} "
|
|
||||||
f"got {len(extra_features)}. If you want to use "
|
|
||||||
"the same list of extra features for all models, "
|
|
||||||
"just pass a list of extrafeatures and not a list "
|
|
||||||
"of list of extra features."
|
|
||||||
)
|
|
||||||
|
|
||||||
# assigning model and optimizers
|
|
||||||
self._pina_models = []
|
|
||||||
self._pina_optimizers = []
|
|
||||||
|
|
||||||
for idx in range(len_model):
|
|
||||||
model_ = Network(
|
|
||||||
model=models[idx],
|
|
||||||
input_variables=problem.input_variables,
|
|
||||||
output_variables=problem.output_variables,
|
|
||||||
extra_features=extra_features[idx],
|
|
||||||
)
|
|
||||||
optim_ = optimizers[idx](
|
|
||||||
model_.parameters(), **optimizers_kwargs[idx]
|
|
||||||
)
|
|
||||||
self._pina_models.append(model_)
|
|
||||||
self._pina_optimizers.append(optim_)
|
|
||||||
|
|
||||||
# assigning problem
|
|
||||||
self._pina_problem = problem
|
self._pina_problem = problem
|
||||||
|
self._pina_model = model
|
||||||
|
self._pina_optimizer = optimizer
|
||||||
|
self._pina_scheduler = scheduler
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
@@ -129,13 +242,13 @@ class SolverInterface(pytorch_lightning.LightningModule, metaclass=ABCMeta):
|
|||||||
def models(self):
|
def models(self):
|
||||||
"""
|
"""
|
||||||
The torch model."""
|
The torch model."""
|
||||||
return self._pina_models
|
return self._pina_model
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def optimizers(self):
|
def optimizers(self):
|
||||||
"""
|
"""
|
||||||
The torch model."""
|
The torch model."""
|
||||||
return self._pina_optimizers
|
return self._pina_optimizer
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def problem(self):
|
def problem(self):
|
||||||
|
|||||||
@@ -1,21 +1,14 @@
|
|||||||
""" Module for SupervisedSolver """
|
""" Module for SupervisedSolver """
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from torch.nn.modules.loss import _Loss
|
||||||
|
|
||||||
try:
|
|
||||||
from torch.optim.lr_scheduler import LRScheduler # torch >= 2.0
|
|
||||||
except ImportError:
|
|
||||||
from torch.optim.lr_scheduler import (
|
|
||||||
_LRScheduler as LRScheduler,
|
|
||||||
) # torch < 2.0
|
|
||||||
|
|
||||||
from torch.optim.lr_scheduler import ConstantLR
|
|
||||||
|
|
||||||
|
from ..optim import Optimizer, Scheduler, TorchOptimizer, TorchScheduler
|
||||||
from .solver import SolverInterface
|
from .solver import SolverInterface
|
||||||
from ..label_tensor import LabelTensor
|
from ..label_tensor import LabelTensor
|
||||||
from ..utils import check_consistency
|
from ..utils import check_consistency
|
||||||
from ..loss import LossInterface
|
from ..loss import LossInterface
|
||||||
from torch.nn.modules.loss import _Loss
|
|
||||||
|
|
||||||
|
|
||||||
class SupervisedSolver(SolverInterface):
|
class SupervisedSolver(SolverInterface):
|
||||||
@@ -51,12 +44,9 @@ class SupervisedSolver(SolverInterface):
|
|||||||
self,
|
self,
|
||||||
problem,
|
problem,
|
||||||
model,
|
model,
|
||||||
extra_features=None,
|
loss=None,
|
||||||
loss=torch.nn.MSELoss(),
|
optimizer=None,
|
||||||
optimizer=torch.optim.Adam,
|
scheduler=None,
|
||||||
optimizer_kwargs={"lr": 0.001},
|
|
||||||
scheduler=ConstantLR,
|
|
||||||
scheduler_kwargs={"factor": 1, "total_iters": 0},
|
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
:param AbstractProblem problem: The formualation of the problem.
|
:param AbstractProblem problem: The formualation of the problem.
|
||||||
@@ -73,24 +63,26 @@ class SupervisedSolver(SolverInterface):
|
|||||||
rate scheduler.
|
rate scheduler.
|
||||||
:param dict scheduler_kwargs: LR scheduler constructor keyword args.
|
:param dict scheduler_kwargs: LR scheduler constructor keyword args.
|
||||||
"""
|
"""
|
||||||
|
if loss is None:
|
||||||
|
loss = torch.nn.MSELoss()
|
||||||
|
|
||||||
|
if optimizer is None:
|
||||||
|
optimizer = TorchOptimizer(torch.optim.Adam, lr=0.001)
|
||||||
|
|
||||||
|
if scheduler is None:
|
||||||
|
scheduler = TorchScheduler(
|
||||||
|
torch.optim.lr_scheduler.ConstantLR)
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
models=[model],
|
model=model,
|
||||||
problem=problem,
|
problem=problem,
|
||||||
optimizers=[optimizer],
|
optimizer=optimizer,
|
||||||
optimizers_kwargs=[optimizer_kwargs],
|
scheduler=scheduler,
|
||||||
extra_features=extra_features,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# check consistency
|
# check consistency
|
||||||
check_consistency(scheduler, LRScheduler, subclass=True)
|
|
||||||
check_consistency(scheduler_kwargs, dict)
|
|
||||||
check_consistency(loss, (LossInterface, _Loss), subclass=False)
|
check_consistency(loss, (LossInterface, _Loss), subclass=False)
|
||||||
|
|
||||||
# assign variables
|
|
||||||
self._scheduler = scheduler(self.optimizers[0], **scheduler_kwargs)
|
|
||||||
self._loss = loss
|
|
||||||
self._neural_net = self.models[0]
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
"""Forward pass implementation for the solver.
|
"""Forward pass implementation for the solver.
|
||||||
|
|
||||||
@@ -98,7 +90,7 @@ class SupervisedSolver(SolverInterface):
|
|||||||
:return: Solver solution.
|
:return: Solver solution.
|
||||||
:rtype: torch.Tensor
|
:rtype: torch.Tensor
|
||||||
"""
|
"""
|
||||||
return self.neural_net(x)
|
return self._pina_model(x)
|
||||||
|
|
||||||
def configure_optimizers(self):
|
def configure_optimizers(self):
|
||||||
"""Optimizer configuration for the solver.
|
"""Optimizer configuration for the solver.
|
||||||
@@ -106,7 +98,9 @@ class SupervisedSolver(SolverInterface):
|
|||||||
:return: The optimizers and the schedulers
|
:return: The optimizers and the schedulers
|
||||||
:rtype: tuple(list, list)
|
:rtype: tuple(list, list)
|
||||||
"""
|
"""
|
||||||
return self.optimizers, [self.scheduler]
|
self._pina_optimizer.hook(self._pina_model.parameters())
|
||||||
|
self._pina_scheduler.hook(self._pina_optimizer)
|
||||||
|
return self._pina_optimizer, self._pina_scheduler
|
||||||
|
|
||||||
def training_step(self, batch, batch_idx):
|
def training_step(self, batch, batch_idx):
|
||||||
"""Solver training step.
|
"""Solver training step.
|
||||||
@@ -168,14 +162,21 @@ class SupervisedSolver(SolverInterface):
|
|||||||
"""
|
"""
|
||||||
Scheduler for training.
|
Scheduler for training.
|
||||||
"""
|
"""
|
||||||
return self._scheduler
|
return self._pina_scheduler
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def neural_net(self):
|
def optimizer(self):
|
||||||
|
"""
|
||||||
|
Optimizer for training.
|
||||||
|
"""
|
||||||
|
return self._pina_optimizer
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model(self):
|
||||||
"""
|
"""
|
||||||
Neural network for training.
|
Neural network for training.
|
||||||
"""
|
"""
|
||||||
return self._neural_net
|
return self._pina_model
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def loss(self):
|
def loss(self):
|
||||||
|
|||||||
@@ -3,7 +3,7 @@
|
|||||||
import torch
|
import torch
|
||||||
import pytorch_lightning
|
import pytorch_lightning
|
||||||
from .utils import check_consistency
|
from .utils import check_consistency
|
||||||
from .dataset import SamplePointDataset, SamplePointLoader, DataPointDataset
|
from .data.dataset import SamplePointDataset, SamplePointLoader, DataPointDataset
|
||||||
from .solvers.solver import SolverInterface
|
from .solvers.solver import SolverInterface
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import torch
|
import torch
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from pina.dataset import SamplePointDataset, SamplePointLoader, DataPointDataset
|
from pina.data.dataset import SamplePointDataset, SamplePointLoader, DataPointDataset
|
||||||
from pina import LabelTensor, Condition
|
from pina import LabelTensor, Condition
|
||||||
from pina.equation import Equation
|
from pina.equation import Equation
|
||||||
from pina.geometry import CartesianDomain
|
from pina.geometry import CartesianDomain
|
||||||
|
|||||||
@@ -11,8 +11,11 @@ from pina.loss import LpLoss
|
|||||||
class NeuralOperatorProblem(AbstractProblem):
|
class NeuralOperatorProblem(AbstractProblem):
|
||||||
input_variables = ['u_0', 'u_1']
|
input_variables = ['u_0', 'u_1']
|
||||||
output_variables = ['u']
|
output_variables = ['u']
|
||||||
conditions = {'data' : Condition(input_points=LabelTensor(torch.rand(100, 2), input_variables),
|
conditions = {
|
||||||
output_points=LabelTensor(torch.rand(100, 1), output_variables))}
|
# 'data' : Condition(
|
||||||
|
# input_points=LabelTensor(torch.rand(100, 2), input_variables),
|
||||||
|
# output_points=LabelTensor(torch.rand(100, 1), output_variables))
|
||||||
|
}
|
||||||
|
|
||||||
class myFeature(torch.nn.Module):
|
class myFeature(torch.nn.Module):
|
||||||
"""
|
"""
|
||||||
@@ -39,63 +42,63 @@ model_extra_feats = FeedForward(
|
|||||||
|
|
||||||
|
|
||||||
def test_constructor():
|
def test_constructor():
|
||||||
SupervisedSolver(problem=problem, model=model, extra_features=None)
|
SupervisedSolver(problem=problem, model=model)
|
||||||
|
|
||||||
|
|
||||||
def test_constructor_extra_feats():
|
# def test_constructor_extra_feats():
|
||||||
SupervisedSolver(problem=problem, model=model_extra_feats, extra_features=extra_feats)
|
# SupervisedSolver(problem=problem, model=model_extra_feats, extra_features=extra_feats)
|
||||||
|
|
||||||
|
|
||||||
def test_train_cpu():
|
def test_train_cpu():
|
||||||
solver = SupervisedSolver(problem = problem, model=model, extra_features=None, loss=LpLoss())
|
solver = SupervisedSolver(problem = problem, model=model, loss=LpLoss())
|
||||||
trainer = Trainer(solver=solver, max_epochs=3, accelerator='cpu', batch_size=20)
|
trainer = Trainer(solver=solver, max_epochs=3, accelerator='cpu', batch_size=20)
|
||||||
trainer.train()
|
trainer.train()
|
||||||
|
|
||||||
|
|
||||||
def test_train_restore():
|
# def test_train_restore():
|
||||||
tmpdir = "tests/tmp_restore"
|
# tmpdir = "tests/tmp_restore"
|
||||||
solver = SupervisedSolver(problem=problem,
|
# solver = SupervisedSolver(problem=problem,
|
||||||
model=model,
|
# model=model,
|
||||||
extra_features=None,
|
# extra_features=None,
|
||||||
loss=LpLoss())
|
# loss=LpLoss())
|
||||||
trainer = Trainer(solver=solver,
|
# trainer = Trainer(solver=solver,
|
||||||
max_epochs=5,
|
# max_epochs=5,
|
||||||
accelerator='cpu',
|
# accelerator='cpu',
|
||||||
default_root_dir=tmpdir)
|
# default_root_dir=tmpdir)
|
||||||
trainer.train()
|
# trainer.train()
|
||||||
ntrainer = Trainer(solver=solver, max_epochs=15, accelerator='cpu')
|
# ntrainer = Trainer(solver=solver, max_epochs=15, accelerator='cpu')
|
||||||
t = ntrainer.train(
|
# t = ntrainer.train(
|
||||||
ckpt_path=f'{tmpdir}/lightning_logs/version_0/checkpoints/epoch=4-step=5.ckpt')
|
# ckpt_path=f'{tmpdir}/lightning_logs/version_0/checkpoints/epoch=4-step=5.ckpt')
|
||||||
import shutil
|
# import shutil
|
||||||
shutil.rmtree(tmpdir)
|
# shutil.rmtree(tmpdir)
|
||||||
|
|
||||||
|
|
||||||
def test_train_load():
|
# def test_train_load():
|
||||||
tmpdir = "tests/tmp_load"
|
# tmpdir = "tests/tmp_load"
|
||||||
solver = SupervisedSolver(problem=problem,
|
# solver = SupervisedSolver(problem=problem,
|
||||||
model=model,
|
# model=model,
|
||||||
extra_features=None,
|
# extra_features=None,
|
||||||
loss=LpLoss())
|
# loss=LpLoss())
|
||||||
trainer = Trainer(solver=solver,
|
# trainer = Trainer(solver=solver,
|
||||||
max_epochs=15,
|
# max_epochs=15,
|
||||||
accelerator='cpu',
|
# accelerator='cpu',
|
||||||
default_root_dir=tmpdir)
|
# default_root_dir=tmpdir)
|
||||||
trainer.train()
|
# trainer.train()
|
||||||
new_solver = SupervisedSolver.load_from_checkpoint(
|
# new_solver = SupervisedSolver.load_from_checkpoint(
|
||||||
f'{tmpdir}/lightning_logs/version_0/checkpoints/epoch=14-step=15.ckpt',
|
# f'{tmpdir}/lightning_logs/version_0/checkpoints/epoch=14-step=15.ckpt',
|
||||||
problem = problem, model=model)
|
# problem = problem, model=model)
|
||||||
test_pts = LabelTensor(torch.rand(20, 2), problem.input_variables)
|
# test_pts = LabelTensor(torch.rand(20, 2), problem.input_variables)
|
||||||
assert new_solver.forward(test_pts).shape == (20, 1)
|
# assert new_solver.forward(test_pts).shape == (20, 1)
|
||||||
assert new_solver.forward(test_pts).shape == solver.forward(test_pts).shape
|
# assert new_solver.forward(test_pts).shape == solver.forward(test_pts).shape
|
||||||
torch.testing.assert_close(
|
# torch.testing.assert_close(
|
||||||
new_solver.forward(test_pts),
|
# new_solver.forward(test_pts),
|
||||||
solver.forward(test_pts))
|
# solver.forward(test_pts))
|
||||||
import shutil
|
# import shutil
|
||||||
shutil.rmtree(tmpdir)
|
# shutil.rmtree(tmpdir)
|
||||||
|
|
||||||
def test_train_extra_feats_cpu():
|
# def test_train_extra_feats_cpu():
|
||||||
pinn = SupervisedSolver(problem=problem,
|
# pinn = SupervisedSolver(problem=problem,
|
||||||
model=model_extra_feats,
|
# model=model_extra_feats,
|
||||||
extra_features=extra_feats)
|
# extra_features=extra_feats)
|
||||||
trainer = Trainer(solver=pinn, max_epochs=5, accelerator='cpu')
|
# trainer = Trainer(solver=pinn, max_epochs=5, accelerator='cpu')
|
||||||
trainer.train()
|
# trainer.train()
|
||||||
Reference in New Issue
Block a user