fix rendering part 2

This commit is contained in:
giovanni
2025-03-14 00:10:18 +01:00
committed by Nicola Demo
parent e0ad4dc8a0
commit d2e3f458ab
17 changed files with 217 additions and 147 deletions

View File

@@ -45,16 +45,22 @@ class GAROM(MultiSolverInterface):
:param torch.nn.Module generator: The generator model.
:param torch.nn.Module discriminator: The discriminator model.
:param torch.nn.Module loss: The loss function to be minimized.
If ``None``, ``PowerLoss(p=1)`` is used. Default is ``None``.
If ``None``, :class:`~pina.loss.power_loss.PowerLoss` with ``p=1``
is used. Default is ``None``.
:param Optimizer optimizer_generator: The optimizer for the generator.
If `None`, the Adam optimizer is used. Default is ``None``.
:param Optimizer optimizer_discriminator: The optimizer for the
discriminator. If `None`, the Adam optimizer is used.
If `None`, the :class:`torch.optim.Adam` optimizer is used.
Default is ``None``.
:param Optimizer optimizer_discriminator: The optimizer for the
discriminator. If `None`, the :class:`torch.optim.Adam` optimizer is
used. Default is ``None``.
:param Scheduler scheduler_generator: The learning rate scheduler for
the generator.
If `None`, the :class:`torch.optim.lr_scheduler.ConstantLR`
scheduler is used. Default is ``None``.
:param Scheduler scheduler_discriminator: The learning rate scheduler
for the discriminator.
If `None`, the :class:`torch.optim.lr_scheduler.ConstantLR`
scheduler is used. Default is ``None``.
:param float gamma: Ratio of expected loss for generator and
discriminator. Default is ``0.3``.
:param float lambda_k: Learning rate for control theory optimization.
@@ -109,7 +115,7 @@ class GAROM(MultiSolverInterface):
of the solution. Default is ``False``.
:return: The expected value of the generator distribution. If
``variance=True``, the method returns also the variance.
:rtype: torch.Tensor | tuple(torch.Tensor, torch.Tensor)
:rtype: torch.Tensor | tuple[torch.Tensor, torch.Tensor]
"""
# sampling
@@ -143,7 +149,7 @@ class GAROM(MultiSolverInterface):
:param torch.Tensor parameters: The input tensor.
:param torch.Tensor snapshots: The target tensor.
:return: The residual loss and the generator loss.
:rtype: tuple(torch.Tensor, torch.Tensor)
:rtype: tuple[torch.Tensor, torch.Tensor]
"""
optimizer = self.optimizer_generator
optimizer.zero_grad()
@@ -170,7 +176,8 @@ class GAROM(MultiSolverInterface):
:param torch.Tensor outputs: The ``model``'s output for the current
batch.
:param dict batch: The current batch of data.
:param list[tuple[str, dict]] batch: A batch of data. Each element is a
tuple containing a condition name and a dictionary of points.
:param int batch_idx: The index of the current batch.
"""
# increase by one the counter of optimization to save loggers
@@ -187,7 +194,7 @@ class GAROM(MultiSolverInterface):
:param torch.Tensor parameters: The input tensor.
:param torch.Tensor snapshots: The target tensor.
:return: The residual loss and the generator loss.
:rtype: tuple(torch.Tensor, torch.Tensor)
:rtype: tuple[torch.Tensor, torch.Tensor]
"""
optimizer = self.optimizer_discriminator
optimizer.zero_grad()
@@ -234,9 +241,12 @@ class GAROM(MultiSolverInterface):
"""
The optimization cycle for the GAROM solver.
:param tuple batch: The batch element in the dataloader.
:return: The loss of the optimization cycle.
:rtype: LabelTensor
:param list[tuple[str, dict]] batch: A batch of data. Each element is a
tuple containing a condition name and a dictionary of points.
:return: The losses computed for all conditions in the batch, casted
to a subclass of :class:`torch.Tensor`. It should return a dict
containing the condition name and the associated scalar loss.
:rtype: dict
"""
condition_loss = {}
for condition_name, points in batch:
@@ -265,7 +275,8 @@ class GAROM(MultiSolverInterface):
"""
The validation step for the PINN solver.
:param dict batch: The batch of data to use in the validation step.
:param list[tuple[str, dict]] batch: A batch of data. Each element is a
tuple containing a condition name and a dictionary of points.
:return: The loss of the validation step.
:rtype: torch.Tensor
"""
@@ -287,7 +298,8 @@ class GAROM(MultiSolverInterface):
"""
The test step for the PINN solver.
:param dict batch: The batch of data to use in the test step.
:param list[tuple[str, dict]] batch: A batch of data. Each element is a
tuple containing a condition name and a dictionary of points.
:return: The loss of the test step.
:rtype: torch.Tensor
"""