🎨 Format Python code with psf/black (#297)
Co-authored-by: dario-coscia <dario-coscia@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
e0429bb445
commit
9463ae4b15
@@ -14,6 +14,7 @@ from pina.problem import InverseProblem
|
||||
|
||||
from torch.optim.lr_scheduler import ConstantLR
|
||||
|
||||
|
||||
class Weights(torch.nn.Module):
|
||||
"""
|
||||
This class aims to implements the mask model for
|
||||
@@ -27,11 +28,9 @@ class Weights(torch.nn.Module):
|
||||
"""
|
||||
super().__init__()
|
||||
check_consistency(func, torch.nn.Module)
|
||||
self.sa_weights = torch.nn.Parameter(
|
||||
torch.Tensor()
|
||||
)
|
||||
self.sa_weights = torch.nn.Parameter(torch.Tensor())
|
||||
self.func = func
|
||||
|
||||
|
||||
def forward(self):
|
||||
"""
|
||||
Forward pass implementation for the mask module.
|
||||
@@ -43,6 +42,7 @@ class Weights(torch.nn.Module):
|
||||
"""
|
||||
return self.func(self.sa_weights)
|
||||
|
||||
|
||||
class SAPINN(PINNInterface):
|
||||
r"""
|
||||
Self Adaptive Physics Informed Neural Network (SAPINN) solver class.
|
||||
@@ -106,22 +106,22 @@ class SAPINN(PINNInterface):
|
||||
DOI: `10.1016/
|
||||
j.jcp.2022.111722 <https://doi.org/10.1016/j.jcp.2022.111722>`_.
|
||||
"""
|
||||
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
problem,
|
||||
model,
|
||||
weights_function=torch.nn.Sigmoid(),
|
||||
extra_features=None,
|
||||
loss=torch.nn.MSELoss(),
|
||||
optimizer_model=torch.optim.Adam,
|
||||
optimizer_model_kwargs={"lr" : 0.001},
|
||||
optimizer_weights=torch.optim.Adam,
|
||||
optimizer_weights_kwargs={"lr" : 0.001},
|
||||
scheduler_model=ConstantLR,
|
||||
scheduler_model_kwargs={"factor" : 1, "total_iters" : 0},
|
||||
scheduler_weights=ConstantLR,
|
||||
scheduler_weights_kwargs={"factor" : 1, "total_iters" : 0}
|
||||
self,
|
||||
problem,
|
||||
model,
|
||||
weights_function=torch.nn.Sigmoid(),
|
||||
extra_features=None,
|
||||
loss=torch.nn.MSELoss(),
|
||||
optimizer_model=torch.optim.Adam,
|
||||
optimizer_model_kwargs={"lr": 0.001},
|
||||
optimizer_weights=torch.optim.Adam,
|
||||
optimizer_weights_kwargs={"lr": 0.001},
|
||||
scheduler_model=ConstantLR,
|
||||
scheduler_model_kwargs={"factor": 1, "total_iters": 0},
|
||||
scheduler_weights=ConstantLR,
|
||||
scheduler_weights_kwargs={"factor": 1, "total_iters": 0},
|
||||
):
|
||||
"""
|
||||
:param AbstractProblem problem: The formualation of the problem.
|
||||
@@ -167,19 +167,18 @@ class SAPINN(PINNInterface):
|
||||
weights_dict[condition_name] = Weights(weights_function)
|
||||
weights_dict = torch.nn.ModuleDict(weights_dict)
|
||||
|
||||
|
||||
super().__init__(
|
||||
models=[model, weights_dict],
|
||||
problem=problem,
|
||||
optimizers=[optimizer_model, optimizer_weights],
|
||||
optimizers_kwargs=[
|
||||
optimizer_model_kwargs,
|
||||
optimizer_weights_kwargs
|
||||
optimizer_weights_kwargs,
|
||||
],
|
||||
extra_features=extra_features,
|
||||
loss=loss
|
||||
loss=loss,
|
||||
)
|
||||
|
||||
|
||||
# set automatic optimization
|
||||
self.automatic_optimization = False
|
||||
|
||||
@@ -191,12 +190,8 @@ class SAPINN(PINNInterface):
|
||||
|
||||
# assign schedulers
|
||||
self._schedulers = [
|
||||
scheduler_model(
|
||||
self.optimizers[0], **scheduler_model_kwargs
|
||||
),
|
||||
scheduler_weights(
|
||||
self.optimizers[1], **scheduler_weights_kwargs
|
||||
),
|
||||
scheduler_model(self.optimizers[0], **scheduler_model_kwargs),
|
||||
scheduler_weights(self.optimizers[1], **scheduler_weights_kwargs),
|
||||
]
|
||||
|
||||
self._model = self.models[0]
|
||||
@@ -204,7 +199,7 @@ class SAPINN(PINNInterface):
|
||||
|
||||
self._vectorial_loss = deepcopy(loss)
|
||||
self._vectorial_loss.reduction = "none"
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Forward pass implementation for the PINN
|
||||
@@ -219,7 +214,7 @@ class SAPINN(PINNInterface):
|
||||
:rtype: LabelTensor
|
||||
"""
|
||||
return self.neural_net(x)
|
||||
|
||||
|
||||
def loss_phys(self, samples, equation):
|
||||
"""
|
||||
Computes the physics loss for the SAPINN solver based on given
|
||||
@@ -235,7 +230,7 @@ class SAPINN(PINNInterface):
|
||||
# train weights
|
||||
self.optimizer_weights.zero_grad()
|
||||
weighted_loss, _ = self._loss_phys(samples, equation)
|
||||
loss_value = - weighted_loss.as_subclass(torch.Tensor)
|
||||
loss_value = -weighted_loss.as_subclass(torch.Tensor)
|
||||
self.manual_backward(loss_value)
|
||||
self.optimizer_weights.step()
|
||||
|
||||
@@ -271,7 +266,7 @@ class SAPINN(PINNInterface):
|
||||
# train weights
|
||||
self.optimizer_weights.zero_grad()
|
||||
weighted_loss, _ = self._loss_data(input_tensor, output_tensor)
|
||||
loss_value = - weighted_loss.as_subclass(torch.Tensor)
|
||||
loss_value = -weighted_loss.as_subclass(torch.Tensor)
|
||||
self.manual_backward(loss_value)
|
||||
self.optimizer_weights.step()
|
||||
|
||||
@@ -291,7 +286,7 @@ class SAPINN(PINNInterface):
|
||||
# store loss without weights
|
||||
self.store_log(loss_value=float(loss))
|
||||
return loss_value
|
||||
|
||||
|
||||
def configure_optimizers(self):
|
||||
"""
|
||||
Optimizer configuration for the SAPINN
|
||||
@@ -312,8 +307,8 @@ class SAPINN(PINNInterface):
|
||||
}
|
||||
)
|
||||
return self.optimizers, self._schedulers
|
||||
|
||||
def on_train_batch_end(self,outputs, batch, batch_idx):
|
||||
|
||||
def on_train_batch_end(self, outputs, batch, batch_idx):
|
||||
"""
|
||||
This method is called at the end of each training batch, and ovverides
|
||||
the PytorchLightining implementation for logging the checkpoints.
|
||||
@@ -327,9 +322,11 @@ class SAPINN(PINNInterface):
|
||||
:rtype: Any
|
||||
"""
|
||||
# increase by one the counter of optimization to save loggers
|
||||
self.trainer.fit_loop.epoch_loop.manual_optimization.optim_step_progress.total.completed += 1
|
||||
self.trainer.fit_loop.epoch_loop.manual_optimization.optim_step_progress.total.completed += (
|
||||
1
|
||||
)
|
||||
return super().on_train_batch_end(outputs, batch, batch_idx)
|
||||
|
||||
|
||||
def on_train_start(self):
|
||||
"""
|
||||
This method is called at the start of the training for setting
|
||||
@@ -343,12 +340,11 @@ class SAPINN(PINNInterface):
|
||||
self.trainer._accelerator_connector._accelerator_flag
|
||||
)
|
||||
for condition_name, tensor in self.problem.input_pts.items():
|
||||
self.weights_dict.torchmodel[condition_name].sa_weights.data = torch.rand(
|
||||
(tensor.shape[0], 1),
|
||||
device = device
|
||||
self.weights_dict.torchmodel[condition_name].sa_weights.data = (
|
||||
torch.rand((tensor.shape[0], 1), device=device)
|
||||
)
|
||||
return super().on_train_start()
|
||||
|
||||
|
||||
def on_load_checkpoint(self, checkpoint):
|
||||
"""
|
||||
Overriding the Pytorch Lightning ``on_load_checkpoint`` to handle
|
||||
@@ -358,8 +354,8 @@ class SAPINN(PINNInterface):
|
||||
:param dict checkpoint: Pytorch Lightning checkpoint dict.
|
||||
"""
|
||||
for condition_name, tensor in self.problem.input_pts.items():
|
||||
self.weights_dict.torchmodel[condition_name].sa_weights.data = torch.rand(
|
||||
(tensor.shape[0], 1)
|
||||
self.weights_dict.torchmodel[condition_name].sa_weights.data = (
|
||||
torch.rand((tensor.shape[0], 1))
|
||||
)
|
||||
return super().on_load_checkpoint(checkpoint)
|
||||
|
||||
@@ -370,13 +366,13 @@ class SAPINN(PINNInterface):
|
||||
:param LabelTensor samples: Input samples to evaluate the physics loss.
|
||||
:param EquationInterface equation: the governing equation representing
|
||||
the physics.
|
||||
|
||||
|
||||
:return: tuple with weighted and not weighted scalar loss
|
||||
:rtype: List[LabelTensor, LabelTensor]
|
||||
"""
|
||||
residual = self.compute_residual(samples, equation)
|
||||
return self._compute_loss(residual)
|
||||
|
||||
|
||||
def _loss_data(self, input_tensor, output_tensor):
|
||||
"""
|
||||
Elaboration of the loss related to data for the SAPINN solver.
|
||||
@@ -384,7 +380,7 @@ class SAPINN(PINNInterface):
|
||||
:param LabelTensor input_tensor: The input to the neural networks.
|
||||
:param LabelTensor output_tensor: The true solution to compare the
|
||||
network solution.
|
||||
|
||||
|
||||
:return: tuple with weighted and not weighted scalar loss
|
||||
:rtype: List[LabelTensor, LabelTensor]
|
||||
"""
|
||||
@@ -396,19 +392,21 @@ class SAPINN(PINNInterface):
|
||||
Elaboration of the pointwise loss through the mask model and the
|
||||
self adaptive weights
|
||||
|
||||
:param LabelTensor residual: the matrix of residuals that have to
|
||||
:param LabelTensor residual: the matrix of residuals that have to
|
||||
be weighted
|
||||
|
||||
:return: tuple with weighted and not weighted loss
|
||||
:rtype List[LabelTensor, LabelTensor]
|
||||
"""
|
||||
weights = self.weights_dict.torchmodel[
|
||||
self.current_condition_name].forward()
|
||||
loss_value = self._vectorial_loss(torch.zeros_like(
|
||||
residual, requires_grad=True), residual)
|
||||
self.current_condition_name
|
||||
].forward()
|
||||
loss_value = self._vectorial_loss(
|
||||
torch.zeros_like(residual, requires_grad=True), residual
|
||||
)
|
||||
return (
|
||||
self._vect_to_scalar(weights * loss_value),
|
||||
self._vect_to_scalar(loss_value)
|
||||
self._vect_to_scalar(loss_value),
|
||||
)
|
||||
|
||||
def _vect_to_scalar(self, loss_value):
|
||||
@@ -426,10 +424,11 @@ class SAPINN(PINNInterface):
|
||||
elif self.loss.reduction == "sum":
|
||||
ret = torch.sum(loss_value)
|
||||
else:
|
||||
raise RuntimeError(f"Invalid reduction, got {self.loss.reduction} "
|
||||
"but expected mean or sum.")
|
||||
raise RuntimeError(
|
||||
f"Invalid reduction, got {self.loss.reduction} "
|
||||
"but expected mean or sum."
|
||||
)
|
||||
return ret
|
||||
|
||||
|
||||
@property
|
||||
def neural_net(self):
|
||||
@@ -440,7 +439,7 @@ class SAPINN(PINNInterface):
|
||||
:rtype: torch.nn.Module
|
||||
"""
|
||||
return self.models[0]
|
||||
|
||||
|
||||
@property
|
||||
def weights_dict(self):
|
||||
"""
|
||||
@@ -462,7 +461,7 @@ class SAPINN(PINNInterface):
|
||||
:rtype: torch.optim.lr_scheduler._LRScheduler
|
||||
"""
|
||||
return self._scheduler[0]
|
||||
|
||||
|
||||
@property
|
||||
def scheduler_weights(self):
|
||||
"""
|
||||
@@ -482,7 +481,7 @@ class SAPINN(PINNInterface):
|
||||
:rtype: torch.optim.Optimizer
|
||||
"""
|
||||
return self.optimizers[0]
|
||||
|
||||
|
||||
@property
|
||||
def optimizer_weights(self):
|
||||
"""
|
||||
@@ -491,4 +490,4 @@ class SAPINN(PINNInterface):
|
||||
:return: The optimizer for the mask model.
|
||||
:rtype: torch.optim.Optimizer
|
||||
"""
|
||||
return self.optimizers[1]
|
||||
return self.optimizers[1]
|
||||
|
||||
Reference in New Issue
Block a user