add scheduler step for multisolvers (#526)

This commit is contained in:
Giovanni Canali
2025-03-27 15:55:15 +01:00
committed by FilippoOlivo
parent 6d39e2fa98
commit 3958e35fdd
3 changed files with 17 additions and 10 deletions

View File

@@ -130,11 +130,15 @@ class CompetitivePINN(PINNInterface, MultiSolverInterface):
loss = super().training_step(batch)
self.manual_backward(loss)
self.optimizer_model.instance.step()
self.scheduler_model.instance.step()
# train discriminator
self.optimizer_discriminator.instance.zero_grad()
loss = super().training_step(batch)
self.manual_backward(-loss)
self.optimizer_discriminator.instance.step()
self.scheduler_discriminator.instance.step()
return loss
def loss_phys(self, samples, equation):

View File

@@ -188,12 +188,14 @@ class SelfAdaptivePINN(PINNInterface, MultiSolverInterface):
loss = super().training_step(batch)
self.manual_backward(-loss)
self.optimizer_weights.instance.step()
self.scheduler_weights.instance.step()
# Model optimization
self.optimizer_model.instance.zero_grad()
loss = super().training_step(batch)
self.manual_backward(loss)
self.optimizer_model.instance.step()
self.scheduler_model.instance.step()
return loss