fixing competitive pinn (#470)

This commit is contained in:
Dario Coscia
2025-03-01 15:03:27 +01:00
committed by Nicola Demo
parent d68b51fce3
commit 4c3e305b09

View File

@@ -3,7 +3,7 @@
import torch
import copy
from pina.problem import InverseProblem
from ...problem import InverseProblem
from .pinn_interface import PINNInterface
from ..solver import MultiSolverInterface
@@ -125,10 +125,15 @@ class CompetitivePINN(PINNInterface, MultiSolverInterface):
:return: The sum of the loss functions.
:rtype: LabelTensor
"""
# train model
self.optimizer_model.instance.zero_grad()
loss = super().training_step(batch)
self.manual_backward(loss)
self.optimizer_model.instance.step()
# train discriminator
self.optimizer_discriminator.instance.zero_grad()
loss = super().training_step(batch)
self.optimizer_model.instance.step()
self.manual_backward(-loss)
self.optimizer_discriminator.instance.step()
return loss
@@ -144,20 +149,18 @@ class CompetitivePINN(PINNInterface, MultiSolverInterface):
samples and equation.
:rtype: LabelTensor
"""
# Train the model for one step
with torch.no_grad():
discriminator_bets = self.discriminator(samples)
loss_val = self._train_model(samples, equation, discriminator_bets)
# Detach samples from the existing computational graph and
# create a new one by setting requires_grad to True.
# In alternative set `retain_graph=True`.
samples = samples.detach()
samples.requires_grad_()
# Train the discriminator for one step
# Compute discriminator bets
discriminator_bets = self.discriminator(samples)
self._train_discriminator(samples, equation, discriminator_bets)
# Compute residual and multiply discriminator_bets
residual = self.compute_residual(samples=samples, equation=equation)
residual = residual * discriminator_bets
# Compute competitive residual.
loss_val = self.loss(
torch.zeros_like(residual, requires_grad=True),
residual,
)
return loss_val
def configure_optimizers(self):
@@ -213,58 +216,6 @@ class CompetitivePINN(PINNInterface, MultiSolverInterface):
return super().on_train_batch_end(outputs, batch, batch_idx)
def _train_discriminator(self, samples, equation, discriminator_bets):
"""
Trains the discriminator network of the Competitive PINN.
:param LabelTensor samples: Input samples to evaluate the physics loss.
:param EquationInterface equation: The governing equation representing
the physics.
:param Tensor discriminator_bets: Predictions made by the discriminator
network.
"""
# Compute residual. Detach since discriminator weights are fixed
residual = self.compute_residual(
samples=samples, equation=equation
).detach()
# Compute competitive residual, then maximise the loss
competitive_residual = residual * discriminator_bets
loss_val = -self.loss(
torch.zeros_like(competitive_residual, requires_grad=True),
competitive_residual,
)
# prepare for optimizer step called in training step
self.manual_backward(loss_val)
def _train_model(self, samples, equation, discriminator_bets):
"""
Trains the model network of the Competitive PINN.
:param LabelTensor samples: Input samples to evaluate the physics loss.
:param EquationInterface equation: The governing equation representing
the physics.
:param Tensor discriminator_bets: Predictions made by the discriminator.
network.
:return: The computed data loss.
:rtype: torch.Tensor
"""
# Compute residual
residual = self.compute_residual(samples=samples, equation=equation)
with torch.no_grad():
loss_residual = self.loss(torch.zeros_like(residual), residual)
# Compute competitive residual. Detach discriminator_bets
# to optimize only the generator model
competitive_residual = residual * discriminator_bets.detach()
loss_val = self.loss(
torch.zeros_like(competitive_residual, requires_grad=True),
competitive_residual,
)
# prepare for optimizer step called in training step
self.manual_backward(loss_val)
return loss_residual
@property
def neural_net(self):
"""