Network handles forward for all solvers

This commit is contained in:
Dario Coscia
2023-11-09 15:16:57 +01:00
committed by Nicola Demo
parent 4844640727
commit c90301c204
5 changed files with 63 additions and 67 deletions

View File

@@ -1,11 +1,12 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
from ..utils import check_consistency from ..utils import check_consistency
from ..label_tensor import LabelTensor
class Network(torch.nn.Module): class Network(torch.nn.Module):
def __init__(self, model, extra_features=None): def __init__(self, model, input_variables, output_variables, extra_features=None):
""" """
Network class with standard forward method Network class with standard forward method
and possibility to pass extra features. This and possibility to pass extra features. This
@@ -14,6 +15,10 @@ class Network(torch.nn.Module):
:param model: The torch model to convert in a PINA model. :param model: The torch model to convert in a PINA model.
:type model: torch.nn.Module :type model: torch.nn.Module
:param list(str) input_variables: The input variables of the :class:`AbstractProblem`, whose type depends on the
type of domain (spatial, temporal, and parameter).
:param list(str) output_variables: The output variables of the :class:`AbstractProblem`, whose type depends on the
problem setting.
:param extra_features: List of torch models to augment the input, defaults to None. :param extra_features: List of torch models to augment the input, defaults to None.
:type extra_features: list(torch.nn.Module) :type extra_features: list(torch.nn.Module)
""" """
@@ -21,7 +26,12 @@ class Network(torch.nn.Module):
# check model consistency # check model consistency
check_consistency(model, nn.Module) check_consistency(model, nn.Module)
check_consistency(input_variables, str)
check_consistency(output_variables, str)
self._model = model self._model = model
self._input_variables = input_variables
self._output_variables = output_variables
# check consistency and assign extra fatures # check consistency and assign extra fatures
if extra_features is None: if extra_features is None:
@@ -46,14 +56,55 @@ class Network(torch.nn.Module):
:param torch.Tensor x: Input of the network. :param torch.Tensor x: Input of the network.
:return torch.Tensor: Output of the network. :return torch.Tensor: Output of the network.
""" """
# extract torch.Tensor from corresponding label
# in case `input_variables = []` all points are used
if self._input_variables:
x = x.extract(self._input_variables)
# extract features and append # extract features and append
for feature in self._extra_features: for feature in self._extra_features:
x = x.append(feature(x)) x = x.append(feature(x))
# perform forward pass
return self._model(x) # convert LabelTensor to torch.Tensor
x = x.as_subclass(torch.Tensor)
# perform forward pass (using torch.Tensor) + converting to LabelTensor
output = self._model(x).as_subclass(LabelTensor)
# set the labels for LabelTensor
output.labels = self._output_variables
return output
def forward_map(self, x):
"""
Forward method for Network class when the input is
a tuple. This class implements the standard forward method,
and it adds the possibility to pass extra features.
All the PINA models ``forward`` s are overriden
by this class, to enable :class:`pina.label_tensor.LabelTensor` labels
extraction.
:param list (torch.Tensor) | tuple(torch.Tensor) x: Input of the network.
:return torch.Tensor: Output of the network.
.. note::
This function does not extract the input variables, all the variables
are used for both tensors. Output variables are correctly applied.
"""
# convert LabelTensor s to torch.Tensor s
x = list(map(lambda x: x.as_subclass(torch.Tensor), x))
# perform forward pass (using torch.Tensor) + converting to LabelTensor
output = self._model(x).as_subclass(LabelTensor)
# set the labels for LabelTensor
output.labels = self._output_variables
return output
@property @property
def model(self): def torchmodel(self):
return self._model return self._model
@property @property

View File

@@ -32,7 +32,6 @@ class GAROM(SolverInterface):
problem, problem,
generator, generator,
discriminator, discriminator,
extra_features=None,
loss=None, loss=None,
optimizer_generator=torch.optim.Adam, optimizer_generator=torch.optim.Adam,
optimizer_generator_kwargs={'lr': 0.001}, optimizer_generator_kwargs={'lr': 0.001},
@@ -58,13 +57,6 @@ class GAROM(SolverInterface):
for the generator. for the generator.
:param torch.nn.Module discriminator: The neural network model to use :param torch.nn.Module discriminator: The neural network model to use
for the discriminator. for the discriminator.
:param torch.nn.Module extra_features: The additional input
features to use as augmented input. It should either be a
list of torch.nn.Module, or a dictionary. If a list it is
passed the extra features are passed to both network. If a
dictionary is passed, the keys must be ``generator`` and
``discriminator`` and the values a list of torch.nn.Module
extra features for each.
:param torch.nn.Module loss: The loss function used as minimizer, :param torch.nn.Module loss: The loss function used as minimizer,
default ``None``. If ``loss`` is ``None`` the defualt default ``None``. If ``loss`` is ``None`` the defualt
``PowerLoss(p=1)`` is used, as in the original paper. ``PowerLoss(p=1)`` is used, as in the original paper.
@@ -97,15 +89,9 @@ class GAROM(SolverInterface):
parameters), and ``output_points``. parameters), and ``output_points``.
""" """
if isinstance(extra_features, dict):
extra_features = [
extra_features['generator'], extra_features['discriminator']
]
super().__init__( super().__init__(
models=[generator, discriminator], models=[generator, discriminator],
problem=problem, problem=problem,
extra_features=extra_features,
optimizers=[optimizer_generator, optimizer_discriminator], optimizers=[optimizer_generator, optimizer_discriminator],
optimizers_kwargs=[ optimizers_kwargs=[
optimizer_generator_kwargs, optimizer_discriminator_kwargs optimizer_generator_kwargs, optimizer_discriminator_kwargs
@@ -200,7 +186,7 @@ class GAROM(SolverInterface):
# generator loss # generator loss
r_loss = self._loss(snapshots, generated_snapshots) r_loss = self._loss(snapshots, generated_snapshots)
d_fake = self.discriminator([generated_snapshots, parameters]) d_fake = self.discriminator.forward_map([generated_snapshots, parameters])
g_loss = self._loss(d_fake, generated_snapshots) + self.regularizer * r_loss g_loss = self._loss(d_fake, generated_snapshots) + self.regularizer * r_loss
# backward step # backward step
@@ -220,8 +206,8 @@ class GAROM(SolverInterface):
generated_snapshots = self.generator(parameters) generated_snapshots = self.generator(parameters)
# Discriminator pass # Discriminator pass
d_real = self.discriminator([snapshots, parameters]) d_real = self.discriminator.forward_map([snapshots, parameters])
d_fake = self.discriminator([generated_snapshots, parameters]) d_fake = self.discriminator.forward_map([generated_snapshots, parameters])
# evaluate loss # evaluate loss
d_loss_real = self._loss(d_real, snapshots) d_loss_real = self._loss(d_real, snapshots)

View File

@@ -83,13 +83,7 @@ class PINN(SolverInterface):
:return: PINN solution. :return: PINN solution.
:rtype: torch.Tensor :rtype: torch.Tensor
""" """
# extract torch.Tensor from corresponding label return self.neural_net(x)
x = x.extract(self.problem.input_variables).as_subclass(torch.Tensor)
# perform forward pass (using torch.Tensor) + converting to LabelTensor
output = self.neural_net(x).as_subclass(LabelTensor)
# set the labels for LabelTensor
output.labels = self.problem.output_variables
return output
def configure_optimizers(self): def configure_optimizers(self):
""" """

View File

@@ -80,7 +80,7 @@ class SolverInterface(pytorch_lightning.LightningModule, metaclass=ABCMeta):
raise ValueError( raise ValueError(
'You passed a list of extrafeatures list with len' 'You passed a list of extrafeatures list with len'
f'different of models len. Expected {len_model} ' f'different of models len. Expected {len_model} '
f'got {len(extra_features)}. If you want to use' f'got {len(extra_features)}. If you want to use '
'the same list of extra features for all models, ' 'the same list of extra features for all models, '
'just pass a list of extrafeatures and not a list ' 'just pass a list of extrafeatures and not a list '
'of list of extra features.') 'of list of extra features.')
@@ -91,6 +91,8 @@ class SolverInterface(pytorch_lightning.LightningModule, metaclass=ABCMeta):
for idx in range(len_model): for idx in range(len_model):
model_ = Network(model=models[idx], model_ = Network(model=models[idx],
input_variables=problem.input_variables,
output_variables=problem.output_variables,
extra_features=extra_features[idx]) extra_features=extra_features[idx])
optim_ = optimizers[idx](model_.parameters(), optim_ = optimizers[idx](model_.parameters(),
**optimizers_kwargs[idx]) **optimizers_kwargs[idx])

View File

@@ -72,13 +72,7 @@ class SupervisedSolver(SolverInterface):
:return: Solver solution. :return: Solver solution.
:rtype: torch.Tensor :rtype: torch.Tensor
""" """
# extract torch.Tensor from corresponding label return self.neural_net(x)
x = x.extract(self.problem.input_variables).as_subclass(torch.Tensor)
# perform forward pass (using torch.Tensor) + converting to LabelTensor
output = self.neural_net(x).as_subclass(LabelTensor)
# set the labels for LabelTensor
output.labels = self.problem.output_variables
return output
def configure_optimizers(self): def configure_optimizers(self):
"""Optimizer configuration for the solver. """Optimizer configuration for the solver.
@@ -125,37 +119,6 @@ class SupervisedSolver(SolverInterface):
self.log('mean_loss', float(loss), prog_bar=True, logger=True) self.log('mean_loss', float(loss), prog_bar=True, logger=True)
return loss return loss
def training_step_(self, batch, batch_idx):
"""Solver training step.
:param batch: The batch element in the dataloader.
:type batch: tuple
:param batch_idx: The batch index.
:type batch_idx: int
:return: The sum of the loss functions.
:rtype: LabelTensor
"""
for condition_name, samples in batch.items():
if condition_name not in self.problem.conditions:
raise RuntimeError('Something wrong happened.')
condition = self.problem.conditions[condition_name]
# data loss
if hasattr(condition, 'output_points'):
input_pts, output_pts = samples
loss = self.loss(self.forward(input_pts),
output_pts) * condition.data_weight
else:
raise RuntimeError(
'Supervised solver works only in data-driven mode.')
self.log('mean_loss', float(loss), prog_bar=True, logger=True)
return loss
@property @property
def scheduler(self): def scheduler(self):
""" """