fix rendering part 2

This commit is contained in:
giovanni
2025-03-14 00:10:18 +01:00
committed by Nicola Demo
parent e0ad4dc8a0
commit d2e3f458ab
17 changed files with 217 additions and 147 deletions

View File

@@ -46,8 +46,8 @@ class CompetitivePINN(PINNInterface, MultiSolverInterface):
.. seealso::
**Original reference**: Zeng, Qi, et al.
"Competitive physics informed networks." International Conference on
Learning Representations, ICLR 2022
*Competitive physics informed networks.*
International Conference on Learning Representations, ICLR 2022
`OpenReview Preprint <https://openreview.net/forum?id=z9SIj-IM7tn>`_.
"""
@@ -72,21 +72,23 @@ class CompetitivePINN(PINNInterface, MultiSolverInterface):
If `None`, the discriminator is a deepcopy of the ``model``.
Default is ``None``.
:param torch.optim.Optimizer optimizer_model: The optimizer of the
``model``. If `None`, the Adam optimizer is used.
Default is ``None``.
``model``. If `None`, the :class:`torch.optim.Adam` optimizer is
used. Default is ``None``.
:param torch.optim.Optimizer optimizer_discriminator: The optimizer of
the ``discriminator``. If `None`, the Adam optimizer is used.
Default is ``None``.
:param torch.optim.LRScheduler scheduler_model: Learning rate scheduler
for the ``model``. If `None`, the constant learning rate scheduler
is used. Default is ``None``.
:param torch.optim.LRScheduler scheduler_discriminator: Learning rate
scheduler for the ``discriminator``. If `None`, the constant
learning rate scheduler is used. Default is ``None``.
the ``discriminator``. If `None`, the :class:`torch.optim.Adam`
optimizer is used. Default is ``None``.
:param Scheduler scheduler_model: Learning rate scheduler for the
``model``.
If `None`, the :class:`torch.optim.lr_scheduler.ConstantLR`
scheduler is used. Default is ``None``.
:param Scheduler scheduler_discriminator: Learning rate scheduler for
the ``discriminator``.
If `None`, the :class:`torch.optim.lr_scheduler.ConstantLR`
scheduler is used. Default is ``None``.
:param WeightingInterface weighting: The weighting schema to be used.
If `None`, no weighting schema is used. Default is ``None``.
:param torch.nn.Module loss: The loss function to be minimized.
If `None`, the Mean Squared Error (MSE) loss is used.
If `None`, the :class:`torch.nn.MSELoss` loss is used.
Default is `None`.
"""
if discriminator is None:
@@ -118,7 +120,8 @@ class CompetitivePINN(PINNInterface, MultiSolverInterface):
"""
Solver training step, overridden to perform manual optimization.
:param dict batch: The batch element in the dataloader.
:param list[tuple[str, dict]] batch: A batch of data. Each element is a
tuple containing a condition name and a dictionary of points.
:return: The aggregated loss.
:rtype: LabelTensor
"""
@@ -163,7 +166,7 @@ class CompetitivePINN(PINNInterface, MultiSolverInterface):
Optimizer configuration.
:return: The optimizers and the schedulers
:rtype: tuple(list, list)
:rtype: tuple[list[Optimizer], list[Scheduler]]
"""
# If the problem is an InverseProblem, add the unknown parameters
# to the parameters to be optimized
@@ -198,7 +201,8 @@ class CompetitivePINN(PINNInterface, MultiSolverInterface):
:param torch.Tensor outputs: The ``model``'s output for the current
batch.
:param dict batch: The current batch of data.
:param list[tuple[str, dict]] batch: A batch of data. Each element is a
tuple containing a condition name and a dictionary of points.
:param int batch_idx: The index of the current batch.
"""
# increase by one the counter of optimization to save loggers
@@ -234,7 +238,7 @@ class CompetitivePINN(PINNInterface, MultiSolverInterface):
The optimizer associated to the model.
:return: The optimizer for the model.
:rtype: torch.optim.Optimizer
:rtype: Optimizer
"""
return self.optimizers[0]
@@ -244,7 +248,7 @@ class CompetitivePINN(PINNInterface, MultiSolverInterface):
The optimizer associated to the discriminator.
:return: The optimizer for the discriminator.
:rtype: torch.optim.Optimizer
:rtype: Optimizer
"""
return self.optimizers[1]
@@ -254,7 +258,7 @@ class CompetitivePINN(PINNInterface, MultiSolverInterface):
The scheduler associated to the model.
:return: The scheduler for the model.
:rtype: torch.optim.lr_scheduler._LRScheduler
:rtype: Scheduler
"""
return self.schedulers[0]
@@ -264,6 +268,6 @@ class CompetitivePINN(PINNInterface, MultiSolverInterface):
The scheduler associated to the discriminator.
:return: The scheduler for the discriminator.
:rtype: torch.optim.lr_scheduler._LRScheduler
:rtype: Scheduler
"""
return self.schedulers[1]