add scheduler step for multisolvers (#526)
This commit is contained in:
committed by
FilippoOlivo
parent
6d39e2fa98
commit
3958e35fdd
@@ -130,11 +130,15 @@ class CompetitivePINN(PINNInterface, MultiSolverInterface):
|
||||
loss = super().training_step(batch)
|
||||
self.manual_backward(loss)
|
||||
self.optimizer_model.instance.step()
|
||||
self.scheduler_model.instance.step()
|
||||
|
||||
# train discriminator
|
||||
self.optimizer_discriminator.instance.zero_grad()
|
||||
loss = super().training_step(batch)
|
||||
self.manual_backward(-loss)
|
||||
self.optimizer_discriminator.instance.step()
|
||||
self.scheduler_discriminator.instance.step()
|
||||
|
||||
return loss
|
||||
|
||||
def loss_phys(self, samples, equation):
|
||||
|
||||
@@ -188,12 +188,14 @@ class SelfAdaptivePINN(PINNInterface, MultiSolverInterface):
|
||||
loss = super().training_step(batch)
|
||||
self.manual_backward(-loss)
|
||||
self.optimizer_weights.instance.step()
|
||||
self.scheduler_weights.instance.step()
|
||||
|
||||
# Model optimization
|
||||
self.optimizer_model.instance.zero_grad()
|
||||
loss = super().training_step(batch)
|
||||
self.manual_backward(loss)
|
||||
self.optimizer_model.instance.step()
|
||||
self.scheduler_model.instance.step()
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
Reference in New Issue
Block a user