Formatting
* Adding black as dev dependency * Formatting pina code * Formatting tests
This commit is contained in:
committed by
Nicola Demo
parent
4c4482b155
commit
42ab1a666b
@@ -1,4 +1,4 @@
|
||||
""" Module for Residual-Based Attention PINN. """
|
||||
"""Module for Residual-Based Attention PINN."""
|
||||
|
||||
from copy import deepcopy
|
||||
import torch
|
||||
@@ -66,15 +66,17 @@ class RBAPINN(PINN):
|
||||
j.cma.2024.116805 <https://doi.org/10.1016/j.cma.2024.116805>`_.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
problem,
|
||||
model,
|
||||
optimizer=None,
|
||||
scheduler=None,
|
||||
weighting=None,
|
||||
loss=None,
|
||||
eta=0.001,
|
||||
gamma=0.999):
|
||||
def __init__(
|
||||
self,
|
||||
problem,
|
||||
model,
|
||||
optimizer=None,
|
||||
scheduler=None,
|
||||
weighting=None,
|
||||
loss=None,
|
||||
eta=0.001,
|
||||
gamma=0.999,
|
||||
):
|
||||
"""
|
||||
:param torch.nn.Module model: The neural network model to use.
|
||||
:param AbstractProblem problem: The formulation of the problem.
|
||||
@@ -86,17 +88,19 @@ class RBAPINN(PINN):
|
||||
default `None`.
|
||||
:param torch.nn.Module loss: The loss function to be minimized;
|
||||
default `None`.
|
||||
:param float | int eta: The learning rate for the weights of the
|
||||
:param float | int eta: The learning rate for the weights of the
|
||||
residual; default 0.001.
|
||||
:param float gamma: The decay parameter in the update of the weights
|
||||
of the residual. Must be between 0 and 1; default 0.999.
|
||||
"""
|
||||
super().__init__(model=model,
|
||||
problem=problem,
|
||||
optimizer=optimizer,
|
||||
scheduler=scheduler,
|
||||
weighting=weighting,
|
||||
loss=loss)
|
||||
super().__init__(
|
||||
model=model,
|
||||
problem=problem,
|
||||
optimizer=optimizer,
|
||||
scheduler=scheduler,
|
||||
weighting=weighting,
|
||||
loss=loss,
|
||||
)
|
||||
|
||||
# check consistency
|
||||
check_consistency(eta, (float, int))
|
||||
@@ -119,9 +123,11 @@ class RBAPINN(PINN):
|
||||
# for now RBAPINN is implemented only for batch_size = None
|
||||
def on_train_start(self):
|
||||
if self.trainer.batch_size is not None:
|
||||
raise NotImplementedError("RBAPINN only works with full batch "
|
||||
"size, set batch_size=None inside the "
|
||||
"Trainer to use the solver.")
|
||||
raise NotImplementedError(
|
||||
"RBAPINN only works with full batch "
|
||||
"size, set batch_size=None inside the "
|
||||
"Trainer to use the solver."
|
||||
)
|
||||
return super().on_train_start()
|
||||
|
||||
def _vect_to_scalar(self, loss_value):
|
||||
@@ -160,10 +166,11 @@ class RBAPINN(PINN):
|
||||
cond = self.current_condition_name
|
||||
|
||||
r_norm = (
|
||||
self.eta * torch.abs(residual)
|
||||
self.eta
|
||||
* torch.abs(residual)
|
||||
/ (torch.max(torch.abs(residual)) + 1e-12)
|
||||
)
|
||||
self.weights[cond] = (self.gamma*self.weights[cond] + r_norm).detach()
|
||||
self.weights[cond] = (self.gamma * self.weights[cond] + r_norm).detach()
|
||||
|
||||
loss_value = self._vectorial_loss(
|
||||
torch.zeros_like(residual, requires_grad=True), residual
|
||||
|
||||
Reference in New Issue
Block a user