Network handles forward for all solvers
This commit is contained in:
committed by
Nicola Demo
parent
4844640727
commit
c90301c204
@@ -1,11 +1,12 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from ..utils import check_consistency
|
from ..utils import check_consistency
|
||||||
|
from ..label_tensor import LabelTensor
|
||||||
|
|
||||||
|
|
||||||
class Network(torch.nn.Module):
|
class Network(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self, model, extra_features=None):
|
def __init__(self, model, input_variables, output_variables, extra_features=None):
|
||||||
"""
|
"""
|
||||||
Network class with standard forward method
|
Network class with standard forward method
|
||||||
and possibility to pass extra features. This
|
and possibility to pass extra features. This
|
||||||
@@ -14,6 +15,10 @@ class Network(torch.nn.Module):
|
|||||||
|
|
||||||
:param model: The torch model to convert in a PINA model.
|
:param model: The torch model to convert in a PINA model.
|
||||||
:type model: torch.nn.Module
|
:type model: torch.nn.Module
|
||||||
|
:param list(str) input_variables: The input variables of the :class:`AbstractProblem`, whose type depends on the
|
||||||
|
type of domain (spatial, temporal, and parameter).
|
||||||
|
:param list(str) output_variables: The output variables of the :class:`AbstractProblem`, whose type depends on the
|
||||||
|
problem setting.
|
||||||
:param extra_features: List of torch models to augment the input, defaults to None.
|
:param extra_features: List of torch models to augment the input, defaults to None.
|
||||||
:type extra_features: list(torch.nn.Module)
|
:type extra_features: list(torch.nn.Module)
|
||||||
"""
|
"""
|
||||||
@@ -21,7 +26,12 @@ class Network(torch.nn.Module):
|
|||||||
|
|
||||||
# check model consistency
|
# check model consistency
|
||||||
check_consistency(model, nn.Module)
|
check_consistency(model, nn.Module)
|
||||||
|
check_consistency(input_variables, str)
|
||||||
|
check_consistency(output_variables, str)
|
||||||
|
|
||||||
self._model = model
|
self._model = model
|
||||||
|
self._input_variables = input_variables
|
||||||
|
self._output_variables = output_variables
|
||||||
|
|
||||||
# check consistency and assign extra fatures
|
# check consistency and assign extra fatures
|
||||||
if extra_features is None:
|
if extra_features is None:
|
||||||
@@ -46,14 +56,55 @@ class Network(torch.nn.Module):
|
|||||||
:param torch.Tensor x: Input of the network.
|
:param torch.Tensor x: Input of the network.
|
||||||
:return torch.Tensor: Output of the network.
|
:return torch.Tensor: Output of the network.
|
||||||
"""
|
"""
|
||||||
|
# extract torch.Tensor from corresponding label
|
||||||
|
# in case `input_variables = []` all points are used
|
||||||
|
if self._input_variables:
|
||||||
|
x = x.extract(self._input_variables)
|
||||||
|
|
||||||
# extract features and append
|
# extract features and append
|
||||||
for feature in self._extra_features:
|
for feature in self._extra_features:
|
||||||
x = x.append(feature(x))
|
x = x.append(feature(x))
|
||||||
# perform forward pass
|
|
||||||
return self._model(x)
|
# convert LabelTensor to torch.Tensor
|
||||||
|
x = x.as_subclass(torch.Tensor)
|
||||||
|
|
||||||
|
# perform forward pass (using torch.Tensor) + converting to LabelTensor
|
||||||
|
output = self._model(x).as_subclass(LabelTensor)
|
||||||
|
|
||||||
|
# set the labels for LabelTensor
|
||||||
|
output.labels = self._output_variables
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
def forward_map(self, x):
|
||||||
|
"""
|
||||||
|
Forward method for Network class when the input is
|
||||||
|
a tuple. This class implements the standard forward method,
|
||||||
|
and it adds the possibility to pass extra features.
|
||||||
|
All the PINA models ``forward`` s are overriden
|
||||||
|
by this class, to enable :class:`pina.label_tensor.LabelTensor` labels
|
||||||
|
extraction.
|
||||||
|
|
||||||
|
:param list (torch.Tensor) | tuple(torch.Tensor) x: Input of the network.
|
||||||
|
:return torch.Tensor: Output of the network.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
This function does not extract the input variables, all the variables
|
||||||
|
are used for both tensors. Output variables are correctly applied.
|
||||||
|
"""
|
||||||
|
# convert LabelTensor s to torch.Tensor s
|
||||||
|
x = list(map(lambda x: x.as_subclass(torch.Tensor), x))
|
||||||
|
|
||||||
|
# perform forward pass (using torch.Tensor) + converting to LabelTensor
|
||||||
|
output = self._model(x).as_subclass(LabelTensor)
|
||||||
|
|
||||||
|
# set the labels for LabelTensor
|
||||||
|
output.labels = self._output_variables
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def model(self):
|
def torchmodel(self):
|
||||||
return self._model
|
return self._model
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|||||||
@@ -32,7 +32,6 @@ class GAROM(SolverInterface):
|
|||||||
problem,
|
problem,
|
||||||
generator,
|
generator,
|
||||||
discriminator,
|
discriminator,
|
||||||
extra_features=None,
|
|
||||||
loss=None,
|
loss=None,
|
||||||
optimizer_generator=torch.optim.Adam,
|
optimizer_generator=torch.optim.Adam,
|
||||||
optimizer_generator_kwargs={'lr': 0.001},
|
optimizer_generator_kwargs={'lr': 0.001},
|
||||||
@@ -58,13 +57,6 @@ class GAROM(SolverInterface):
|
|||||||
for the generator.
|
for the generator.
|
||||||
:param torch.nn.Module discriminator: The neural network model to use
|
:param torch.nn.Module discriminator: The neural network model to use
|
||||||
for the discriminator.
|
for the discriminator.
|
||||||
:param torch.nn.Module extra_features: The additional input
|
|
||||||
features to use as augmented input. It should either be a
|
|
||||||
list of torch.nn.Module, or a dictionary. If a list it is
|
|
||||||
passed the extra features are passed to both network. If a
|
|
||||||
dictionary is passed, the keys must be ``generator`` and
|
|
||||||
``discriminator`` and the values a list of torch.nn.Module
|
|
||||||
extra features for each.
|
|
||||||
:param torch.nn.Module loss: The loss function used as minimizer,
|
:param torch.nn.Module loss: The loss function used as minimizer,
|
||||||
default ``None``. If ``loss`` is ``None`` the defualt
|
default ``None``. If ``loss`` is ``None`` the defualt
|
||||||
``PowerLoss(p=1)`` is used, as in the original paper.
|
``PowerLoss(p=1)`` is used, as in the original paper.
|
||||||
@@ -97,15 +89,9 @@ class GAROM(SolverInterface):
|
|||||||
parameters), and ``output_points``.
|
parameters), and ``output_points``.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if isinstance(extra_features, dict):
|
|
||||||
extra_features = [
|
|
||||||
extra_features['generator'], extra_features['discriminator']
|
|
||||||
]
|
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
models=[generator, discriminator],
|
models=[generator, discriminator],
|
||||||
problem=problem,
|
problem=problem,
|
||||||
extra_features=extra_features,
|
|
||||||
optimizers=[optimizer_generator, optimizer_discriminator],
|
optimizers=[optimizer_generator, optimizer_discriminator],
|
||||||
optimizers_kwargs=[
|
optimizers_kwargs=[
|
||||||
optimizer_generator_kwargs, optimizer_discriminator_kwargs
|
optimizer_generator_kwargs, optimizer_discriminator_kwargs
|
||||||
@@ -200,7 +186,7 @@ class GAROM(SolverInterface):
|
|||||||
|
|
||||||
# generator loss
|
# generator loss
|
||||||
r_loss = self._loss(snapshots, generated_snapshots)
|
r_loss = self._loss(snapshots, generated_snapshots)
|
||||||
d_fake = self.discriminator([generated_snapshots, parameters])
|
d_fake = self.discriminator.forward_map([generated_snapshots, parameters])
|
||||||
g_loss = self._loss(d_fake, generated_snapshots) + self.regularizer * r_loss
|
g_loss = self._loss(d_fake, generated_snapshots) + self.regularizer * r_loss
|
||||||
|
|
||||||
# backward step
|
# backward step
|
||||||
@@ -220,8 +206,8 @@ class GAROM(SolverInterface):
|
|||||||
generated_snapshots = self.generator(parameters)
|
generated_snapshots = self.generator(parameters)
|
||||||
|
|
||||||
# Discriminator pass
|
# Discriminator pass
|
||||||
d_real = self.discriminator([snapshots, parameters])
|
d_real = self.discriminator.forward_map([snapshots, parameters])
|
||||||
d_fake = self.discriminator([generated_snapshots, parameters])
|
d_fake = self.discriminator.forward_map([generated_snapshots, parameters])
|
||||||
|
|
||||||
# evaluate loss
|
# evaluate loss
|
||||||
d_loss_real = self._loss(d_real, snapshots)
|
d_loss_real = self._loss(d_real, snapshots)
|
||||||
|
|||||||
@@ -83,13 +83,7 @@ class PINN(SolverInterface):
|
|||||||
:return: PINN solution.
|
:return: PINN solution.
|
||||||
:rtype: torch.Tensor
|
:rtype: torch.Tensor
|
||||||
"""
|
"""
|
||||||
# extract torch.Tensor from corresponding label
|
return self.neural_net(x)
|
||||||
x = x.extract(self.problem.input_variables).as_subclass(torch.Tensor)
|
|
||||||
# perform forward pass (using torch.Tensor) + converting to LabelTensor
|
|
||||||
output = self.neural_net(x).as_subclass(LabelTensor)
|
|
||||||
# set the labels for LabelTensor
|
|
||||||
output.labels = self.problem.output_variables
|
|
||||||
return output
|
|
||||||
|
|
||||||
def configure_optimizers(self):
|
def configure_optimizers(self):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -80,7 +80,7 @@ class SolverInterface(pytorch_lightning.LightningModule, metaclass=ABCMeta):
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
'You passed a list of extrafeatures list with len'
|
'You passed a list of extrafeatures list with len'
|
||||||
f'different of models len. Expected {len_model} '
|
f'different of models len. Expected {len_model} '
|
||||||
f'got {len(extra_features)}. If you want to use'
|
f'got {len(extra_features)}. If you want to use '
|
||||||
'the same list of extra features for all models, '
|
'the same list of extra features for all models, '
|
||||||
'just pass a list of extrafeatures and not a list '
|
'just pass a list of extrafeatures and not a list '
|
||||||
'of list of extra features.')
|
'of list of extra features.')
|
||||||
@@ -91,6 +91,8 @@ class SolverInterface(pytorch_lightning.LightningModule, metaclass=ABCMeta):
|
|||||||
|
|
||||||
for idx in range(len_model):
|
for idx in range(len_model):
|
||||||
model_ = Network(model=models[idx],
|
model_ = Network(model=models[idx],
|
||||||
|
input_variables=problem.input_variables,
|
||||||
|
output_variables=problem.output_variables,
|
||||||
extra_features=extra_features[idx])
|
extra_features=extra_features[idx])
|
||||||
optim_ = optimizers[idx](model_.parameters(),
|
optim_ = optimizers[idx](model_.parameters(),
|
||||||
**optimizers_kwargs[idx])
|
**optimizers_kwargs[idx])
|
||||||
|
|||||||
@@ -72,13 +72,7 @@ class SupervisedSolver(SolverInterface):
|
|||||||
:return: Solver solution.
|
:return: Solver solution.
|
||||||
:rtype: torch.Tensor
|
:rtype: torch.Tensor
|
||||||
"""
|
"""
|
||||||
# extract torch.Tensor from corresponding label
|
return self.neural_net(x)
|
||||||
x = x.extract(self.problem.input_variables).as_subclass(torch.Tensor)
|
|
||||||
# perform forward pass (using torch.Tensor) + converting to LabelTensor
|
|
||||||
output = self.neural_net(x).as_subclass(LabelTensor)
|
|
||||||
# set the labels for LabelTensor
|
|
||||||
output.labels = self.problem.output_variables
|
|
||||||
return output
|
|
||||||
|
|
||||||
def configure_optimizers(self):
|
def configure_optimizers(self):
|
||||||
"""Optimizer configuration for the solver.
|
"""Optimizer configuration for the solver.
|
||||||
@@ -125,37 +119,6 @@ class SupervisedSolver(SolverInterface):
|
|||||||
self.log('mean_loss', float(loss), prog_bar=True, logger=True)
|
self.log('mean_loss', float(loss), prog_bar=True, logger=True)
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
|
||||||
def training_step_(self, batch, batch_idx):
|
|
||||||
"""Solver training step.
|
|
||||||
|
|
||||||
:param batch: The batch element in the dataloader.
|
|
||||||
:type batch: tuple
|
|
||||||
:param batch_idx: The batch index.
|
|
||||||
:type batch_idx: int
|
|
||||||
:return: The sum of the loss functions.
|
|
||||||
:rtype: LabelTensor
|
|
||||||
"""
|
|
||||||
|
|
||||||
for condition_name, samples in batch.items():
|
|
||||||
|
|
||||||
if condition_name not in self.problem.conditions:
|
|
||||||
raise RuntimeError('Something wrong happened.')
|
|
||||||
|
|
||||||
condition = self.problem.conditions[condition_name]
|
|
||||||
|
|
||||||
# data loss
|
|
||||||
if hasattr(condition, 'output_points'):
|
|
||||||
input_pts, output_pts = samples
|
|
||||||
loss = self.loss(self.forward(input_pts),
|
|
||||||
output_pts) * condition.data_weight
|
|
||||||
else:
|
|
||||||
raise RuntimeError(
|
|
||||||
'Supervised solver works only in data-driven mode.')
|
|
||||||
|
|
||||||
self.log('mean_loss', float(loss), prog_bar=True, logger=True)
|
|
||||||
return loss
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def scheduler(self):
|
def scheduler(self):
|
||||||
"""
|
"""
|
||||||
|
|||||||
Reference in New Issue
Block a user