add scheduler step for multisolvers (#526)

This commit is contained in:
Giovanni Canali
2025-03-27 15:55:15 +01:00
committed by GitHub
parent b958c0f5db
commit f48da47ed4
3 changed files with 17 additions and 10 deletions

View File

@@ -151,9 +151,9 @@ class GAROM(MultiSolverInterface):
:return: The residual loss and the generator loss. :return: The residual loss and the generator loss.
:rtype: tuple[torch.Tensor, torch.Tensor] :rtype: tuple[torch.Tensor, torch.Tensor]
""" """
optimizer = self.optimizer_generator self.optimizer_generator.instance.zero_grad()
optimizer.zero_grad()
# Generate a batch of images
generated_snapshots = self.sample(parameters) generated_snapshots = self.sample(parameters)
# generator loss # generator loss
@@ -165,7 +165,8 @@ class GAROM(MultiSolverInterface):
# backward step # backward step
g_loss.backward() g_loss.backward()
optimizer.step() self.optimizer_generator.instance.step()
self.scheduler_generator.instance.step()
return r_loss, g_loss return r_loss, g_loss
@@ -196,8 +197,7 @@ class GAROM(MultiSolverInterface):
:return: The residual loss and the generator loss. :return: The residual loss and the generator loss.
:rtype: tuple[torch.Tensor, torch.Tensor] :rtype: tuple[torch.Tensor, torch.Tensor]
""" """
optimizer = self.optimizer_discriminator self.optimizer_discriminator.instance.zero_grad()
optimizer.zero_grad()
# Generate a batch of images # Generate a batch of images
generated_snapshots = self.sample(parameters) generated_snapshots = self.sample(parameters)
@@ -213,7 +213,8 @@ class GAROM(MultiSolverInterface):
# backward step # backward step
d_loss.backward() d_loss.backward()
optimizer.step() self.optimizer_discriminator.instance.step()
self.scheduler_discriminator.instance.step()
return d_loss_real, d_loss_fake, d_loss return d_loss_real, d_loss_fake, d_loss
@@ -345,7 +346,7 @@ class GAROM(MultiSolverInterface):
:return: The optimizer for the generator. :return: The optimizer for the generator.
:rtype: Optimizer :rtype: Optimizer
""" """
return self.optimizers[0].instance return self.optimizers[0]
@property @property
def optimizer_discriminator(self): def optimizer_discriminator(self):
@@ -355,7 +356,7 @@ class GAROM(MultiSolverInterface):
:return: The optimizer for the discriminator. :return: The optimizer for the discriminator.
:rtype: Optimizer :rtype: Optimizer
""" """
return self.optimizers[1].instance return self.optimizers[1]
@property @property
def scheduler_generator(self): def scheduler_generator(self):
@@ -365,7 +366,7 @@ class GAROM(MultiSolverInterface):
:return: The scheduler for the generator. :return: The scheduler for the generator.
:rtype: Scheduler :rtype: Scheduler
""" """
return self.schedulers[0].instance return self.schedulers[0]
@property @property
def scheduler_discriminator(self): def scheduler_discriminator(self):
@@ -375,4 +376,4 @@ class GAROM(MultiSolverInterface):
:return: The scheduler for the discriminator. :return: The scheduler for the discriminator.
:rtype: Scheduler :rtype: Scheduler
""" """
return self.schedulers[1].instance return self.schedulers[1]

View File

@@ -130,11 +130,15 @@ class CompetitivePINN(PINNInterface, MultiSolverInterface):
loss = super().training_step(batch) loss = super().training_step(batch)
self.manual_backward(loss) self.manual_backward(loss)
self.optimizer_model.instance.step() self.optimizer_model.instance.step()
self.scheduler_model.instance.step()
# train discriminator # train discriminator
self.optimizer_discriminator.instance.zero_grad() self.optimizer_discriminator.instance.zero_grad()
loss = super().training_step(batch) loss = super().training_step(batch)
self.manual_backward(-loss) self.manual_backward(-loss)
self.optimizer_discriminator.instance.step() self.optimizer_discriminator.instance.step()
self.scheduler_discriminator.instance.step()
return loss return loss
def loss_phys(self, samples, equation): def loss_phys(self, samples, equation):

View File

@@ -188,12 +188,14 @@ class SelfAdaptivePINN(PINNInterface, MultiSolverInterface):
loss = super().training_step(batch) loss = super().training_step(batch)
self.manual_backward(-loss) self.manual_backward(-loss)
self.optimizer_weights.instance.step() self.optimizer_weights.instance.step()
self.scheduler_weights.instance.step()
# Model optimization # Model optimization
self.optimizer_model.instance.zero_grad() self.optimizer_model.instance.zero_grad()
loss = super().training_step(batch) loss = super().training_step(batch)
self.manual_backward(loss) self.manual_backward(loss)
self.optimizer_model.instance.step() self.optimizer_model.instance.step()
self.scheduler_model.instance.step()
return loss return loss