add scheduler step for multisolvers (#526)
This commit is contained in:
@@ -151,9 +151,9 @@ class GAROM(MultiSolverInterface):
|
||||
:return: The residual loss and the generator loss.
|
||||
:rtype: tuple[torch.Tensor, torch.Tensor]
|
||||
"""
|
||||
optimizer = self.optimizer_generator
|
||||
optimizer.zero_grad()
|
||||
self.optimizer_generator.instance.zero_grad()
|
||||
|
||||
# Generate a batch of images
|
||||
generated_snapshots = self.sample(parameters)
|
||||
|
||||
# generator loss
|
||||
@@ -165,7 +165,8 @@ class GAROM(MultiSolverInterface):
|
||||
|
||||
# backward step
|
||||
g_loss.backward()
|
||||
optimizer.step()
|
||||
self.optimizer_generator.instance.step()
|
||||
self.scheduler_generator.instance.step()
|
||||
|
||||
return r_loss, g_loss
|
||||
|
||||
@@ -196,8 +197,7 @@ class GAROM(MultiSolverInterface):
|
||||
:return: The residual loss and the generator loss.
|
||||
:rtype: tuple[torch.Tensor, torch.Tensor]
|
||||
"""
|
||||
optimizer = self.optimizer_discriminator
|
||||
optimizer.zero_grad()
|
||||
self.optimizer_discriminator.instance.zero_grad()
|
||||
|
||||
# Generate a batch of images
|
||||
generated_snapshots = self.sample(parameters)
|
||||
@@ -213,7 +213,8 @@ class GAROM(MultiSolverInterface):
|
||||
|
||||
# backward step
|
||||
d_loss.backward()
|
||||
optimizer.step()
|
||||
self.optimizer_discriminator.instance.step()
|
||||
self.scheduler_discriminator.instance.step()
|
||||
|
||||
return d_loss_real, d_loss_fake, d_loss
|
||||
|
||||
@@ -345,7 +346,7 @@ class GAROM(MultiSolverInterface):
|
||||
:return: The optimizer for the generator.
|
||||
:rtype: Optimizer
|
||||
"""
|
||||
return self.optimizers[0].instance
|
||||
return self.optimizers[0]
|
||||
|
||||
@property
|
||||
def optimizer_discriminator(self):
|
||||
@@ -355,7 +356,7 @@ class GAROM(MultiSolverInterface):
|
||||
:return: The optimizer for the discriminator.
|
||||
:rtype: Optimizer
|
||||
"""
|
||||
return self.optimizers[1].instance
|
||||
return self.optimizers[1]
|
||||
|
||||
@property
|
||||
def scheduler_generator(self):
|
||||
@@ -365,7 +366,7 @@ class GAROM(MultiSolverInterface):
|
||||
:return: The scheduler for the generator.
|
||||
:rtype: Scheduler
|
||||
"""
|
||||
return self.schedulers[0].instance
|
||||
return self.schedulers[0]
|
||||
|
||||
@property
|
||||
def scheduler_discriminator(self):
|
||||
@@ -375,4 +376,4 @@ class GAROM(MultiSolverInterface):
|
||||
:return: The scheduler for the discriminator.
|
||||
:rtype: Scheduler
|
||||
"""
|
||||
return self.schedulers[1].instance
|
||||
return self.schedulers[1]
|
||||
|
||||
@@ -130,11 +130,15 @@ class CompetitivePINN(PINNInterface, MultiSolverInterface):
|
||||
loss = super().training_step(batch)
|
||||
self.manual_backward(loss)
|
||||
self.optimizer_model.instance.step()
|
||||
self.scheduler_model.instance.step()
|
||||
|
||||
# train discriminator
|
||||
self.optimizer_discriminator.instance.zero_grad()
|
||||
loss = super().training_step(batch)
|
||||
self.manual_backward(-loss)
|
||||
self.optimizer_discriminator.instance.step()
|
||||
self.scheduler_discriminator.instance.step()
|
||||
|
||||
return loss
|
||||
|
||||
def loss_phys(self, samples, equation):
|
||||
|
||||
@@ -188,12 +188,14 @@ class SelfAdaptivePINN(PINNInterface, MultiSolverInterface):
|
||||
loss = super().training_step(batch)
|
||||
self.manual_backward(-loss)
|
||||
self.optimizer_weights.instance.step()
|
||||
self.scheduler_weights.instance.step()
|
||||
|
||||
# Model optimization
|
||||
self.optimizer_model.instance.zero_grad()
|
||||
loss = super().training_step(batch)
|
||||
self.manual_backward(loss)
|
||||
self.optimizer_model.instance.step()
|
||||
self.scheduler_model.instance.step()
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
Reference in New Issue
Block a user