Solvers for multiple models (#133)
* Solvers for multiple models - Implementing the possibility to add multiple models for solvers (e.g. GAN) - Implementing GAROM solver, see https://arxiv.org/abs/2305.15881 - Implementing tests for GAROM solver (cpu only) - Fixing docs PINNs - Creating a solver directory, for consistency in the package --------- Co-authored-by: Dario Coscia <dariocoscia@dhcp-040.eduroam.sissa.it>
This commit is contained in:
committed by
Nicola Demo
parent
6c8635c316
commit
701046661f
@@ -10,7 +10,7 @@ __all__ = [
|
||||
|
||||
from .meta import *
|
||||
from .label_tensor import LabelTensor
|
||||
from .pinn import PINN
|
||||
from .solvers.pinn import PINN
|
||||
from .trainer import Trainer
|
||||
from .plotter import Plotter
|
||||
from .condition import Condition
|
||||
|
||||
@@ -1,65 +0,0 @@
|
||||
""" Solver module. """
|
||||
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from .model.network import Network
|
||||
import lightning.pytorch as pl
|
||||
from .utils import check_consistency
|
||||
from .problem import AbstractProblem
|
||||
|
||||
class SolverInterface(pl.LightningModule, metaclass=ABCMeta):
|
||||
""" Solver base class. """
|
||||
def __init__(self, model, problem, extra_features=None):
|
||||
"""
|
||||
:param model: A torch neural network model instance.
|
||||
:type model: torch.nn.Module
|
||||
:param problem: A problem definition instance.
|
||||
:type problem: AbstractProblem
|
||||
:param list(torch.nn.Module) extra_features: the additional input
|
||||
features to use as augmented input.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
# check inheritance for pina problem
|
||||
check_consistency(problem, AbstractProblem)
|
||||
|
||||
# assigning class variables (check consistency inside Network class)
|
||||
self._pina_model = Network(model=model, extra_features=extra_features)
|
||||
self._problem = problem
|
||||
|
||||
@abstractmethod
|
||||
def forward(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def training_step(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def configure_optimizers(self):
|
||||
pass
|
||||
|
||||
@property
|
||||
def model(self):
|
||||
"""
|
||||
The torch model."""
|
||||
return self._pina_model
|
||||
|
||||
@property
|
||||
def problem(self):
|
||||
"""
|
||||
The problem formulation."""
|
||||
return self._problem
|
||||
|
||||
# @model.setter
|
||||
# def model(self, new_model):
|
||||
# """
|
||||
# Set the torch."""
|
||||
# check_consistency(new_model, nn.Module, 'torch model')
|
||||
# self._model= new_model
|
||||
|
||||
# @problem.setter
|
||||
# def problem(self, problem):
|
||||
# """
|
||||
# Set the problem formulation."""
|
||||
# check_consistency(problem, AbstractProblem, 'pina problem')
|
||||
# self._problem = problem
|
||||
7
pina/solvers/__init__.py
Normal file
7
pina/solvers/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
__all__ = [
|
||||
'PINN',
|
||||
'GAROM',
|
||||
]
|
||||
|
||||
from .garom import GAROM
|
||||
from .pinn import PINN
|
||||
261
pina/solvers/garom.py
Normal file
261
pina/solvers/garom.py
Normal file
@@ -0,0 +1,261 @@
|
||||
""" Module for PINN """
|
||||
import torch
|
||||
try:
|
||||
from torch.optim.lr_scheduler import LRScheduler # torch >= 2.0
|
||||
except ImportError:
|
||||
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler # torch < 2.0
|
||||
|
||||
from torch.optim.lr_scheduler import ConstantLR
|
||||
from .solver import SolverInterface
|
||||
from ..utils import check_consistency
|
||||
from ..loss import LossInterface, LpLoss
|
||||
from torch.nn.modules.loss import _Loss
|
||||
|
||||
|
||||
class GAROM(SolverInterface):
|
||||
"""
|
||||
GAROM solver class. This class implements Generative Adversarial
|
||||
Reduced Order Model solver, using user specified ``models`` to solve
|
||||
a specific order reduction``problem``.
|
||||
|
||||
.. seealso::
|
||||
|
||||
**Original reference**: Coscia, D., Demo, N., & Rozza, G. (2023).
|
||||
Generative Adversarial Reduced Order Modelling.
|
||||
arXiv preprint arXiv:2305.15881.
|
||||
<https://doi.org/10.48550/arXiv.2305.15881>`_.
|
||||
"""
|
||||
def __init__(self,
|
||||
problem,
|
||||
generator,
|
||||
discriminator,
|
||||
extra_features=None,
|
||||
loss = None,
|
||||
optimizer_generator=torch.optim.Adam,
|
||||
optimizer_generator_kwargs={'lr' : 0.001},
|
||||
optimizer_discriminator=torch.optim.Adam,
|
||||
optimizer_discriminator_kwargs={'lr' : 0.001},
|
||||
scheduler_generator=ConstantLR,
|
||||
scheduler_generator_kwargs={"factor": 1, "total_iters": 0},
|
||||
scheduler_discriminator=ConstantLR,
|
||||
scheduler_discriminator_kwargs={"factor": 1, "total_iters": 0},
|
||||
gamma = 0.3,
|
||||
lambda_k = 0.001,
|
||||
regularizer = False,
|
||||
):
|
||||
"""
|
||||
:param AbstractProblem problem: The formualation of the problem.
|
||||
:param torch.nn.Module generator: The neural network model to use
|
||||
for the generator.
|
||||
:param torch.nn.Module discriminator: The neural network model to use
|
||||
for the discriminator.
|
||||
:param torch.nn.Module extra_features: The additional input
|
||||
features to use as augmented input. It should either be a
|
||||
list of torch.nn.Module, or a dictionary. If a list it is
|
||||
passed the extra features are passed to both network. If a
|
||||
dictionary is passed, the keys must be ``generator`` and
|
||||
``discriminator`` and the values a list of torch.nn.Module
|
||||
extra features for each.
|
||||
:param torch.nn.Module loss: The loss function used as minimizer,
|
||||
default ``None``. If ``loss`` is ``None`` the defualt
|
||||
``LpLoss(p=1)`` is used, as in the original paper.
|
||||
:param torch.optim.Optimizer optimizer_generator: The neural
|
||||
network optimizer to use for the generator network
|
||||
, default is `torch.optim.Adam`.
|
||||
:param dict optimizer_generator_kwargs: Optimizer constructor keyword
|
||||
args. for the generator.
|
||||
:param torch.optim.Optimizer optimizer_discriminator: The neural
|
||||
network optimizer to use for the discriminator network
|
||||
, default is `torch.optim.Adam`.
|
||||
:param dict optimizer_discriminator_kwargs: Optimizer constructor keyword
|
||||
args. for the discriminator.
|
||||
:param torch.optim.LRScheduler scheduler_generator: Learning
|
||||
rate scheduler for the generator.
|
||||
:param dict scheduler_generator_kwargs: LR scheduler constructor keyword args.
|
||||
:param torch.optim.LRScheduler scheduler_discriminator: Learning
|
||||
rate scheduler for the discriminator.
|
||||
:param dict scheduler_discriminator_kwargs: LR scheduler constructor keyword args.
|
||||
:param gamma: Ratio of expected loss for generator and discriminator, defaults to 0.3.
|
||||
:type gamma: float, optional
|
||||
:param lambda_k: Learning rate for control theory optimization, defaults to 0.001.
|
||||
:type lambda_k: float, optional
|
||||
:param regularizer: Regularization term in the GAROM loss, defaults to False.
|
||||
:type regularizer: bool, optional
|
||||
|
||||
.. warning::
|
||||
The algorithm works only for data-driven model. Hence in the ``problem`` definition
|
||||
the codition must only contain ``input_points`` (e.g. coefficient parameters, time
|
||||
parameters), and ``output_points``.
|
||||
"""
|
||||
|
||||
if isinstance(extra_features, dict):
|
||||
extra_features = [extra_features['generator'], extra_features['discriminator']]
|
||||
|
||||
super().__init__(models=[generator, discriminator],
|
||||
problem=problem,
|
||||
extra_features=extra_features,
|
||||
optimizers=[optimizer_generator, optimizer_discriminator],
|
||||
optimizers_kwargs=[optimizer_generator_kwargs, optimizer_discriminator_kwargs])
|
||||
|
||||
# set automatic optimization for GANs
|
||||
self.automatic_optimization = False
|
||||
|
||||
# set loss
|
||||
if loss is None:
|
||||
loss = LpLoss(p=1)
|
||||
|
||||
# check consistency
|
||||
check_consistency(scheduler_generator, LRScheduler, subclass=True)
|
||||
check_consistency(scheduler_generator_kwargs, dict)
|
||||
check_consistency(scheduler_discriminator, LRScheduler, subclass=True)
|
||||
check_consistency(scheduler_discriminator_kwargs, dict)
|
||||
check_consistency(loss, (LossInterface, _Loss))
|
||||
check_consistency(gamma, float)
|
||||
check_consistency(lambda_k, float)
|
||||
check_consistency(regularizer, bool)
|
||||
|
||||
|
||||
# assign schedulers
|
||||
self._schedulers = [scheduler_generator(self.optimizers[0],
|
||||
**scheduler_generator_kwargs),
|
||||
scheduler_discriminator(self.optimizers[1],
|
||||
**scheduler_discriminator_kwargs)]
|
||||
# loss and writer
|
||||
self._loss = loss
|
||||
|
||||
# began hyperparameters
|
||||
self.k = 0
|
||||
self.gamma = gamma
|
||||
self.lambda_k = lambda_k
|
||||
self.regularizer = float(regularizer)
|
||||
|
||||
def forward(self, x, mc_steps=20, variance=False):
|
||||
|
||||
# sampling
|
||||
field_sample = [self.sample(x) for _ in range(mc_steps)]
|
||||
field_sample = torch.stack(field_sample)
|
||||
|
||||
# extract mean
|
||||
mean = field_sample.mean(dim=0)
|
||||
|
||||
if variance:
|
||||
var = field_sample.var(dim=0)
|
||||
return mean, var
|
||||
|
||||
return mean
|
||||
|
||||
def configure_optimizers(self):
|
||||
"""Optimizer configuration for the GAROM
|
||||
solver.
|
||||
|
||||
:return: The optimizers and the schedulers
|
||||
:rtype: tuple(list, list)
|
||||
"""
|
||||
return self.optimizers, self._schedulers
|
||||
|
||||
def sample(self, x):
|
||||
# sampling
|
||||
return self.generator(x)
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
"""PINN solver training step.
|
||||
|
||||
:param batch: The batch element in the dataloader.
|
||||
:type batch: tuple
|
||||
:param batch_idx: The batch index.
|
||||
:type batch_idx: int
|
||||
:return: The sum of the loss functions.
|
||||
:rtype: LabelTensor
|
||||
"""
|
||||
|
||||
for condition_name, samples in batch.items():
|
||||
|
||||
if condition_name not in self.problem.conditions:
|
||||
raise RuntimeError('Something wrong happened.')
|
||||
|
||||
condition = self.problem.conditions[condition_name]
|
||||
|
||||
# for data driven mode
|
||||
if hasattr(condition, 'output_points'):
|
||||
|
||||
# get data
|
||||
parameters, input_pts = samples
|
||||
|
||||
# get optimizers
|
||||
opt_gen, opt_disc = self.optimizers
|
||||
|
||||
# ---------------------
|
||||
# Train Discriminator
|
||||
# ---------------------
|
||||
opt_disc.zero_grad()
|
||||
|
||||
# Generate a batch of images
|
||||
gen_imgs = self.generator(parameters)
|
||||
|
||||
# Discriminator pass
|
||||
d_real = self.discriminator([input_pts, parameters])
|
||||
d_fake = self.discriminator([gen_imgs.detach(), parameters])
|
||||
|
||||
# evaluate loss
|
||||
d_loss_real = self._loss(d_real, input_pts)
|
||||
d_loss_fake = self._loss(d_fake, gen_imgs.detach())
|
||||
d_loss = d_loss_real - self.k * d_loss_fake
|
||||
|
||||
# backward step
|
||||
d_loss.backward()
|
||||
opt_disc.step()
|
||||
|
||||
# -----------------
|
||||
# Train Generator
|
||||
# -----------------
|
||||
opt_gen.zero_grad()
|
||||
|
||||
# Generate a batch of images
|
||||
gen_imgs = self.generator(parameters)
|
||||
|
||||
# generator loss
|
||||
r_loss = self._loss(input_pts, gen_imgs)
|
||||
d_fake = self.discriminator([gen_imgs, parameters])
|
||||
g_loss = self._loss(d_fake, gen_imgs) + self.regularizer * r_loss
|
||||
|
||||
# backward step
|
||||
g_loss.backward()
|
||||
opt_gen.step()
|
||||
|
||||
# ----------------
|
||||
# Update weights
|
||||
# ----------------
|
||||
diff = torch.mean(self.gamma * d_loss_real - d_loss_fake)
|
||||
|
||||
# Update weight term for fake samples
|
||||
self.k += self.lambda_k * diff.item()
|
||||
self.k = min(max(self.k, 0), 1) # Constraint to interval [0, 1]
|
||||
|
||||
else:
|
||||
raise NotImplementedError('GAROM works only in data-driven mode.')
|
||||
|
||||
return
|
||||
|
||||
@property
|
||||
def generator(self):
|
||||
return self.models[0]
|
||||
|
||||
@property
|
||||
def discriminator(self):
|
||||
return self.models[1]
|
||||
|
||||
@property
|
||||
def optimizer_generator(self):
|
||||
return self.optimizers[0]
|
||||
|
||||
@property
|
||||
def optimizer_discriminator(self):
|
||||
return self.optimizers[1]
|
||||
|
||||
@property
|
||||
def scheduler_generator(self):
|
||||
return self._schedulers[0]
|
||||
|
||||
@property
|
||||
def scheduler_discriminator(self):
|
||||
return self._schedulers[1]
|
||||
@@ -8,10 +8,9 @@ except ImportError:
|
||||
from torch.optim.lr_scheduler import ConstantLR
|
||||
|
||||
from .solver import SolverInterface
|
||||
from .label_tensor import LabelTensor
|
||||
from .utils import check_consistency
|
||||
from .writer import Writer
|
||||
from .loss import LossInterface
|
||||
from ..label_tensor import LabelTensor
|
||||
from ..utils import check_consistency
|
||||
from ..loss import LossInterface
|
||||
from torch.nn.modules.loss import _Loss
|
||||
|
||||
|
||||
@@ -19,7 +18,18 @@ torch.pi = torch.acos(torch.zeros(1)).item() * 2 # which is 3.1415927410125732
|
||||
|
||||
|
||||
class PINN(SolverInterface):
|
||||
"""
|
||||
PINN solver class. This class implements Physics Informed Neural
|
||||
Network solvers, using a user specified ``model`` to solve a specific
|
||||
``problem``.
|
||||
|
||||
.. seealso::
|
||||
|
||||
**Original reference**: Karniadakis, G. E., Kevrekidis, I. G., Lu, L.,
|
||||
Perdikaris, P., Wang, S., & Yang, L. (2021).
|
||||
Physics-informed machine learning. Nature Reviews Physics, 3(6), 422-440.
|
||||
<https://doi.org/10.1038/s42254-021-00314-5>`_.
|
||||
"""
|
||||
def __init__(self,
|
||||
problem,
|
||||
model,
|
||||
@@ -45,20 +55,21 @@ class PINN(SolverInterface):
|
||||
rate scheduler.
|
||||
:param dict scheduler_kwargs: LR scheduler constructor keyword args.
|
||||
'''
|
||||
super().__init__(model=model, problem=problem, extra_features=extra_features)
|
||||
super().__init__(models=[model],
|
||||
problem=problem,
|
||||
optimizers=[optimizer],
|
||||
optimizers_kwargs=[optimizer_kwargs],
|
||||
extra_features=extra_features)
|
||||
|
||||
# check consistency
|
||||
check_consistency(optimizer, torch.optim.Optimizer, subclass=True)
|
||||
check_consistency(optimizer_kwargs, dict)
|
||||
check_consistency(scheduler, LRScheduler, subclass=True)
|
||||
check_consistency(scheduler_kwargs, dict)
|
||||
check_consistency(loss, (LossInterface, _Loss), subclass=False)
|
||||
|
||||
# assign variables
|
||||
self._optimizer = optimizer(self.model.parameters(), **optimizer_kwargs)
|
||||
self._scheduler = scheduler(self._optimizer, **scheduler_kwargs)
|
||||
self._scheduler = scheduler(self.optimizers[0], **scheduler_kwargs)
|
||||
self._loss = loss
|
||||
self._writer = Writer()
|
||||
self._neural_net = self.models[0]
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
@@ -72,7 +83,7 @@ class PINN(SolverInterface):
|
||||
# extract labels
|
||||
x = x.extract(self.problem.input_variables)
|
||||
# perform forward pass
|
||||
output = self.model(x).as_subclass(LabelTensor)
|
||||
output = self.neural_net(x).as_subclass(LabelTensor)
|
||||
# set the labels
|
||||
output.labels = self.problem.output_variables
|
||||
return output
|
||||
@@ -84,7 +95,7 @@ class PINN(SolverInterface):
|
||||
:return: The optimizers and the schedulers
|
||||
:rtype: tuple(list, list)
|
||||
"""
|
||||
return [self._optimizer], [self._scheduler]
|
||||
return self.optimizers, [self.scheduler]
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
"""PINN solver training step.
|
||||
@@ -109,11 +120,11 @@ class PINN(SolverInterface):
|
||||
# PINN loss: equation evaluated on location or input_points
|
||||
if hasattr(condition, 'equation'):
|
||||
target = condition.equation.residual(samples, self.forward(samples))
|
||||
loss = self._loss(torch.zeros_like(target), target)
|
||||
loss = self.loss(torch.zeros_like(target), target)
|
||||
# PINN loss: evaluate model(input_points) vs output_points
|
||||
elif hasattr(condition, 'output_points'):
|
||||
input_pts, output_pts = samples
|
||||
loss = self._loss(self.forward(input_pts), output_pts)
|
||||
loss = self.loss(self.forward(input_pts), output_pts)
|
||||
|
||||
condition_losses.append(loss * condition.data_weight)
|
||||
|
||||
@@ -121,3 +132,24 @@ class PINN(SolverInterface):
|
||||
# we need to pass it as a torch tensor to make everything work
|
||||
total_loss = sum(condition_losses)
|
||||
return total_loss
|
||||
|
||||
@property
|
||||
def scheduler(self):
|
||||
"""
|
||||
Scheduler for the PINN training.
|
||||
"""
|
||||
return self._scheduler
|
||||
|
||||
@property
|
||||
def neural_net(self):
|
||||
"""
|
||||
Neural network for the PINN training.
|
||||
"""
|
||||
return self._neural_net
|
||||
|
||||
@property
|
||||
def loss(self):
|
||||
"""
|
||||
Loss for the PINN training.
|
||||
"""
|
||||
return self._loss
|
||||
134
pina/solvers/solver.py
Normal file
134
pina/solvers/solver.py
Normal file
@@ -0,0 +1,134 @@
|
||||
""" Solver module. """
|
||||
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from ..model.network import Network
|
||||
import lightning.pytorch as pl
|
||||
from ..utils import check_consistency
|
||||
from ..problem import AbstractProblem
|
||||
import torch
|
||||
|
||||
|
||||
class SolverInterface(pl.LightningModule, metaclass=ABCMeta):
|
||||
""" Solver base class. """
|
||||
def __init__(self,
|
||||
models,
|
||||
problem,
|
||||
optimizers,
|
||||
optimizers_kwargs,
|
||||
extra_features=None):
|
||||
"""
|
||||
:param models: A torch neural network model instance.
|
||||
:type models: torch.nn.Module
|
||||
:param problem: A problem definition instance.
|
||||
:type problem: AbstractProblem
|
||||
:param list(torch.nn.Module) extra_features: the additional input
|
||||
features to use as augmented input. If ``None`` no extra features
|
||||
are passed. If it is a list of ``torch.nn.Module``, the extra feature
|
||||
list is passed to all models. If it is a list of extra features' lists,
|
||||
each single list of extra feature is passed to a model.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
# check consistency of the inputs
|
||||
check_consistency(models, torch.nn.Module)
|
||||
check_consistency(problem, AbstractProblem)
|
||||
check_consistency(optimizers, torch.optim.Optimizer, subclass=True)
|
||||
check_consistency(optimizers_kwargs, dict)
|
||||
|
||||
# put everything in a list if only one input
|
||||
if not isinstance(models, list):
|
||||
models = [models]
|
||||
if not isinstance(optimizers, list):
|
||||
optimizers = [optimizers]
|
||||
optimizers_kwargs = [optimizers_kwargs]
|
||||
|
||||
# number of models and optimizers
|
||||
len_model = len(models)
|
||||
len_optimizer = len(optimizers)
|
||||
len_optimizer_kwargs = len(optimizers_kwargs)
|
||||
|
||||
# check length consistency optimizers
|
||||
if len_model != len_optimizer:
|
||||
raise ValueError('You must define one optimizer for each model.'
|
||||
f'Got {len_model} models, and {len_optimizer}'
|
||||
' optimizers.')
|
||||
|
||||
# check length consistency optimizers kwargs
|
||||
if len_optimizer_kwargs != len_optimizer:
|
||||
raise ValueError('You must define one dictionary of keyword'
|
||||
' arguments for each optimizers.'
|
||||
f'Got {len_optimizer} optimizers, and'
|
||||
f' {len_optimizer_kwargs} dicitionaries')
|
||||
|
||||
# extra features handling
|
||||
if extra_features is None:
|
||||
extra_features = [None] * len_model
|
||||
else:
|
||||
# if we only have a list of extra features
|
||||
if not isinstance(extra_features[0], (tuple, list)):
|
||||
extra_features = [extra_features] * len_model
|
||||
else: # if we have a list of list extra features
|
||||
if len(extra_features) != len_model:
|
||||
raise ValueError('You passed a list of extrafeatures list with len'
|
||||
f'different of models len. Expected {len_model} '
|
||||
f'got {len(extra_features)}. If you want to use'
|
||||
'the same list of extra features for all models, '
|
||||
'just pass a list of extrafeatures and not a list '
|
||||
'of list of extra features.')
|
||||
|
||||
# assigning model and optimizers
|
||||
self._pina_models = []
|
||||
self._pina_optimizers = []
|
||||
|
||||
for idx in range(len_model):
|
||||
model_ = Network(model=models[idx], extra_features=extra_features[idx])
|
||||
optim_ = optimizers[idx](model_.parameters(), **optimizers_kwargs[idx])
|
||||
self._pina_models.append(model_)
|
||||
self._pina_optimizers.append(optim_)
|
||||
|
||||
# assigning problem
|
||||
self._pina_problem = problem
|
||||
|
||||
@abstractmethod
|
||||
def forward(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def training_step(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def configure_optimizers(self):
|
||||
pass
|
||||
|
||||
@property
|
||||
def models(self):
|
||||
"""
|
||||
The torch model."""
|
||||
return self._pina_models
|
||||
|
||||
@property
|
||||
def optimizers(self):
|
||||
"""
|
||||
The torch model."""
|
||||
return self._pina_optimizers
|
||||
|
||||
@property
|
||||
def problem(self):
|
||||
"""
|
||||
The problem formulation."""
|
||||
return self._pina_problem
|
||||
|
||||
# @model.setter
|
||||
# def model(self, new_model):
|
||||
# """
|
||||
# Set the torch."""
|
||||
# check_consistency(new_model, nn.Module, 'torch model')
|
||||
# self._model= new_model
|
||||
|
||||
# @problem.setter
|
||||
# def problem(self, problem):
|
||||
# """
|
||||
# Set the problem formulation."""
|
||||
# check_consistency(problem, AbstractProblem, 'pina problem')
|
||||
# self._problem = problem
|
||||
@@ -3,7 +3,7 @@
|
||||
import lightning.pytorch as pl
|
||||
from .utils import check_consistency
|
||||
from .dataset import DummyLoader
|
||||
from .solver import SolverInterface
|
||||
from .solvers.solver import SolverInterface
|
||||
|
||||
class Trainer(pl.Trainer):
|
||||
|
||||
|
||||
162
tests/test_solvers/test_garom.py
Normal file
162
tests/test_solvers/test_garom.py
Normal file
@@ -0,0 +1,162 @@
|
||||
import torch
|
||||
|
||||
from pina.problem import AbstractProblem
|
||||
from pina import Condition, LabelTensor
|
||||
from pina.solvers import GAROM
|
||||
from pina.trainer import Trainer
|
||||
import torch.nn as nn
|
||||
import matplotlib.tri as tri
|
||||
|
||||
|
||||
def func(x, mu1, mu2):
|
||||
import torch
|
||||
x_m1 = (x[:, 0] - mu1).pow(2)
|
||||
x_m2 = (x[:, 1] - mu2).pow(2)
|
||||
norm = x[:, 0]**2 + x[:, 1]**2
|
||||
return torch.exp(-(x_m1 + x_m2))
|
||||
|
||||
class ParametricGaussian(AbstractProblem):
|
||||
output_variables = [f'u_{i}' for i in range(900)]
|
||||
|
||||
# params
|
||||
xx = torch.linspace(-1, 1, 20)
|
||||
yy = xx
|
||||
params = LabelTensor(torch.cartesian_prod(xx, yy), labels=['mu1', 'mu2'])
|
||||
|
||||
# define domain
|
||||
x = torch.linspace(-1, 1, 30)
|
||||
domain = torch.cartesian_prod(x, x)
|
||||
triang = tri.Triangulation(domain[:, 0], domain[:, 1])
|
||||
sol = []
|
||||
for p in params:
|
||||
sol.append(func(domain, p[0], p[1]))
|
||||
snapshots = LabelTensor(torch.stack(sol), labels=output_variables)
|
||||
|
||||
# define conditions
|
||||
conditions = {
|
||||
'data': Condition(
|
||||
input_points=params,
|
||||
output_points=snapshots)
|
||||
}
|
||||
|
||||
# simple Generator Network
|
||||
class Generator(nn.Module):
|
||||
def __init__(self, input_dimension, parameters_dimension,
|
||||
noise_dimension, activation=torch.nn.SiLU):
|
||||
super().__init__()
|
||||
|
||||
self._noise_dimension = noise_dimension
|
||||
self._activation = activation
|
||||
|
||||
self.model = torch.nn.Sequential(
|
||||
torch.nn.Linear(6 * self._noise_dimension, input_dimension // 6),
|
||||
self._activation(),
|
||||
torch.nn.Linear(input_dimension // 6, input_dimension // 3),
|
||||
self._activation(),
|
||||
torch.nn.Linear(input_dimension // 3, input_dimension)
|
||||
)
|
||||
self.condition = torch.nn.Sequential(
|
||||
torch.nn.Linear(parameters_dimension, 2 * self._noise_dimension),
|
||||
self._activation(),
|
||||
torch.nn.Linear(2 * self._noise_dimension, 5 * self._noise_dimension)
|
||||
)
|
||||
|
||||
def forward(self, param):
|
||||
# uniform sampling in [-1, 1]
|
||||
z = torch.rand(size=(param.shape[0], self._noise_dimension),
|
||||
device=param.device,
|
||||
dtype=param.dtype,
|
||||
requires_grad=True)
|
||||
z = 2. * z - 1.
|
||||
|
||||
# conditioning by concatenation of mapped parameters
|
||||
input_ = torch.cat((z, self.condition(param)), dim=-1)
|
||||
out = self.model(input_)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
# Simple Discriminator Network
|
||||
class Discriminator(nn.Module):
|
||||
def __init__(self, input_dimension, parameter_dimension,
|
||||
hidden_dimension, activation=torch.nn.ReLU):
|
||||
super().__init__()
|
||||
|
||||
self._activation = activation
|
||||
self.encoding = torch.nn.Sequential(
|
||||
torch.nn.Linear(input_dimension, input_dimension // 3),
|
||||
self._activation(),
|
||||
torch.nn.Linear(input_dimension // 3, input_dimension // 6),
|
||||
self._activation(),
|
||||
torch.nn.Linear(input_dimension // 6, hidden_dimension)
|
||||
)
|
||||
self.decoding = torch.nn.Sequential(
|
||||
torch.nn.Linear(2*hidden_dimension, input_dimension // 6),
|
||||
self._activation(),
|
||||
torch.nn.Linear(input_dimension // 6, input_dimension // 3),
|
||||
self._activation(),
|
||||
torch.nn.Linear(input_dimension // 3, input_dimension),
|
||||
)
|
||||
|
||||
self.condition = torch.nn.Sequential(
|
||||
torch.nn.Linear(parameter_dimension, hidden_dimension // 2),
|
||||
self._activation(),
|
||||
torch.nn.Linear(hidden_dimension // 2, hidden_dimension)
|
||||
)
|
||||
|
||||
def forward(self, data):
|
||||
x, condition = data
|
||||
encoding = self.encoding(x)
|
||||
conditioning = torch.cat((encoding, self.condition(condition)), dim=-1)
|
||||
decoding = self.decoding(conditioning)
|
||||
return decoding
|
||||
|
||||
|
||||
problem = ParametricGaussian()
|
||||
|
||||
def test_constructor():
|
||||
GAROM(problem = problem,
|
||||
generator = Generator(input_dimension=900,
|
||||
parameters_dimension=2,
|
||||
noise_dimension=12),
|
||||
discriminator = Discriminator(input_dimension=900,
|
||||
parameter_dimension=2,
|
||||
hidden_dimension=64)
|
||||
)
|
||||
|
||||
def test_train_cpu():
|
||||
solver = GAROM(problem = problem,
|
||||
generator = Generator(input_dimension=900,
|
||||
parameters_dimension=2,
|
||||
noise_dimension=12),
|
||||
discriminator = Discriminator(input_dimension=900,
|
||||
parameter_dimension=2,
|
||||
hidden_dimension=64)
|
||||
)
|
||||
|
||||
trainer = Trainer(solver=solver, kwargs={'max_epochs' : 4, 'accelerator': 'cpu'})
|
||||
trainer.train()
|
||||
|
||||
def test_sample():
|
||||
solver = GAROM(problem = problem,
|
||||
generator = Generator(input_dimension=900,
|
||||
parameters_dimension=2,
|
||||
noise_dimension=12),
|
||||
discriminator = Discriminator(input_dimension=900,
|
||||
parameter_dimension=2,
|
||||
hidden_dimension=64)
|
||||
)
|
||||
solver.sample(problem.params)
|
||||
assert solver.sample(problem.params).shape == problem.snapshots.shape
|
||||
|
||||
def test_forward():
|
||||
solver = GAROM(problem = problem,
|
||||
generator = Generator(input_dimension=900,
|
||||
parameters_dimension=2,
|
||||
noise_dimension=12),
|
||||
discriminator = Discriminator(input_dimension=900,
|
||||
parameter_dimension=2,
|
||||
hidden_dimension=64)
|
||||
)
|
||||
solver(problem.params, mc_steps=100, variance=True)
|
||||
assert solver(problem.params).shape == problem.snapshots.shape
|
||||
Reference in New Issue
Block a user