62 lines
2.4 KiB
Python
62 lines
2.4 KiB
Python
"""Module for Loss Interface"""
|
|
|
|
from abc import ABCMeta, abstractmethod
|
|
from torch.nn.modules.loss import _Loss
|
|
import torch
|
|
|
|
|
|
class LossInterface(_Loss, metaclass=ABCMeta):
|
|
"""
|
|
The abstract ``LossInterface`` class. All the class defining a PINA Loss
|
|
should be inheritied from this class.
|
|
"""
|
|
|
|
def __init__(self, reduction="mean"):
|
|
"""
|
|
:param str reduction: Specifies the reduction to apply to the output:
|
|
``none`` | ``mean`` | ``sum``. When ``none``: no reduction
|
|
will be applied, ``mean``: the sum of the output will be divided
|
|
by the number of elements in the output, ``sum``: the output will
|
|
be summed. Note: ``size_average`` and ``reduce`` are in the
|
|
process of being deprecated, and in the meantime, specifying either of
|
|
those two args will override ``reduction``. Default: ``mean``.
|
|
"""
|
|
super().__init__(reduction=reduction, size_average=None, reduce=None)
|
|
|
|
@abstractmethod
|
|
def forward(self, input, target):
|
|
"""Forward method for loss function.
|
|
|
|
:param torch.Tensor input: Input tensor from real data.
|
|
:param torch.Tensor target: Model tensor output.
|
|
:return: Loss evaluation.
|
|
:rtype: torch.Tensor
|
|
"""
|
|
pass
|
|
|
|
def _reduction(self, loss):
|
|
"""Simple helper function to check reduction
|
|
|
|
:param reduction: Specifies the reduction to apply to the output:
|
|
``none`` | ``mean`` | ``sum``. When ``none``: no reduction
|
|
will be applied, ``mean``: the sum of the output will be divided
|
|
by the number of elements in the output, ``sum``: the output will
|
|
be summed. Note: ``size_average`` and ``reduce`` are in the
|
|
process of being deprecated, and in the meantime, specifying either of
|
|
those two args will override ``reduction``. Default: ``mean``.
|
|
:type reduction: str
|
|
:param loss: Loss tensor for each element.
|
|
:type loss: torch.Tensor
|
|
:return: Reduced loss.
|
|
:rtype: torch.Tensor
|
|
"""
|
|
if self.reduction == "none":
|
|
ret = loss
|
|
elif self.reduction == "mean":
|
|
ret = torch.mean(loss, keepdim=True, dim=-1)
|
|
elif self.reduction == "sum":
|
|
ret = torch.sum(loss, keepdim=True, dim=-1)
|
|
else:
|
|
raise ValueError(self.reduction + " is not valid")
|
|
return ret
|