Formatting
* Adding black as dev dependency * Formatting pina code * Formatting tests
This commit is contained in:
committed by
Nicola Demo
parent
4c4482b155
commit
42ab1a666b
@@ -1,4 +1,4 @@
|
||||
""" Module for Self-Adaptive PINN. """
|
||||
"""Module for Self-Adaptive PINN."""
|
||||
|
||||
import torch
|
||||
from copy import deepcopy
|
||||
@@ -99,25 +99,27 @@ class SelfAdaptivePINN(PINNInterface, MultiSolverInterface):
|
||||
j.jcp.2022.111722 <https://doi.org/10.1016/j.jcp.2022.111722>`_.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
problem,
|
||||
model,
|
||||
weight_function=torch.nn.Sigmoid(),
|
||||
optimizer_model=None,
|
||||
optimizer_weights=None,
|
||||
scheduler_model=None,
|
||||
scheduler_weights=None,
|
||||
weighting=None,
|
||||
loss=None):
|
||||
def __init__(
|
||||
self,
|
||||
problem,
|
||||
model,
|
||||
weight_function=torch.nn.Sigmoid(),
|
||||
optimizer_model=None,
|
||||
optimizer_weights=None,
|
||||
scheduler_model=None,
|
||||
scheduler_weights=None,
|
||||
weighting=None,
|
||||
loss=None,
|
||||
):
|
||||
"""
|
||||
:param AbstractProblem problem: The formulation of the problem.
|
||||
:param torch.nn.Module model: The neural network model to use for
|
||||
:param torch.nn.Module model: The neural network model to use for
|
||||
the model.
|
||||
:param torch.nn.Module weight_function: The neural network model
|
||||
related to the Self-Adaptive PINN mask; default `torch.nn.Sigmoid()`
|
||||
:param torch.optim.Optimizer optimizer_model: The neural network
|
||||
:param torch.optim.Optimizer optimizer_model: The neural network
|
||||
optimizer to use for the model network; default `None`.
|
||||
:param torch.optim.Optimizer optimizer_weights: The neural network
|
||||
:param torch.optim.Optimizer optimizer_weights: The neural network
|
||||
optimizer to use for mask model; default `None`.
|
||||
:param torch.optim.LRScheduler scheduler_model: Learning rate scheduler
|
||||
for the model; default `None`.
|
||||
@@ -137,12 +139,14 @@ class SelfAdaptivePINN(PINNInterface, MultiSolverInterface):
|
||||
weights_dict[condition_name] = Weights(weight_function)
|
||||
weights_dict = torch.nn.ModuleDict(weights_dict)
|
||||
|
||||
super().__init__(models=[model, weights_dict],
|
||||
problem=problem,
|
||||
optimizers=[optimizer_model, optimizer_weights],
|
||||
schedulers=[scheduler_model, scheduler_weights],
|
||||
weighting=weighting,
|
||||
loss=loss)
|
||||
super().__init__(
|
||||
models=[model, weights_dict],
|
||||
problem=problem,
|
||||
optimizers=[optimizer_model, optimizer_weights],
|
||||
schedulers=[scheduler_model, scheduler_weights],
|
||||
weighting=weighting,
|
||||
loss=loss,
|
||||
)
|
||||
|
||||
# Set automatic optimization to False
|
||||
self.automatic_optimization = False
|
||||
@@ -202,7 +206,7 @@ class SelfAdaptivePINN(PINNInterface, MultiSolverInterface):
|
||||
# create a new one by setting requires_grad to True.
|
||||
# In alternative set `retain_graph=True`.
|
||||
samples = samples.detach()
|
||||
samples.requires_grad_()# = True
|
||||
samples.requires_grad_() # = True
|
||||
|
||||
# Train the model
|
||||
weighted_loss = self._loss_phys(samples, equation)
|
||||
@@ -244,20 +248,18 @@ class SelfAdaptivePINN(PINNInterface, MultiSolverInterface):
|
||||
self.optimizer_weights.hook(self.weights_dict.parameters())
|
||||
if isinstance(self.problem, InverseProblem):
|
||||
self.optimizer_model.instance.add_param_group(
|
||||
{
|
||||
"params": [
|
||||
self._params[var]
|
||||
for var in self.problem.unknown_variables
|
||||
]
|
||||
}
|
||||
)
|
||||
{
|
||||
"params": [
|
||||
self._params[var]
|
||||
for var in self.problem.unknown_variables
|
||||
]
|
||||
}
|
||||
)
|
||||
self.scheduler_model.hook(self.optimizer_model)
|
||||
self.scheduler_weights.hook(self.optimizer_weights)
|
||||
return (
|
||||
[self.optimizer_model.instance,
|
||||
self.optimizer_weights.instance],
|
||||
[self.scheduler_model.instance,
|
||||
self.scheduler_weights.instance]
|
||||
[self.optimizer_model.instance, self.optimizer_weights.instance],
|
||||
[self.scheduler_model.instance, self.scheduler_weights.instance],
|
||||
)
|
||||
|
||||
def on_train_batch_end(self, outputs, batch, batch_idx):
|
||||
@@ -275,8 +277,7 @@ class SelfAdaptivePINN(PINNInterface, MultiSolverInterface):
|
||||
"""
|
||||
# increase by one the counter of optimization to save loggers
|
||||
(
|
||||
self.trainer.fit_loop.epoch_loop.manual_optimization
|
||||
.optim_step_progress.total.completed
|
||||
self.trainer.fit_loop.epoch_loop.manual_optimization.optim_step_progress.total.completed
|
||||
) += 1
|
||||
|
||||
return super().on_train_batch_end(outputs, batch, batch_idx)
|
||||
@@ -291,19 +292,22 @@ class SelfAdaptivePINN(PINNInterface, MultiSolverInterface):
|
||||
:rtype: Any
|
||||
"""
|
||||
if self.trainer.batch_size is not None:
|
||||
raise NotImplementedError("SelfAdaptivePINN only works with full "
|
||||
"batch size, set batch_size=None inside "
|
||||
"the Trainer to use the solver.")
|
||||
raise NotImplementedError(
|
||||
"SelfAdaptivePINN only works with full "
|
||||
"batch size, set batch_size=None inside "
|
||||
"the Trainer to use the solver."
|
||||
)
|
||||
device = torch.device(
|
||||
self.trainer._accelerator_connector._accelerator_flag
|
||||
)
|
||||
|
||||
# Initialize the self adaptive weights only for training points
|
||||
for condition_name, tensor in (
|
||||
self.trainer.data_module.train_dataset.input_points.items()
|
||||
):
|
||||
self.weights_dict[condition_name].sa_weights.data = (
|
||||
torch.rand((tensor.shape[0], 1), device=device)
|
||||
for (
|
||||
condition_name,
|
||||
tensor,
|
||||
) in self.trainer.data_module.train_dataset.input_points.items():
|
||||
self.weights_dict[condition_name].sa_weights.data = torch.rand(
|
||||
(tensor.shape[0], 1), device=device
|
||||
)
|
||||
return super().on_train_start()
|
||||
|
||||
@@ -318,11 +322,11 @@ class SelfAdaptivePINN(PINNInterface, MultiSolverInterface):
|
||||
# First initialize self-adaptive weights with correct shape,
|
||||
# then load the values from the checkpoint.
|
||||
for condition_name, _ in self.problem.input_pts.items():
|
||||
shape = checkpoint['state_dict'][
|
||||
shape = checkpoint["state_dict"][
|
||||
f"_pina_models.1.{condition_name}.sa_weights"
|
||||
].shape
|
||||
self.weights_dict[condition_name].sa_weights.data = (
|
||||
torch.rand(shape)
|
||||
self.weights_dict[condition_name].sa_weights.data = torch.rand(
|
||||
shape
|
||||
)
|
||||
return super().on_load_checkpoint(checkpoint)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user