add scheduler step for multisolvers (#526)
This commit is contained in:
@@ -151,9 +151,9 @@ class GAROM(MultiSolverInterface):
|
|||||||
:return: The residual loss and the generator loss.
|
:return: The residual loss and the generator loss.
|
||||||
:rtype: tuple[torch.Tensor, torch.Tensor]
|
:rtype: tuple[torch.Tensor, torch.Tensor]
|
||||||
"""
|
"""
|
||||||
optimizer = self.optimizer_generator
|
self.optimizer_generator.instance.zero_grad()
|
||||||
optimizer.zero_grad()
|
|
||||||
|
|
||||||
|
# Generate a batch of images
|
||||||
generated_snapshots = self.sample(parameters)
|
generated_snapshots = self.sample(parameters)
|
||||||
|
|
||||||
# generator loss
|
# generator loss
|
||||||
@@ -165,7 +165,8 @@ class GAROM(MultiSolverInterface):
|
|||||||
|
|
||||||
# backward step
|
# backward step
|
||||||
g_loss.backward()
|
g_loss.backward()
|
||||||
optimizer.step()
|
self.optimizer_generator.instance.step()
|
||||||
|
self.scheduler_generator.instance.step()
|
||||||
|
|
||||||
return r_loss, g_loss
|
return r_loss, g_loss
|
||||||
|
|
||||||
@@ -196,8 +197,7 @@ class GAROM(MultiSolverInterface):
|
|||||||
:return: The residual loss and the generator loss.
|
:return: The residual loss and the generator loss.
|
||||||
:rtype: tuple[torch.Tensor, torch.Tensor]
|
:rtype: tuple[torch.Tensor, torch.Tensor]
|
||||||
"""
|
"""
|
||||||
optimizer = self.optimizer_discriminator
|
self.optimizer_discriminator.instance.zero_grad()
|
||||||
optimizer.zero_grad()
|
|
||||||
|
|
||||||
# Generate a batch of images
|
# Generate a batch of images
|
||||||
generated_snapshots = self.sample(parameters)
|
generated_snapshots = self.sample(parameters)
|
||||||
@@ -213,7 +213,8 @@ class GAROM(MultiSolverInterface):
|
|||||||
|
|
||||||
# backward step
|
# backward step
|
||||||
d_loss.backward()
|
d_loss.backward()
|
||||||
optimizer.step()
|
self.optimizer_discriminator.instance.step()
|
||||||
|
self.scheduler_discriminator.instance.step()
|
||||||
|
|
||||||
return d_loss_real, d_loss_fake, d_loss
|
return d_loss_real, d_loss_fake, d_loss
|
||||||
|
|
||||||
@@ -345,7 +346,7 @@ class GAROM(MultiSolverInterface):
|
|||||||
:return: The optimizer for the generator.
|
:return: The optimizer for the generator.
|
||||||
:rtype: Optimizer
|
:rtype: Optimizer
|
||||||
"""
|
"""
|
||||||
return self.optimizers[0].instance
|
return self.optimizers[0]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def optimizer_discriminator(self):
|
def optimizer_discriminator(self):
|
||||||
@@ -355,7 +356,7 @@ class GAROM(MultiSolverInterface):
|
|||||||
:return: The optimizer for the discriminator.
|
:return: The optimizer for the discriminator.
|
||||||
:rtype: Optimizer
|
:rtype: Optimizer
|
||||||
"""
|
"""
|
||||||
return self.optimizers[1].instance
|
return self.optimizers[1]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def scheduler_generator(self):
|
def scheduler_generator(self):
|
||||||
@@ -365,7 +366,7 @@ class GAROM(MultiSolverInterface):
|
|||||||
:return: The scheduler for the generator.
|
:return: The scheduler for the generator.
|
||||||
:rtype: Scheduler
|
:rtype: Scheduler
|
||||||
"""
|
"""
|
||||||
return self.schedulers[0].instance
|
return self.schedulers[0]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def scheduler_discriminator(self):
|
def scheduler_discriminator(self):
|
||||||
@@ -375,4 +376,4 @@ class GAROM(MultiSolverInterface):
|
|||||||
:return: The scheduler for the discriminator.
|
:return: The scheduler for the discriminator.
|
||||||
:rtype: Scheduler
|
:rtype: Scheduler
|
||||||
"""
|
"""
|
||||||
return self.schedulers[1].instance
|
return self.schedulers[1]
|
||||||
|
|||||||
@@ -130,11 +130,15 @@ class CompetitivePINN(PINNInterface, MultiSolverInterface):
|
|||||||
loss = super().training_step(batch)
|
loss = super().training_step(batch)
|
||||||
self.manual_backward(loss)
|
self.manual_backward(loss)
|
||||||
self.optimizer_model.instance.step()
|
self.optimizer_model.instance.step()
|
||||||
|
self.scheduler_model.instance.step()
|
||||||
|
|
||||||
# train discriminator
|
# train discriminator
|
||||||
self.optimizer_discriminator.instance.zero_grad()
|
self.optimizer_discriminator.instance.zero_grad()
|
||||||
loss = super().training_step(batch)
|
loss = super().training_step(batch)
|
||||||
self.manual_backward(-loss)
|
self.manual_backward(-loss)
|
||||||
self.optimizer_discriminator.instance.step()
|
self.optimizer_discriminator.instance.step()
|
||||||
|
self.scheduler_discriminator.instance.step()
|
||||||
|
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
def loss_phys(self, samples, equation):
|
def loss_phys(self, samples, equation):
|
||||||
|
|||||||
@@ -188,12 +188,14 @@ class SelfAdaptivePINN(PINNInterface, MultiSolverInterface):
|
|||||||
loss = super().training_step(batch)
|
loss = super().training_step(batch)
|
||||||
self.manual_backward(-loss)
|
self.manual_backward(-loss)
|
||||||
self.optimizer_weights.instance.step()
|
self.optimizer_weights.instance.step()
|
||||||
|
self.scheduler_weights.instance.step()
|
||||||
|
|
||||||
# Model optimization
|
# Model optimization
|
||||||
self.optimizer_model.instance.zero_grad()
|
self.optimizer_model.instance.zero_grad()
|
||||||
loss = super().training_step(batch)
|
loss = super().training_step(batch)
|
||||||
self.manual_backward(loss)
|
self.manual_backward(loss)
|
||||||
self.optimizer_model.instance.step()
|
self.optimizer_model.instance.step()
|
||||||
|
self.scheduler_model.instance.step()
|
||||||
|
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user