Network handles forward for all solvers
This commit is contained in:
committed by
Nicola Demo
parent
4844640727
commit
c90301c204
@@ -32,7 +32,6 @@ class GAROM(SolverInterface):
|
||||
problem,
|
||||
generator,
|
||||
discriminator,
|
||||
extra_features=None,
|
||||
loss=None,
|
||||
optimizer_generator=torch.optim.Adam,
|
||||
optimizer_generator_kwargs={'lr': 0.001},
|
||||
@@ -58,13 +57,6 @@ class GAROM(SolverInterface):
|
||||
for the generator.
|
||||
:param torch.nn.Module discriminator: The neural network model to use
|
||||
for the discriminator.
|
||||
:param torch.nn.Module extra_features: The additional input
|
||||
features to use as augmented input. It should either be a
|
||||
list of torch.nn.Module, or a dictionary. If a list it is
|
||||
passed the extra features are passed to both network. If a
|
||||
dictionary is passed, the keys must be ``generator`` and
|
||||
``discriminator`` and the values a list of torch.nn.Module
|
||||
extra features for each.
|
||||
:param torch.nn.Module loss: The loss function used as minimizer,
|
||||
default ``None``. If ``loss`` is ``None`` the defualt
|
||||
``PowerLoss(p=1)`` is used, as in the original paper.
|
||||
@@ -97,15 +89,9 @@ class GAROM(SolverInterface):
|
||||
parameters), and ``output_points``.
|
||||
"""
|
||||
|
||||
if isinstance(extra_features, dict):
|
||||
extra_features = [
|
||||
extra_features['generator'], extra_features['discriminator']
|
||||
]
|
||||
|
||||
super().__init__(
|
||||
models=[generator, discriminator],
|
||||
problem=problem,
|
||||
extra_features=extra_features,
|
||||
optimizers=[optimizer_generator, optimizer_discriminator],
|
||||
optimizers_kwargs=[
|
||||
optimizer_generator_kwargs, optimizer_discriminator_kwargs
|
||||
@@ -200,7 +186,7 @@ class GAROM(SolverInterface):
|
||||
|
||||
# generator loss
|
||||
r_loss = self._loss(snapshots, generated_snapshots)
|
||||
d_fake = self.discriminator([generated_snapshots, parameters])
|
||||
d_fake = self.discriminator.forward_map([generated_snapshots, parameters])
|
||||
g_loss = self._loss(d_fake, generated_snapshots) + self.regularizer * r_loss
|
||||
|
||||
# backward step
|
||||
@@ -220,8 +206,8 @@ class GAROM(SolverInterface):
|
||||
generated_snapshots = self.generator(parameters)
|
||||
|
||||
# Discriminator pass
|
||||
d_real = self.discriminator([snapshots, parameters])
|
||||
d_fake = self.discriminator([generated_snapshots, parameters])
|
||||
d_real = self.discriminator.forward_map([snapshots, parameters])
|
||||
d_fake = self.discriminator.forward_map([generated_snapshots, parameters])
|
||||
|
||||
# evaluate loss
|
||||
d_loss_real = self._loss(d_real, snapshots)
|
||||
|
||||
Reference in New Issue
Block a user