From c90301c2048327a5ad9090795dec292f708f343d Mon Sep 17 00:00:00 2001 From: Dario Coscia Date: Thu, 9 Nov 2023 15:16:57 +0100 Subject: [PATCH] Network handles forward for all solvers --- pina/model/network.py | 59 +++++++++++++++++++++++++++++++++++--- pina/solvers/garom.py | 20 ++----------- pina/solvers/pinn.py | 8 +----- pina/solvers/solver.py | 4 ++- pina/solvers/supervised.py | 39 +------------------------ 5 files changed, 63 insertions(+), 67 deletions(-) diff --git a/pina/model/network.py b/pina/model/network.py index f27d5b2..0f84897 100644 --- a/pina/model/network.py +++ b/pina/model/network.py @@ -1,11 +1,12 @@ import torch import torch.nn as nn from ..utils import check_consistency +from ..label_tensor import LabelTensor class Network(torch.nn.Module): - def __init__(self, model, extra_features=None): + def __init__(self, model, input_variables, output_variables, extra_features=None): """ Network class with standard forward method and possibility to pass extra features. This @@ -14,6 +15,10 @@ class Network(torch.nn.Module): :param model: The torch model to convert in a PINA model. :type model: torch.nn.Module + :param list(str) input_variables: The input variables of the :class:`AbstractProblem`, whose type depends on the + type of domain (spatial, temporal, and parameter). + :param list(str) output_variables: The output variables of the :class:`AbstractProblem`, whose type depends on the + problem setting. :param extra_features: List of torch models to augment the input, defaults to None. :type extra_features: list(torch.nn.Module) """ @@ -21,7 +26,12 @@ class Network(torch.nn.Module): # check model consistency check_consistency(model, nn.Module) + check_consistency(input_variables, str) + check_consistency(output_variables, str) + self._model = model + self._input_variables = input_variables + self._output_variables = output_variables # check consistency and assign extra fatures if extra_features is None: @@ -46,14 +56,55 @@ class Network(torch.nn.Module): :param torch.Tensor x: Input of the network. :return torch.Tensor: Output of the network. """ + # extract torch.Tensor from corresponding label + # in case `input_variables = []` all points are used + if self._input_variables: + x = x.extract(self._input_variables) + # extract features and append for feature in self._extra_features: x = x.append(feature(x)) - # perform forward pass - return self._model(x) + + # convert LabelTensor to torch.Tensor + x = x.as_subclass(torch.Tensor) + + # perform forward pass (using torch.Tensor) + converting to LabelTensor + output = self._model(x).as_subclass(LabelTensor) + + # set the labels for LabelTensor + output.labels = self._output_variables + + return output + + def forward_map(self, x): + """ + Forward method for Network class when the input is + a tuple. This class implements the standard forward method, + and it adds the possibility to pass extra features. + All the PINA models ``forward`` s are overriden + by this class, to enable :class:`pina.label_tensor.LabelTensor` labels + extraction. + + :param list (torch.Tensor) | tuple(torch.Tensor) x: Input of the network. + :return torch.Tensor: Output of the network. + + .. note:: + This function does not extract the input variables, all the variables + are used for both tensors. Output variables are correctly applied. + """ + # convert LabelTensor s to torch.Tensor s + x = list(map(lambda x: x.as_subclass(torch.Tensor), x)) + + # perform forward pass (using torch.Tensor) + converting to LabelTensor + output = self._model(x).as_subclass(LabelTensor) + + # set the labels for LabelTensor + output.labels = self._output_variables + + return output @property - def model(self): + def torchmodel(self): return self._model @property diff --git a/pina/solvers/garom.py b/pina/solvers/garom.py index 972f156..4e41af8 100644 --- a/pina/solvers/garom.py +++ b/pina/solvers/garom.py @@ -32,7 +32,6 @@ class GAROM(SolverInterface): problem, generator, discriminator, - extra_features=None, loss=None, optimizer_generator=torch.optim.Adam, optimizer_generator_kwargs={'lr': 0.001}, @@ -58,13 +57,6 @@ class GAROM(SolverInterface): for the generator. :param torch.nn.Module discriminator: The neural network model to use for the discriminator. - :param torch.nn.Module extra_features: The additional input - features to use as augmented input. It should either be a - list of torch.nn.Module, or a dictionary. If a list it is - passed the extra features are passed to both network. If a - dictionary is passed, the keys must be ``generator`` and - ``discriminator`` and the values a list of torch.nn.Module - extra features for each. :param torch.nn.Module loss: The loss function used as minimizer, default ``None``. If ``loss`` is ``None`` the defualt ``PowerLoss(p=1)`` is used, as in the original paper. @@ -97,15 +89,9 @@ class GAROM(SolverInterface): parameters), and ``output_points``. """ - if isinstance(extra_features, dict): - extra_features = [ - extra_features['generator'], extra_features['discriminator'] - ] - super().__init__( models=[generator, discriminator], problem=problem, - extra_features=extra_features, optimizers=[optimizer_generator, optimizer_discriminator], optimizers_kwargs=[ optimizer_generator_kwargs, optimizer_discriminator_kwargs @@ -200,7 +186,7 @@ class GAROM(SolverInterface): # generator loss r_loss = self._loss(snapshots, generated_snapshots) - d_fake = self.discriminator([generated_snapshots, parameters]) + d_fake = self.discriminator.forward_map([generated_snapshots, parameters]) g_loss = self._loss(d_fake, generated_snapshots) + self.regularizer * r_loss # backward step @@ -220,8 +206,8 @@ class GAROM(SolverInterface): generated_snapshots = self.generator(parameters) # Discriminator pass - d_real = self.discriminator([snapshots, parameters]) - d_fake = self.discriminator([generated_snapshots, parameters]) + d_real = self.discriminator.forward_map([snapshots, parameters]) + d_fake = self.discriminator.forward_map([generated_snapshots, parameters]) # evaluate loss d_loss_real = self._loss(d_real, snapshots) diff --git a/pina/solvers/pinn.py b/pina/solvers/pinn.py index 746cae7..a1581a4 100644 --- a/pina/solvers/pinn.py +++ b/pina/solvers/pinn.py @@ -83,13 +83,7 @@ class PINN(SolverInterface): :return: PINN solution. :rtype: torch.Tensor """ - # extract torch.Tensor from corresponding label - x = x.extract(self.problem.input_variables).as_subclass(torch.Tensor) - # perform forward pass (using torch.Tensor) + converting to LabelTensor - output = self.neural_net(x).as_subclass(LabelTensor) - # set the labels for LabelTensor - output.labels = self.problem.output_variables - return output + return self.neural_net(x) def configure_optimizers(self): """ diff --git a/pina/solvers/solver.py b/pina/solvers/solver.py index 698e088..e062eda 100644 --- a/pina/solvers/solver.py +++ b/pina/solvers/solver.py @@ -80,7 +80,7 @@ class SolverInterface(pytorch_lightning.LightningModule, metaclass=ABCMeta): raise ValueError( 'You passed a list of extrafeatures list with len' f'different of models len. Expected {len_model} ' - f'got {len(extra_features)}. If you want to use' + f'got {len(extra_features)}. If you want to use ' 'the same list of extra features for all models, ' 'just pass a list of extrafeatures and not a list ' 'of list of extra features.') @@ -91,6 +91,8 @@ class SolverInterface(pytorch_lightning.LightningModule, metaclass=ABCMeta): for idx in range(len_model): model_ = Network(model=models[idx], + input_variables=problem.input_variables, + output_variables=problem.output_variables, extra_features=extra_features[idx]) optim_ = optimizers[idx](model_.parameters(), **optimizers_kwargs[idx]) diff --git a/pina/solvers/supervised.py b/pina/solvers/supervised.py index fb3df56..c98146b 100644 --- a/pina/solvers/supervised.py +++ b/pina/solvers/supervised.py @@ -72,13 +72,7 @@ class SupervisedSolver(SolverInterface): :return: Solver solution. :rtype: torch.Tensor """ - # extract torch.Tensor from corresponding label - x = x.extract(self.problem.input_variables).as_subclass(torch.Tensor) - # perform forward pass (using torch.Tensor) + converting to LabelTensor - output = self.neural_net(x).as_subclass(LabelTensor) - # set the labels for LabelTensor - output.labels = self.problem.output_variables - return output + return self.neural_net(x) def configure_optimizers(self): """Optimizer configuration for the solver. @@ -125,37 +119,6 @@ class SupervisedSolver(SolverInterface): self.log('mean_loss', float(loss), prog_bar=True, logger=True) return loss - - def training_step_(self, batch, batch_idx): - """Solver training step. - - :param batch: The batch element in the dataloader. - :type batch: tuple - :param batch_idx: The batch index. - :type batch_idx: int - :return: The sum of the loss functions. - :rtype: LabelTensor - """ - - for condition_name, samples in batch.items(): - - if condition_name not in self.problem.conditions: - raise RuntimeError('Something wrong happened.') - - condition = self.problem.conditions[condition_name] - - # data loss - if hasattr(condition, 'output_points'): - input_pts, output_pts = samples - loss = self.loss(self.forward(input_pts), - output_pts) * condition.data_weight - else: - raise RuntimeError( - 'Supervised solver works only in data-driven mode.') - - self.log('mean_loss', float(loss), prog_bar=True, logger=True) - return loss - @property def scheduler(self): """