GAROM solver loss update

Changing from `LpLoss` to `PowerLoss`
This commit is contained in:
Dario Coscia
2023-10-06 15:54:10 +02:00
committed by Nicola Demo
parent 80b4b43460
commit 2e2fe93458

View File

@@ -1,4 +1,5 @@
""" Module for PINN """ """ Module for GAROM """
import torch import torch
try: try:
from torch.optim.lr_scheduler import LRScheduler # torch >= 2.0 from torch.optim.lr_scheduler import LRScheduler # torch >= 2.0
@@ -8,7 +9,7 @@ except ImportError:
from torch.optim.lr_scheduler import ConstantLR from torch.optim.lr_scheduler import ConstantLR
from .solver import SolverInterface from .solver import SolverInterface
from ..utils import check_consistency from ..utils import check_consistency
from ..loss import LossInterface, LpLoss from ..loss import LossInterface, PowerLoss
from torch.nn.modules.loss import _Loss from torch.nn.modules.loss import _Loss
@@ -58,7 +59,7 @@ class GAROM(SolverInterface):
extra features for each. extra features for each.
:param torch.nn.Module loss: The loss function used as minimizer, :param torch.nn.Module loss: The loss function used as minimizer,
default ``None``. If ``loss`` is ``None`` the defualt default ``None``. If ``loss`` is ``None`` the defualt
``LpLoss(p=1)`` is used, as in the original paper. ``PowerLoss(p=1)`` is used, as in the original paper.
:param torch.optim.Optimizer optimizer_generator: The neural :param torch.optim.Optimizer optimizer_generator: The neural
network optimizer to use for the generator network network optimizer to use for the generator network
, default is `torch.optim.Adam`. , default is `torch.optim.Adam`.
@@ -102,7 +103,7 @@ class GAROM(SolverInterface):
# set loss # set loss
if loss is None: if loss is None:
loss = LpLoss(p=1) loss = PowerLoss(p=1)
# check consistency # check consistency
check_consistency(scheduler_generator, LRScheduler, subclass=True) check_consistency(scheduler_generator, LRScheduler, subclass=True)
@@ -264,4 +265,4 @@ class GAROM(SolverInterface):
@property @property
def scheduler_discriminator(self): def scheduler_discriminator(self):
return self._schedulers[1] return self._schedulers[1]