Formatting
* Adding black as dev dependency * Formatting pina code * Formatting tests
This commit is contained in:
committed by
Nicola Demo
parent
4c4482b155
commit
42ab1a666b
@@ -1,4 +1,4 @@
|
||||
""" Module for Competitive PINN. """
|
||||
"""Module for Competitive PINN."""
|
||||
|
||||
import torch
|
||||
import copy
|
||||
@@ -55,16 +55,18 @@ class CompetitivePINN(PINNInterface, MultiSolverInterface):
|
||||
``extra_feature``.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
problem,
|
||||
model,
|
||||
discriminator=None,
|
||||
optimizer_model=None,
|
||||
optimizer_discriminator=None,
|
||||
scheduler_model=None,
|
||||
scheduler_discriminator=None,
|
||||
weighting=None,
|
||||
loss=None):
|
||||
def __init__(
|
||||
self,
|
||||
problem,
|
||||
model,
|
||||
discriminator=None,
|
||||
optimizer_model=None,
|
||||
optimizer_discriminator=None,
|
||||
scheduler_model=None,
|
||||
scheduler_discriminator=None,
|
||||
weighting=None,
|
||||
loss=None,
|
||||
):
|
||||
"""
|
||||
:param AbstractProblem problem: The formulation of the problem.
|
||||
:param torch.nn.Module model: The neural network model to use
|
||||
@@ -72,13 +74,13 @@ class CompetitivePINN(PINNInterface, MultiSolverInterface):
|
||||
:param torch.nn.Module discriminator: The neural network model to use
|
||||
for the discriminator. If ``None``, the discriminator network will
|
||||
have the same architecture as the model network.
|
||||
:param torch.optim.Optimizer optimizer_model: The neural network
|
||||
:param torch.optim.Optimizer optimizer_model: The neural network
|
||||
optimizer to use for the model network; default `None`.
|
||||
:param torch.optim.Optimizer optimizer_discriminator: The neural network
|
||||
optimizer to use for the discriminator network; default `None`.
|
||||
:param torch.optim.LRScheduler scheduler_model: Learning rate scheduler
|
||||
:param torch.optim.LRScheduler scheduler_model: Learning rate scheduler
|
||||
for the model; default `None`.
|
||||
:param torch.optim.LRScheduler scheduler_discriminator: Learning rate
|
||||
:param torch.optim.LRScheduler scheduler_discriminator: Learning rate
|
||||
scheduler for the discriminator; default `None`.
|
||||
:param WeightingInterface weighting: The weighting schema to use;
|
||||
default `None`.
|
||||
@@ -88,12 +90,14 @@ class CompetitivePINN(PINNInterface, MultiSolverInterface):
|
||||
if discriminator is None:
|
||||
discriminator = copy.deepcopy(model)
|
||||
|
||||
super().__init__(models=[model, discriminator],
|
||||
problem=problem,
|
||||
optimizers=[optimizer_model, optimizer_discriminator],
|
||||
schedulers=[scheduler_model, scheduler_discriminator],
|
||||
weighting=weighting,
|
||||
loss=loss)
|
||||
super().__init__(
|
||||
models=[model, discriminator],
|
||||
problem=problem,
|
||||
optimizers=[optimizer_model, optimizer_discriminator],
|
||||
schedulers=[scheduler_model, scheduler_discriminator],
|
||||
weighting=weighting,
|
||||
loss=loss,
|
||||
)
|
||||
|
||||
# Set automatic optimization to False
|
||||
self.automatic_optimization = False
|
||||
@@ -158,7 +162,7 @@ class CompetitivePINN(PINNInterface, MultiSolverInterface):
|
||||
|
||||
def loss_data(self, input_pts, output_pts):
|
||||
"""
|
||||
The data loss for the CompetitivePINN solver. It computes the loss
|
||||
The data loss for the CompetitivePINN solver. It computes the loss
|
||||
between the network output against the true solution.
|
||||
|
||||
:param LabelTensor input_tensor: The input to the neural networks.
|
||||
@@ -167,7 +171,7 @@ class CompetitivePINN(PINNInterface, MultiSolverInterface):
|
||||
:return: The computed data loss.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
loss_val = (super().loss_data(input_pts, output_pts))
|
||||
loss_val = super().loss_data(input_pts, output_pts)
|
||||
# prepare for optimizer step called in training step
|
||||
loss_val.backward()
|
||||
return loss_val
|
||||
@@ -195,10 +199,14 @@ class CompetitivePINN(PINNInterface, MultiSolverInterface):
|
||||
self.scheduler_model.hook(self.optimizer_model)
|
||||
self.scheduler_discriminator.hook(self.optimizer_discriminator)
|
||||
return (
|
||||
[self.optimizer_model.instance,
|
||||
self.optimizer_discriminator.instance],
|
||||
[self.scheduler_model.instance,
|
||||
self.scheduler_discriminator.instance]
|
||||
[
|
||||
self.optimizer_model.instance,
|
||||
self.optimizer_discriminator.instance,
|
||||
],
|
||||
[
|
||||
self.scheduler_model.instance,
|
||||
self.scheduler_discriminator.instance,
|
||||
],
|
||||
)
|
||||
|
||||
def on_train_batch_end(self, outputs, batch, batch_idx):
|
||||
@@ -216,8 +224,7 @@ class CompetitivePINN(PINNInterface, MultiSolverInterface):
|
||||
"""
|
||||
# increase by one the counter of optimization to save loggers
|
||||
(
|
||||
self.trainer.fit_loop.epoch_loop.manual_optimization
|
||||
.optim_step_progress.total.completed
|
||||
self.trainer.fit_loop.epoch_loop.manual_optimization.optim_step_progress.total.completed
|
||||
) += 1
|
||||
|
||||
return super().on_train_batch_end(outputs, batch, batch_idx)
|
||||
|
||||
Reference in New Issue
Block a user