GAROM solver loss update
Changing from `LpLoss` to `PowerLoss`
This commit is contained in:
committed by
Nicola Demo
parent
80b4b43460
commit
2e2fe93458
@@ -1,4 +1,5 @@
|
||||
""" Module for PINN """
|
||||
""" Module for GAROM """
|
||||
|
||||
import torch
|
||||
try:
|
||||
from torch.optim.lr_scheduler import LRScheduler # torch >= 2.0
|
||||
@@ -8,7 +9,7 @@ except ImportError:
|
||||
from torch.optim.lr_scheduler import ConstantLR
|
||||
from .solver import SolverInterface
|
||||
from ..utils import check_consistency
|
||||
from ..loss import LossInterface, LpLoss
|
||||
from ..loss import LossInterface, PowerLoss
|
||||
from torch.nn.modules.loss import _Loss
|
||||
|
||||
|
||||
@@ -58,7 +59,7 @@ class GAROM(SolverInterface):
|
||||
extra features for each.
|
||||
:param torch.nn.Module loss: The loss function used as minimizer,
|
||||
default ``None``. If ``loss`` is ``None`` the defualt
|
||||
``LpLoss(p=1)`` is used, as in the original paper.
|
||||
``PowerLoss(p=1)`` is used, as in the original paper.
|
||||
:param torch.optim.Optimizer optimizer_generator: The neural
|
||||
network optimizer to use for the generator network
|
||||
, default is `torch.optim.Adam`.
|
||||
@@ -102,7 +103,7 @@ class GAROM(SolverInterface):
|
||||
|
||||
# set loss
|
||||
if loss is None:
|
||||
loss = LpLoss(p=1)
|
||||
loss = PowerLoss(p=1)
|
||||
|
||||
# check consistency
|
||||
check_consistency(scheduler_generator, LRScheduler, subclass=True)
|
||||
|
||||
Reference in New Issue
Block a user