🎨 Format Python code with psf/black (#297)
Co-authored-by: dario-coscia <dario-coscia@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
e0429bb445
commit
9463ae4b15
@@ -117,7 +117,7 @@ class CompetitivePINN(PINNInterface):
|
||||
optimizer_discriminator_kwargs,
|
||||
],
|
||||
extra_features=None, # CompetitivePINN doesn't take extra features
|
||||
loss=loss
|
||||
loss=loss,
|
||||
)
|
||||
|
||||
# set automatic optimization for GANs
|
||||
@@ -131,9 +131,7 @@ class CompetitivePINN(PINNInterface):
|
||||
|
||||
# assign schedulers
|
||||
self._schedulers = [
|
||||
scheduler_model(
|
||||
self.optimizers[0], **scheduler_model_kwargs
|
||||
),
|
||||
scheduler_model(self.optimizers[0], **scheduler_model_kwargs),
|
||||
scheduler_discriminator(
|
||||
self.optimizers[1], **scheduler_discriminator_kwargs
|
||||
),
|
||||
@@ -141,7 +139,7 @@ class CompetitivePINN(PINNInterface):
|
||||
|
||||
self._model = self.models[0]
|
||||
self._discriminator = self.models[1]
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
r"""
|
||||
Forward pass implementation for the PINN solver. It returns the function
|
||||
@@ -195,8 +193,11 @@ class CompetitivePINN(PINNInterface):
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
self.optimizer_model.zero_grad()
|
||||
loss_val = super().loss_data(
|
||||
input_tensor, output_tensor).as_subclass(torch.Tensor)
|
||||
loss_val = (
|
||||
super()
|
||||
.loss_data(input_tensor, output_tensor)
|
||||
.as_subclass(torch.Tensor)
|
||||
)
|
||||
loss_val.backward()
|
||||
self.optimizer_model.step()
|
||||
return loss_val
|
||||
@@ -221,7 +222,7 @@ class CompetitivePINN(PINNInterface):
|
||||
)
|
||||
return self.optimizers, self._schedulers
|
||||
|
||||
def on_train_batch_end(self,outputs, batch, batch_idx):
|
||||
def on_train_batch_end(self, outputs, batch, batch_idx):
|
||||
"""
|
||||
This method is called at the end of each training batch, and ovverides
|
||||
the PytorchLightining implementation for logging the checkpoints.
|
||||
@@ -235,7 +236,9 @@ class CompetitivePINN(PINNInterface):
|
||||
:rtype: Any
|
||||
"""
|
||||
# increase by one the counter of optimization to save loggers
|
||||
self.trainer.fit_loop.epoch_loop.manual_optimization.optim_step_progress.total.completed += 1
|
||||
self.trainer.fit_loop.epoch_loop.manual_optimization.optim_step_progress.total.completed += (
|
||||
1
|
||||
)
|
||||
return super().on_train_batch_end(outputs, batch, batch_idx)
|
||||
|
||||
def _train_discriminator(self, samples, equation, discriminator_bets):
|
||||
@@ -252,13 +255,14 @@ class CompetitivePINN(PINNInterface):
|
||||
self.optimizer_discriminator.zero_grad()
|
||||
# compute residual, we detach because the weights of the generator
|
||||
# model are fixed
|
||||
residual = self.compute_residual(samples=samples,
|
||||
equation=equation).detach()
|
||||
residual = self.compute_residual(
|
||||
samples=samples, equation=equation
|
||||
).detach()
|
||||
# compute competitive residual, the minus is because we maximise
|
||||
competitive_residual = residual * discriminator_bets
|
||||
loss_val = - self.loss(
|
||||
loss_val = -self.loss(
|
||||
torch.zeros_like(competitive_residual, requires_grad=True),
|
||||
competitive_residual
|
||||
competitive_residual,
|
||||
).as_subclass(torch.Tensor)
|
||||
# backprop
|
||||
self.manual_backward(loss_val)
|
||||
@@ -283,16 +287,13 @@ class CompetitivePINN(PINNInterface):
|
||||
residual = self.compute_residual(samples=samples, equation=equation)
|
||||
# store logging
|
||||
with torch.no_grad():
|
||||
loss_residual = self.loss(
|
||||
torch.zeros_like(residual),
|
||||
residual
|
||||
)
|
||||
loss_residual = self.loss(torch.zeros_like(residual), residual)
|
||||
# compute competitive residual, discriminator_bets are detached becase
|
||||
# we optimize only the generator model
|
||||
competitive_residual = residual * discriminator_bets.detach()
|
||||
loss_val = self.loss(
|
||||
torch.zeros_like(competitive_residual, requires_grad=True),
|
||||
competitive_residual
|
||||
competitive_residual,
|
||||
).as_subclass(torch.Tensor)
|
||||
# backprop
|
||||
self.manual_backward(loss_val)
|
||||
@@ -357,4 +358,4 @@ class CompetitivePINN(PINNInterface):
|
||||
:return: The scheduler for the discriminator.
|
||||
:rtype: torch.optim.lr_scheduler._LRScheduler
|
||||
"""
|
||||
return self._schedulers[1]
|
||||
return self._schedulers[1]
|
||||
|
||||
Reference in New Issue
Block a user