GAROM solver loss update
Changing from `LpLoss` to `PowerLoss`
This commit is contained in:
committed by
Nicola Demo
parent
80b4b43460
commit
2e2fe93458
@@ -1,4 +1,5 @@
|
|||||||
""" Module for PINN """
|
""" Module for GAROM """
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
try:
|
try:
|
||||||
from torch.optim.lr_scheduler import LRScheduler # torch >= 2.0
|
from torch.optim.lr_scheduler import LRScheduler # torch >= 2.0
|
||||||
@@ -8,7 +9,7 @@ except ImportError:
|
|||||||
from torch.optim.lr_scheduler import ConstantLR
|
from torch.optim.lr_scheduler import ConstantLR
|
||||||
from .solver import SolverInterface
|
from .solver import SolverInterface
|
||||||
from ..utils import check_consistency
|
from ..utils import check_consistency
|
||||||
from ..loss import LossInterface, LpLoss
|
from ..loss import LossInterface, PowerLoss
|
||||||
from torch.nn.modules.loss import _Loss
|
from torch.nn.modules.loss import _Loss
|
||||||
|
|
||||||
|
|
||||||
@@ -58,7 +59,7 @@ class GAROM(SolverInterface):
|
|||||||
extra features for each.
|
extra features for each.
|
||||||
:param torch.nn.Module loss: The loss function used as minimizer,
|
:param torch.nn.Module loss: The loss function used as minimizer,
|
||||||
default ``None``. If ``loss`` is ``None`` the defualt
|
default ``None``. If ``loss`` is ``None`` the defualt
|
||||||
``LpLoss(p=1)`` is used, as in the original paper.
|
``PowerLoss(p=1)`` is used, as in the original paper.
|
||||||
:param torch.optim.Optimizer optimizer_generator: The neural
|
:param torch.optim.Optimizer optimizer_generator: The neural
|
||||||
network optimizer to use for the generator network
|
network optimizer to use for the generator network
|
||||||
, default is `torch.optim.Adam`.
|
, default is `torch.optim.Adam`.
|
||||||
@@ -102,7 +103,7 @@ class GAROM(SolverInterface):
|
|||||||
|
|
||||||
# set loss
|
# set loss
|
||||||
if loss is None:
|
if loss is None:
|
||||||
loss = LpLoss(p=1)
|
loss = PowerLoss(p=1)
|
||||||
|
|
||||||
# check consistency
|
# check consistency
|
||||||
check_consistency(scheduler_generator, LRScheduler, subclass=True)
|
check_consistency(scheduler_generator, LRScheduler, subclass=True)
|
||||||
@@ -264,4 +265,4 @@ class GAROM(SolverInterface):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def scheduler_discriminator(self):
|
def scheduler_discriminator(self):
|
||||||
return self._schedulers[1]
|
return self._schedulers[1]
|
||||||
|
|||||||
Reference in New Issue
Block a user