fixing competitive pinn (#470)
This commit is contained in:
committed by
Nicola Demo
parent
d68b51fce3
commit
4c3e305b09
@@ -3,7 +3,7 @@
|
||||
import torch
|
||||
import copy
|
||||
|
||||
from pina.problem import InverseProblem
|
||||
from ...problem import InverseProblem
|
||||
from .pinn_interface import PINNInterface
|
||||
from ..solver import MultiSolverInterface
|
||||
|
||||
@@ -125,10 +125,15 @@ class CompetitivePINN(PINNInterface, MultiSolverInterface):
|
||||
:return: The sum of the loss functions.
|
||||
:rtype: LabelTensor
|
||||
"""
|
||||
# train model
|
||||
self.optimizer_model.instance.zero_grad()
|
||||
loss = super().training_step(batch)
|
||||
self.manual_backward(loss)
|
||||
self.optimizer_model.instance.step()
|
||||
# train discriminator
|
||||
self.optimizer_discriminator.instance.zero_grad()
|
||||
loss = super().training_step(batch)
|
||||
self.optimizer_model.instance.step()
|
||||
self.manual_backward(-loss)
|
||||
self.optimizer_discriminator.instance.step()
|
||||
return loss
|
||||
|
||||
@@ -144,20 +149,18 @@ class CompetitivePINN(PINNInterface, MultiSolverInterface):
|
||||
samples and equation.
|
||||
:rtype: LabelTensor
|
||||
"""
|
||||
# Train the model for one step
|
||||
with torch.no_grad():
|
||||
# Compute discriminator bets
|
||||
discriminator_bets = self.discriminator(samples)
|
||||
loss_val = self._train_model(samples, equation, discriminator_bets)
|
||||
|
||||
# Detach samples from the existing computational graph and
|
||||
# create a new one by setting requires_grad to True.
|
||||
# In alternative set `retain_graph=True`.
|
||||
samples = samples.detach()
|
||||
samples.requires_grad_()
|
||||
# Compute residual and multiply discriminator_bets
|
||||
residual = self.compute_residual(samples=samples, equation=equation)
|
||||
residual = residual * discriminator_bets
|
||||
|
||||
# Train the discriminator for one step
|
||||
discriminator_bets = self.discriminator(samples)
|
||||
self._train_discriminator(samples, equation, discriminator_bets)
|
||||
# Compute competitive residual.
|
||||
loss_val = self.loss(
|
||||
torch.zeros_like(residual, requires_grad=True),
|
||||
residual,
|
||||
)
|
||||
return loss_val
|
||||
|
||||
def configure_optimizers(self):
|
||||
@@ -213,58 +216,6 @@ class CompetitivePINN(PINNInterface, MultiSolverInterface):
|
||||
|
||||
return super().on_train_batch_end(outputs, batch, batch_idx)
|
||||
|
||||
def _train_discriminator(self, samples, equation, discriminator_bets):
|
||||
"""
|
||||
Trains the discriminator network of the Competitive PINN.
|
||||
|
||||
:param LabelTensor samples: Input samples to evaluate the physics loss.
|
||||
:param EquationInterface equation: The governing equation representing
|
||||
the physics.
|
||||
:param Tensor discriminator_bets: Predictions made by the discriminator
|
||||
network.
|
||||
"""
|
||||
# Compute residual. Detach since discriminator weights are fixed
|
||||
residual = self.compute_residual(
|
||||
samples=samples, equation=equation
|
||||
).detach()
|
||||
|
||||
# Compute competitive residual, then maximise the loss
|
||||
competitive_residual = residual * discriminator_bets
|
||||
loss_val = -self.loss(
|
||||
torch.zeros_like(competitive_residual, requires_grad=True),
|
||||
competitive_residual,
|
||||
)
|
||||
# prepare for optimizer step called in training step
|
||||
self.manual_backward(loss_val)
|
||||
|
||||
def _train_model(self, samples, equation, discriminator_bets):
|
||||
"""
|
||||
Trains the model network of the Competitive PINN.
|
||||
|
||||
:param LabelTensor samples: Input samples to evaluate the physics loss.
|
||||
:param EquationInterface equation: The governing equation representing
|
||||
the physics.
|
||||
:param Tensor discriminator_bets: Predictions made by the discriminator.
|
||||
network.
|
||||
:return: The computed data loss.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
# Compute residual
|
||||
residual = self.compute_residual(samples=samples, equation=equation)
|
||||
with torch.no_grad():
|
||||
loss_residual = self.loss(torch.zeros_like(residual), residual)
|
||||
|
||||
# Compute competitive residual. Detach discriminator_bets
|
||||
# to optimize only the generator model
|
||||
competitive_residual = residual * discriminator_bets.detach()
|
||||
loss_val = self.loss(
|
||||
torch.zeros_like(competitive_residual, requires_grad=True),
|
||||
competitive_residual,
|
||||
)
|
||||
# prepare for optimizer step called in training step
|
||||
self.manual_backward(loss_val)
|
||||
return loss_residual
|
||||
|
||||
@property
|
||||
def neural_net(self):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user