fix rendering part 2
This commit is contained in:
@@ -59,11 +59,11 @@ class RBAPINN(PINN):
|
||||
.. seealso::
|
||||
**Original reference**: Sokratis J. Anagnostopoulos, Juan D. Toscano,
|
||||
Nikolaos Stergiopulos, and George E. Karniadakis.
|
||||
"Residual-based attention and connection to information
|
||||
bottleneck theory in PINNs".
|
||||
*Residual-based attention and connection to information
|
||||
bottleneck theory in PINNs.*
|
||||
Computer Methods in Applied Mechanics and Engineering 421 (2024): 116805
|
||||
DOI: `10.1016/
|
||||
j.cma.2024.116805 <https://doi.org/10.1016/j.cma.2024.116805>`_.
|
||||
DOI: `10.1016/j.cma.2024.116805
|
||||
<https://doi.org/10.1016/j.cma.2024.116805>`_.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -82,15 +82,16 @@ class RBAPINN(PINN):
|
||||
|
||||
:param AbstractProblem problem: The problem to be solved.
|
||||
:param torch.nn.Module model: The neural network model to be used.
|
||||
:param torch.optim.Optimizer optimizer: The optimizer to be used.
|
||||
If `None`, the Adam optimizer is used. Default is ``None``.
|
||||
:param torch.optim.LRScheduler scheduler: Learning rate scheduler.
|
||||
If `None`, the constant learning rate scheduler is used.
|
||||
param Optimizer optimizer: The optimizer to be used.
|
||||
If `None`, the :class:`torch.optim.Adam` optimizer is used.
|
||||
Default is ``None``.
|
||||
:param Scheduler scheduler: Learning rate scheduler.
|
||||
If `None`, the :class:`torch.optim.lr_scheduler.ConstantLR`
|
||||
scheduler is used. Default is ``None``.
|
||||
:param WeightingInterface weighting: The weighting schema to be used.
|
||||
If `None`, no weighting schema is used. Default is ``None``.
|
||||
:param torch.nn.Module loss: The loss function to be minimized.
|
||||
If `None`, the Mean Squared Error (MSE) loss is used.
|
||||
If `None`, the :class:`torch.nn.MSELoss` loss is used.
|
||||
Default is `None`.
|
||||
:param float | int eta: The learning rate for the weights of the
|
||||
residuals. Default is ``0.001``.
|
||||
@@ -147,7 +148,7 @@ class RBAPINN(PINN):
|
||||
:param LabelTensor loss_value: the tensor of pointwise losses.
|
||||
:raises RuntimeError: If the loss reduction is not ``mean`` or ``sum``.
|
||||
:return: The computed scalar loss.
|
||||
:rtype LabelTensor
|
||||
:rtype: LabelTensor
|
||||
"""
|
||||
if self.loss.reduction == "mean":
|
||||
ret = torch.mean(loss_value)
|
||||
|
||||
Reference in New Issue
Block a user