fix rendering part 2
This commit is contained in:
@@ -45,16 +45,22 @@ class GAROM(MultiSolverInterface):
|
||||
:param torch.nn.Module generator: The generator model.
|
||||
:param torch.nn.Module discriminator: The discriminator model.
|
||||
:param torch.nn.Module loss: The loss function to be minimized.
|
||||
If ``None``, ``PowerLoss(p=1)`` is used. Default is ``None``.
|
||||
If ``None``, :class:`~pina.loss.power_loss.PowerLoss` with ``p=1``
|
||||
is used. Default is ``None``.
|
||||
:param Optimizer optimizer_generator: The optimizer for the generator.
|
||||
If `None`, the Adam optimizer is used. Default is ``None``.
|
||||
:param Optimizer optimizer_discriminator: The optimizer for the
|
||||
discriminator. If `None`, the Adam optimizer is used.
|
||||
If `None`, the :class:`torch.optim.Adam` optimizer is used.
|
||||
Default is ``None``.
|
||||
:param Optimizer optimizer_discriminator: The optimizer for the
|
||||
discriminator. If `None`, the :class:`torch.optim.Adam` optimizer is
|
||||
used. Default is ``None``.
|
||||
:param Scheduler scheduler_generator: The learning rate scheduler for
|
||||
the generator.
|
||||
If `None`, the :class:`torch.optim.lr_scheduler.ConstantLR`
|
||||
scheduler is used. Default is ``None``.
|
||||
:param Scheduler scheduler_discriminator: The learning rate scheduler
|
||||
for the discriminator.
|
||||
If `None`, the :class:`torch.optim.lr_scheduler.ConstantLR`
|
||||
scheduler is used. Default is ``None``.
|
||||
:param float gamma: Ratio of expected loss for generator and
|
||||
discriminator. Default is ``0.3``.
|
||||
:param float lambda_k: Learning rate for control theory optimization.
|
||||
@@ -109,7 +115,7 @@ class GAROM(MultiSolverInterface):
|
||||
of the solution. Default is ``False``.
|
||||
:return: The expected value of the generator distribution. If
|
||||
``variance=True``, the method returns also the variance.
|
||||
:rtype: torch.Tensor | tuple(torch.Tensor, torch.Tensor)
|
||||
:rtype: torch.Tensor | tuple[torch.Tensor, torch.Tensor]
|
||||
"""
|
||||
|
||||
# sampling
|
||||
@@ -143,7 +149,7 @@ class GAROM(MultiSolverInterface):
|
||||
:param torch.Tensor parameters: The input tensor.
|
||||
:param torch.Tensor snapshots: The target tensor.
|
||||
:return: The residual loss and the generator loss.
|
||||
:rtype: tuple(torch.Tensor, torch.Tensor)
|
||||
:rtype: tuple[torch.Tensor, torch.Tensor]
|
||||
"""
|
||||
optimizer = self.optimizer_generator
|
||||
optimizer.zero_grad()
|
||||
@@ -170,7 +176,8 @@ class GAROM(MultiSolverInterface):
|
||||
|
||||
:param torch.Tensor outputs: The ``model``'s output for the current
|
||||
batch.
|
||||
:param dict batch: The current batch of data.
|
||||
:param list[tuple[str, dict]] batch: A batch of data. Each element is a
|
||||
tuple containing a condition name and a dictionary of points.
|
||||
:param int batch_idx: The index of the current batch.
|
||||
"""
|
||||
# increase by one the counter of optimization to save loggers
|
||||
@@ -187,7 +194,7 @@ class GAROM(MultiSolverInterface):
|
||||
:param torch.Tensor parameters: The input tensor.
|
||||
:param torch.Tensor snapshots: The target tensor.
|
||||
:return: The residual loss and the generator loss.
|
||||
:rtype: tuple(torch.Tensor, torch.Tensor)
|
||||
:rtype: tuple[torch.Tensor, torch.Tensor]
|
||||
"""
|
||||
optimizer = self.optimizer_discriminator
|
||||
optimizer.zero_grad()
|
||||
@@ -234,9 +241,12 @@ class GAROM(MultiSolverInterface):
|
||||
"""
|
||||
The optimization cycle for the GAROM solver.
|
||||
|
||||
:param tuple batch: The batch element in the dataloader.
|
||||
:return: The loss of the optimization cycle.
|
||||
:rtype: LabelTensor
|
||||
:param list[tuple[str, dict]] batch: A batch of data. Each element is a
|
||||
tuple containing a condition name and a dictionary of points.
|
||||
:return: The losses computed for all conditions in the batch, casted
|
||||
to a subclass of :class:`torch.Tensor`. It should return a dict
|
||||
containing the condition name and the associated scalar loss.
|
||||
:rtype: dict
|
||||
"""
|
||||
condition_loss = {}
|
||||
for condition_name, points in batch:
|
||||
@@ -265,7 +275,8 @@ class GAROM(MultiSolverInterface):
|
||||
"""
|
||||
The validation step for the PINN solver.
|
||||
|
||||
:param dict batch: The batch of data to use in the validation step.
|
||||
:param list[tuple[str, dict]] batch: A batch of data. Each element is a
|
||||
tuple containing a condition name and a dictionary of points.
|
||||
:return: The loss of the validation step.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
@@ -287,7 +298,8 @@ class GAROM(MultiSolverInterface):
|
||||
"""
|
||||
The test step for the PINN solver.
|
||||
|
||||
:param dict batch: The batch of data to use in the test step.
|
||||
:param list[tuple[str, dict]] batch: A batch of data. Each element is a
|
||||
tuple containing a condition name and a dictionary of points.
|
||||
:return: The loss of the test step.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user