Solvers for multiple models (#133)

* Solvers for multiple models
- Implementing the possibility to add multiple models for solvers (e.g. GAN)
- Implementing GAROM solver, see https://arxiv.org/abs/2305.15881
- Implementing tests for GAROM solver (cpu only)
- Fixing docs PINNs
- Creating a solver directory, for consistency in the package


---------

Co-authored-by: Dario Coscia <dariocoscia@dhcp-040.eduroam.sissa.it>
This commit is contained in:
Dario Coscia
2023-06-28 14:44:49 +02:00
committed by Nicola Demo
parent 6c8635c316
commit 701046661f
9 changed files with 612 additions and 81 deletions

View File

@@ -10,7 +10,7 @@ __all__ = [
from .meta import *
from .label_tensor import LabelTensor
from .pinn import PINN
from .solvers.pinn import PINN
from .trainer import Trainer
from .plotter import Plotter
from .condition import Condition

View File

@@ -1,65 +0,0 @@
""" Solver module. """
from abc import ABCMeta, abstractmethod
from .model.network import Network
import lightning.pytorch as pl
from .utils import check_consistency
from .problem import AbstractProblem
class SolverInterface(pl.LightningModule, metaclass=ABCMeta):
""" Solver base class. """
def __init__(self, model, problem, extra_features=None):
"""
:param model: A torch neural network model instance.
:type model: torch.nn.Module
:param problem: A problem definition instance.
:type problem: AbstractProblem
:param list(torch.nn.Module) extra_features: the additional input
features to use as augmented input.
"""
super().__init__()
# check inheritance for pina problem
check_consistency(problem, AbstractProblem)
# assigning class variables (check consistency inside Network class)
self._pina_model = Network(model=model, extra_features=extra_features)
self._problem = problem
@abstractmethod
def forward(self):
pass
@abstractmethod
def training_step(self):
pass
@abstractmethod
def configure_optimizers(self):
pass
@property
def model(self):
"""
The torch model."""
return self._pina_model
@property
def problem(self):
"""
The problem formulation."""
return self._problem
# @model.setter
# def model(self, new_model):
# """
# Set the torch."""
# check_consistency(new_model, nn.Module, 'torch model')
# self._model= new_model
# @problem.setter
# def problem(self, problem):
# """
# Set the problem formulation."""
# check_consistency(problem, AbstractProblem, 'pina problem')
# self._problem = problem

7
pina/solvers/__init__.py Normal file
View File

@@ -0,0 +1,7 @@
__all__ = [
'PINN',
'GAROM',
]
from .garom import GAROM
from .pinn import PINN

261
pina/solvers/garom.py Normal file
View File

@@ -0,0 +1,261 @@
""" Module for PINN """
import torch
try:
from torch.optim.lr_scheduler import LRScheduler # torch >= 2.0
except ImportError:
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler # torch < 2.0
from torch.optim.lr_scheduler import ConstantLR
from .solver import SolverInterface
from ..utils import check_consistency
from ..loss import LossInterface, LpLoss
from torch.nn.modules.loss import _Loss
class GAROM(SolverInterface):
"""
GAROM solver class. This class implements Generative Adversarial
Reduced Order Model solver, using user specified ``models`` to solve
a specific order reduction``problem``.
.. seealso::
**Original reference**: Coscia, D., Demo, N., & Rozza, G. (2023).
Generative Adversarial Reduced Order Modelling.
arXiv preprint arXiv:2305.15881.
<https://doi.org/10.48550/arXiv.2305.15881>`_.
"""
def __init__(self,
problem,
generator,
discriminator,
extra_features=None,
loss = None,
optimizer_generator=torch.optim.Adam,
optimizer_generator_kwargs={'lr' : 0.001},
optimizer_discriminator=torch.optim.Adam,
optimizer_discriminator_kwargs={'lr' : 0.001},
scheduler_generator=ConstantLR,
scheduler_generator_kwargs={"factor": 1, "total_iters": 0},
scheduler_discriminator=ConstantLR,
scheduler_discriminator_kwargs={"factor": 1, "total_iters": 0},
gamma = 0.3,
lambda_k = 0.001,
regularizer = False,
):
"""
:param AbstractProblem problem: The formualation of the problem.
:param torch.nn.Module generator: The neural network model to use
for the generator.
:param torch.nn.Module discriminator: The neural network model to use
for the discriminator.
:param torch.nn.Module extra_features: The additional input
features to use as augmented input. It should either be a
list of torch.nn.Module, or a dictionary. If a list it is
passed the extra features are passed to both network. If a
dictionary is passed, the keys must be ``generator`` and
``discriminator`` and the values a list of torch.nn.Module
extra features for each.
:param torch.nn.Module loss: The loss function used as minimizer,
default ``None``. If ``loss`` is ``None`` the defualt
``LpLoss(p=1)`` is used, as in the original paper.
:param torch.optim.Optimizer optimizer_generator: The neural
network optimizer to use for the generator network
, default is `torch.optim.Adam`.
:param dict optimizer_generator_kwargs: Optimizer constructor keyword
args. for the generator.
:param torch.optim.Optimizer optimizer_discriminator: The neural
network optimizer to use for the discriminator network
, default is `torch.optim.Adam`.
:param dict optimizer_discriminator_kwargs: Optimizer constructor keyword
args. for the discriminator.
:param torch.optim.LRScheduler scheduler_generator: Learning
rate scheduler for the generator.
:param dict scheduler_generator_kwargs: LR scheduler constructor keyword args.
:param torch.optim.LRScheduler scheduler_discriminator: Learning
rate scheduler for the discriminator.
:param dict scheduler_discriminator_kwargs: LR scheduler constructor keyword args.
:param gamma: Ratio of expected loss for generator and discriminator, defaults to 0.3.
:type gamma: float, optional
:param lambda_k: Learning rate for control theory optimization, defaults to 0.001.
:type lambda_k: float, optional
:param regularizer: Regularization term in the GAROM loss, defaults to False.
:type regularizer: bool, optional
.. warning::
The algorithm works only for data-driven model. Hence in the ``problem`` definition
the codition must only contain ``input_points`` (e.g. coefficient parameters, time
parameters), and ``output_points``.
"""
if isinstance(extra_features, dict):
extra_features = [extra_features['generator'], extra_features['discriminator']]
super().__init__(models=[generator, discriminator],
problem=problem,
extra_features=extra_features,
optimizers=[optimizer_generator, optimizer_discriminator],
optimizers_kwargs=[optimizer_generator_kwargs, optimizer_discriminator_kwargs])
# set automatic optimization for GANs
self.automatic_optimization = False
# set loss
if loss is None:
loss = LpLoss(p=1)
# check consistency
check_consistency(scheduler_generator, LRScheduler, subclass=True)
check_consistency(scheduler_generator_kwargs, dict)
check_consistency(scheduler_discriminator, LRScheduler, subclass=True)
check_consistency(scheduler_discriminator_kwargs, dict)
check_consistency(loss, (LossInterface, _Loss))
check_consistency(gamma, float)
check_consistency(lambda_k, float)
check_consistency(regularizer, bool)
# assign schedulers
self._schedulers = [scheduler_generator(self.optimizers[0],
**scheduler_generator_kwargs),
scheduler_discriminator(self.optimizers[1],
**scheduler_discriminator_kwargs)]
# loss and writer
self._loss = loss
# began hyperparameters
self.k = 0
self.gamma = gamma
self.lambda_k = lambda_k
self.regularizer = float(regularizer)
def forward(self, x, mc_steps=20, variance=False):
# sampling
field_sample = [self.sample(x) for _ in range(mc_steps)]
field_sample = torch.stack(field_sample)
# extract mean
mean = field_sample.mean(dim=0)
if variance:
var = field_sample.var(dim=0)
return mean, var
return mean
def configure_optimizers(self):
"""Optimizer configuration for the GAROM
solver.
:return: The optimizers and the schedulers
:rtype: tuple(list, list)
"""
return self.optimizers, self._schedulers
def sample(self, x):
# sampling
return self.generator(x)
def training_step(self, batch, batch_idx):
"""PINN solver training step.
:param batch: The batch element in the dataloader.
:type batch: tuple
:param batch_idx: The batch index.
:type batch_idx: int
:return: The sum of the loss functions.
:rtype: LabelTensor
"""
for condition_name, samples in batch.items():
if condition_name not in self.problem.conditions:
raise RuntimeError('Something wrong happened.')
condition = self.problem.conditions[condition_name]
# for data driven mode
if hasattr(condition, 'output_points'):
# get data
parameters, input_pts = samples
# get optimizers
opt_gen, opt_disc = self.optimizers
# ---------------------
# Train Discriminator
# ---------------------
opt_disc.zero_grad()
# Generate a batch of images
gen_imgs = self.generator(parameters)
# Discriminator pass
d_real = self.discriminator([input_pts, parameters])
d_fake = self.discriminator([gen_imgs.detach(), parameters])
# evaluate loss
d_loss_real = self._loss(d_real, input_pts)
d_loss_fake = self._loss(d_fake, gen_imgs.detach())
d_loss = d_loss_real - self.k * d_loss_fake
# backward step
d_loss.backward()
opt_disc.step()
# -----------------
# Train Generator
# -----------------
opt_gen.zero_grad()
# Generate a batch of images
gen_imgs = self.generator(parameters)
# generator loss
r_loss = self._loss(input_pts, gen_imgs)
d_fake = self.discriminator([gen_imgs, parameters])
g_loss = self._loss(d_fake, gen_imgs) + self.regularizer * r_loss
# backward step
g_loss.backward()
opt_gen.step()
# ----------------
# Update weights
# ----------------
diff = torch.mean(self.gamma * d_loss_real - d_loss_fake)
# Update weight term for fake samples
self.k += self.lambda_k * diff.item()
self.k = min(max(self.k, 0), 1) # Constraint to interval [0, 1]
else:
raise NotImplementedError('GAROM works only in data-driven mode.')
return
@property
def generator(self):
return self.models[0]
@property
def discriminator(self):
return self.models[1]
@property
def optimizer_generator(self):
return self.optimizers[0]
@property
def optimizer_discriminator(self):
return self.optimizers[1]
@property
def scheduler_generator(self):
return self._schedulers[0]
@property
def scheduler_discriminator(self):
return self._schedulers[1]

View File

@@ -8,10 +8,9 @@ except ImportError:
from torch.optim.lr_scheduler import ConstantLR
from .solver import SolverInterface
from .label_tensor import LabelTensor
from .utils import check_consistency
from .writer import Writer
from .loss import LossInterface
from ..label_tensor import LabelTensor
from ..utils import check_consistency
from ..loss import LossInterface
from torch.nn.modules.loss import _Loss
@@ -19,7 +18,18 @@ torch.pi = torch.acos(torch.zeros(1)).item() * 2 # which is 3.1415927410125732
class PINN(SolverInterface):
"""
PINN solver class. This class implements Physics Informed Neural
Network solvers, using a user specified ``model`` to solve a specific
``problem``.
.. seealso::
**Original reference**: Karniadakis, G. E., Kevrekidis, I. G., Lu, L.,
Perdikaris, P., Wang, S., & Yang, L. (2021).
Physics-informed machine learning. Nature Reviews Physics, 3(6), 422-440.
<https://doi.org/10.1038/s42254-021-00314-5>`_.
"""
def __init__(self,
problem,
model,
@@ -45,20 +55,21 @@ class PINN(SolverInterface):
rate scheduler.
:param dict scheduler_kwargs: LR scheduler constructor keyword args.
'''
super().__init__(model=model, problem=problem, extra_features=extra_features)
super().__init__(models=[model],
problem=problem,
optimizers=[optimizer],
optimizers_kwargs=[optimizer_kwargs],
extra_features=extra_features)
# check consistency
check_consistency(optimizer, torch.optim.Optimizer, subclass=True)
check_consistency(optimizer_kwargs, dict)
check_consistency(scheduler, LRScheduler, subclass=True)
check_consistency(scheduler_kwargs, dict)
check_consistency(loss, (LossInterface, _Loss), subclass=False)
# assign variables
self._optimizer = optimizer(self.model.parameters(), **optimizer_kwargs)
self._scheduler = scheduler(self._optimizer, **scheduler_kwargs)
self._scheduler = scheduler(self.optimizers[0], **scheduler_kwargs)
self._loss = loss
self._writer = Writer()
self._neural_net = self.models[0]
def forward(self, x):
@@ -72,7 +83,7 @@ class PINN(SolverInterface):
# extract labels
x = x.extract(self.problem.input_variables)
# perform forward pass
output = self.model(x).as_subclass(LabelTensor)
output = self.neural_net(x).as_subclass(LabelTensor)
# set the labels
output.labels = self.problem.output_variables
return output
@@ -84,7 +95,7 @@ class PINN(SolverInterface):
:return: The optimizers and the schedulers
:rtype: tuple(list, list)
"""
return [self._optimizer], [self._scheduler]
return self.optimizers, [self.scheduler]
def training_step(self, batch, batch_idx):
"""PINN solver training step.
@@ -109,11 +120,11 @@ class PINN(SolverInterface):
# PINN loss: equation evaluated on location or input_points
if hasattr(condition, 'equation'):
target = condition.equation.residual(samples, self.forward(samples))
loss = self._loss(torch.zeros_like(target), target)
loss = self.loss(torch.zeros_like(target), target)
# PINN loss: evaluate model(input_points) vs output_points
elif hasattr(condition, 'output_points'):
input_pts, output_pts = samples
loss = self._loss(self.forward(input_pts), output_pts)
loss = self.loss(self.forward(input_pts), output_pts)
condition_losses.append(loss * condition.data_weight)
@@ -121,3 +132,24 @@ class PINN(SolverInterface):
# we need to pass it as a torch tensor to make everything work
total_loss = sum(condition_losses)
return total_loss
@property
def scheduler(self):
"""
Scheduler for the PINN training.
"""
return self._scheduler
@property
def neural_net(self):
"""
Neural network for the PINN training.
"""
return self._neural_net
@property
def loss(self):
"""
Loss for the PINN training.
"""
return self._loss

134
pina/solvers/solver.py Normal file
View File

@@ -0,0 +1,134 @@
""" Solver module. """
from abc import ABCMeta, abstractmethod
from ..model.network import Network
import lightning.pytorch as pl
from ..utils import check_consistency
from ..problem import AbstractProblem
import torch
class SolverInterface(pl.LightningModule, metaclass=ABCMeta):
""" Solver base class. """
def __init__(self,
models,
problem,
optimizers,
optimizers_kwargs,
extra_features=None):
"""
:param models: A torch neural network model instance.
:type models: torch.nn.Module
:param problem: A problem definition instance.
:type problem: AbstractProblem
:param list(torch.nn.Module) extra_features: the additional input
features to use as augmented input. If ``None`` no extra features
are passed. If it is a list of ``torch.nn.Module``, the extra feature
list is passed to all models. If it is a list of extra features' lists,
each single list of extra feature is passed to a model.
"""
super().__init__()
# check consistency of the inputs
check_consistency(models, torch.nn.Module)
check_consistency(problem, AbstractProblem)
check_consistency(optimizers, torch.optim.Optimizer, subclass=True)
check_consistency(optimizers_kwargs, dict)
# put everything in a list if only one input
if not isinstance(models, list):
models = [models]
if not isinstance(optimizers, list):
optimizers = [optimizers]
optimizers_kwargs = [optimizers_kwargs]
# number of models and optimizers
len_model = len(models)
len_optimizer = len(optimizers)
len_optimizer_kwargs = len(optimizers_kwargs)
# check length consistency optimizers
if len_model != len_optimizer:
raise ValueError('You must define one optimizer for each model.'
f'Got {len_model} models, and {len_optimizer}'
' optimizers.')
# check length consistency optimizers kwargs
if len_optimizer_kwargs != len_optimizer:
raise ValueError('You must define one dictionary of keyword'
' arguments for each optimizers.'
f'Got {len_optimizer} optimizers, and'
f' {len_optimizer_kwargs} dicitionaries')
# extra features handling
if extra_features is None:
extra_features = [None] * len_model
else:
# if we only have a list of extra features
if not isinstance(extra_features[0], (tuple, list)):
extra_features = [extra_features] * len_model
else: # if we have a list of list extra features
if len(extra_features) != len_model:
raise ValueError('You passed a list of extrafeatures list with len'
f'different of models len. Expected {len_model} '
f'got {len(extra_features)}. If you want to use'
'the same list of extra features for all models, '
'just pass a list of extrafeatures and not a list '
'of list of extra features.')
# assigning model and optimizers
self._pina_models = []
self._pina_optimizers = []
for idx in range(len_model):
model_ = Network(model=models[idx], extra_features=extra_features[idx])
optim_ = optimizers[idx](model_.parameters(), **optimizers_kwargs[idx])
self._pina_models.append(model_)
self._pina_optimizers.append(optim_)
# assigning problem
self._pina_problem = problem
@abstractmethod
def forward(self):
pass
@abstractmethod
def training_step(self):
pass
@abstractmethod
def configure_optimizers(self):
pass
@property
def models(self):
"""
The torch model."""
return self._pina_models
@property
def optimizers(self):
"""
The torch model."""
return self._pina_optimizers
@property
def problem(self):
"""
The problem formulation."""
return self._pina_problem
# @model.setter
# def model(self, new_model):
# """
# Set the torch."""
# check_consistency(new_model, nn.Module, 'torch model')
# self._model= new_model
# @problem.setter
# def problem(self, problem):
# """
# Set the problem formulation."""
# check_consistency(problem, AbstractProblem, 'pina problem')
# self._problem = problem

View File

@@ -3,7 +3,7 @@
import lightning.pytorch as pl
from .utils import check_consistency
from .dataset import DummyLoader
from .solver import SolverInterface
from .solvers.solver import SolverInterface
class Trainer(pl.Trainer):

View File

@@ -0,0 +1,162 @@
import torch
from pina.problem import AbstractProblem
from pina import Condition, LabelTensor
from pina.solvers import GAROM
from pina.trainer import Trainer
import torch.nn as nn
import matplotlib.tri as tri
def func(x, mu1, mu2):
import torch
x_m1 = (x[:, 0] - mu1).pow(2)
x_m2 = (x[:, 1] - mu2).pow(2)
norm = x[:, 0]**2 + x[:, 1]**2
return torch.exp(-(x_m1 + x_m2))
class ParametricGaussian(AbstractProblem):
output_variables = [f'u_{i}' for i in range(900)]
# params
xx = torch.linspace(-1, 1, 20)
yy = xx
params = LabelTensor(torch.cartesian_prod(xx, yy), labels=['mu1', 'mu2'])
# define domain
x = torch.linspace(-1, 1, 30)
domain = torch.cartesian_prod(x, x)
triang = tri.Triangulation(domain[:, 0], domain[:, 1])
sol = []
for p in params:
sol.append(func(domain, p[0], p[1]))
snapshots = LabelTensor(torch.stack(sol), labels=output_variables)
# define conditions
conditions = {
'data': Condition(
input_points=params,
output_points=snapshots)
}
# simple Generator Network
class Generator(nn.Module):
def __init__(self, input_dimension, parameters_dimension,
noise_dimension, activation=torch.nn.SiLU):
super().__init__()
self._noise_dimension = noise_dimension
self._activation = activation
self.model = torch.nn.Sequential(
torch.nn.Linear(6 * self._noise_dimension, input_dimension // 6),
self._activation(),
torch.nn.Linear(input_dimension // 6, input_dimension // 3),
self._activation(),
torch.nn.Linear(input_dimension // 3, input_dimension)
)
self.condition = torch.nn.Sequential(
torch.nn.Linear(parameters_dimension, 2 * self._noise_dimension),
self._activation(),
torch.nn.Linear(2 * self._noise_dimension, 5 * self._noise_dimension)
)
def forward(self, param):
# uniform sampling in [-1, 1]
z = torch.rand(size=(param.shape[0], self._noise_dimension),
device=param.device,
dtype=param.dtype,
requires_grad=True)
z = 2. * z - 1.
# conditioning by concatenation of mapped parameters
input_ = torch.cat((z, self.condition(param)), dim=-1)
out = self.model(input_)
return out
# Simple Discriminator Network
class Discriminator(nn.Module):
def __init__(self, input_dimension, parameter_dimension,
hidden_dimension, activation=torch.nn.ReLU):
super().__init__()
self._activation = activation
self.encoding = torch.nn.Sequential(
torch.nn.Linear(input_dimension, input_dimension // 3),
self._activation(),
torch.nn.Linear(input_dimension // 3, input_dimension // 6),
self._activation(),
torch.nn.Linear(input_dimension // 6, hidden_dimension)
)
self.decoding = torch.nn.Sequential(
torch.nn.Linear(2*hidden_dimension, input_dimension // 6),
self._activation(),
torch.nn.Linear(input_dimension // 6, input_dimension // 3),
self._activation(),
torch.nn.Linear(input_dimension // 3, input_dimension),
)
self.condition = torch.nn.Sequential(
torch.nn.Linear(parameter_dimension, hidden_dimension // 2),
self._activation(),
torch.nn.Linear(hidden_dimension // 2, hidden_dimension)
)
def forward(self, data):
x, condition = data
encoding = self.encoding(x)
conditioning = torch.cat((encoding, self.condition(condition)), dim=-1)
decoding = self.decoding(conditioning)
return decoding
problem = ParametricGaussian()
def test_constructor():
GAROM(problem = problem,
generator = Generator(input_dimension=900,
parameters_dimension=2,
noise_dimension=12),
discriminator = Discriminator(input_dimension=900,
parameter_dimension=2,
hidden_dimension=64)
)
def test_train_cpu():
solver = GAROM(problem = problem,
generator = Generator(input_dimension=900,
parameters_dimension=2,
noise_dimension=12),
discriminator = Discriminator(input_dimension=900,
parameter_dimension=2,
hidden_dimension=64)
)
trainer = Trainer(solver=solver, kwargs={'max_epochs' : 4, 'accelerator': 'cpu'})
trainer.train()
def test_sample():
solver = GAROM(problem = problem,
generator = Generator(input_dimension=900,
parameters_dimension=2,
noise_dimension=12),
discriminator = Discriminator(input_dimension=900,
parameter_dimension=2,
hidden_dimension=64)
)
solver.sample(problem.params)
assert solver.sample(problem.params).shape == problem.snapshots.shape
def test_forward():
solver = GAROM(problem = problem,
generator = Generator(input_dimension=900,
parameters_dimension=2,
noise_dimension=12),
discriminator = Discriminator(input_dimension=900,
parameter_dimension=2,
hidden_dimension=64)
)
solver(problem.params, mc_steps=100, variance=True)
assert solver(problem.params).shape == problem.snapshots.shape