fixing competitive pinn (#470)
This commit is contained in:
committed by
Nicola Demo
parent
d68b51fce3
commit
4c3e305b09
@@ -3,7 +3,7 @@
|
|||||||
import torch
|
import torch
|
||||||
import copy
|
import copy
|
||||||
|
|
||||||
from pina.problem import InverseProblem
|
from ...problem import InverseProblem
|
||||||
from .pinn_interface import PINNInterface
|
from .pinn_interface import PINNInterface
|
||||||
from ..solver import MultiSolverInterface
|
from ..solver import MultiSolverInterface
|
||||||
|
|
||||||
@@ -125,10 +125,15 @@ class CompetitivePINN(PINNInterface, MultiSolverInterface):
|
|||||||
:return: The sum of the loss functions.
|
:return: The sum of the loss functions.
|
||||||
:rtype: LabelTensor
|
:rtype: LabelTensor
|
||||||
"""
|
"""
|
||||||
|
# train model
|
||||||
self.optimizer_model.instance.zero_grad()
|
self.optimizer_model.instance.zero_grad()
|
||||||
|
loss = super().training_step(batch)
|
||||||
|
self.manual_backward(loss)
|
||||||
|
self.optimizer_model.instance.step()
|
||||||
|
# train discriminator
|
||||||
self.optimizer_discriminator.instance.zero_grad()
|
self.optimizer_discriminator.instance.zero_grad()
|
||||||
loss = super().training_step(batch)
|
loss = super().training_step(batch)
|
||||||
self.optimizer_model.instance.step()
|
self.manual_backward(-loss)
|
||||||
self.optimizer_discriminator.instance.step()
|
self.optimizer_discriminator.instance.step()
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
@@ -144,20 +149,18 @@ class CompetitivePINN(PINNInterface, MultiSolverInterface):
|
|||||||
samples and equation.
|
samples and equation.
|
||||||
:rtype: LabelTensor
|
:rtype: LabelTensor
|
||||||
"""
|
"""
|
||||||
# Train the model for one step
|
# Compute discriminator bets
|
||||||
with torch.no_grad():
|
|
||||||
discriminator_bets = self.discriminator(samples)
|
|
||||||
loss_val = self._train_model(samples, equation, discriminator_bets)
|
|
||||||
|
|
||||||
# Detach samples from the existing computational graph and
|
|
||||||
# create a new one by setting requires_grad to True.
|
|
||||||
# In alternative set `retain_graph=True`.
|
|
||||||
samples = samples.detach()
|
|
||||||
samples.requires_grad_()
|
|
||||||
|
|
||||||
# Train the discriminator for one step
|
|
||||||
discriminator_bets = self.discriminator(samples)
|
discriminator_bets = self.discriminator(samples)
|
||||||
self._train_discriminator(samples, equation, discriminator_bets)
|
|
||||||
|
# Compute residual and multiply discriminator_bets
|
||||||
|
residual = self.compute_residual(samples=samples, equation=equation)
|
||||||
|
residual = residual * discriminator_bets
|
||||||
|
|
||||||
|
# Compute competitive residual.
|
||||||
|
loss_val = self.loss(
|
||||||
|
torch.zeros_like(residual, requires_grad=True),
|
||||||
|
residual,
|
||||||
|
)
|
||||||
return loss_val
|
return loss_val
|
||||||
|
|
||||||
def configure_optimizers(self):
|
def configure_optimizers(self):
|
||||||
@@ -213,58 +216,6 @@ class CompetitivePINN(PINNInterface, MultiSolverInterface):
|
|||||||
|
|
||||||
return super().on_train_batch_end(outputs, batch, batch_idx)
|
return super().on_train_batch_end(outputs, batch, batch_idx)
|
||||||
|
|
||||||
def _train_discriminator(self, samples, equation, discriminator_bets):
|
|
||||||
"""
|
|
||||||
Trains the discriminator network of the Competitive PINN.
|
|
||||||
|
|
||||||
:param LabelTensor samples: Input samples to evaluate the physics loss.
|
|
||||||
:param EquationInterface equation: The governing equation representing
|
|
||||||
the physics.
|
|
||||||
:param Tensor discriminator_bets: Predictions made by the discriminator
|
|
||||||
network.
|
|
||||||
"""
|
|
||||||
# Compute residual. Detach since discriminator weights are fixed
|
|
||||||
residual = self.compute_residual(
|
|
||||||
samples=samples, equation=equation
|
|
||||||
).detach()
|
|
||||||
|
|
||||||
# Compute competitive residual, then maximise the loss
|
|
||||||
competitive_residual = residual * discriminator_bets
|
|
||||||
loss_val = -self.loss(
|
|
||||||
torch.zeros_like(competitive_residual, requires_grad=True),
|
|
||||||
competitive_residual,
|
|
||||||
)
|
|
||||||
# prepare for optimizer step called in training step
|
|
||||||
self.manual_backward(loss_val)
|
|
||||||
|
|
||||||
def _train_model(self, samples, equation, discriminator_bets):
|
|
||||||
"""
|
|
||||||
Trains the model network of the Competitive PINN.
|
|
||||||
|
|
||||||
:param LabelTensor samples: Input samples to evaluate the physics loss.
|
|
||||||
:param EquationInterface equation: The governing equation representing
|
|
||||||
the physics.
|
|
||||||
:param Tensor discriminator_bets: Predictions made by the discriminator.
|
|
||||||
network.
|
|
||||||
:return: The computed data loss.
|
|
||||||
:rtype: torch.Tensor
|
|
||||||
"""
|
|
||||||
# Compute residual
|
|
||||||
residual = self.compute_residual(samples=samples, equation=equation)
|
|
||||||
with torch.no_grad():
|
|
||||||
loss_residual = self.loss(torch.zeros_like(residual), residual)
|
|
||||||
|
|
||||||
# Compute competitive residual. Detach discriminator_bets
|
|
||||||
# to optimize only the generator model
|
|
||||||
competitive_residual = residual * discriminator_bets.detach()
|
|
||||||
loss_val = self.loss(
|
|
||||||
torch.zeros_like(competitive_residual, requires_grad=True),
|
|
||||||
competitive_residual,
|
|
||||||
)
|
|
||||||
# prepare for optimizer step called in training step
|
|
||||||
self.manual_backward(loss_val)
|
|
||||||
return loss_residual
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def neural_net(self):
|
def neural_net(self):
|
||||||
"""
|
"""
|
||||||
|
|||||||
Reference in New Issue
Block a user