* fix doc loss + adding PowerLoss * adding loss tests folder --------- Co-authored-by: Dario Coscia <dariocoscia@Dario-Coscia.local>
This commit is contained in:
committed by
Nicola Demo
parent
4850b0045d
commit
cc3332b519
@@ -40,6 +40,15 @@ Layers
|
|||||||
|
|
||||||
ContinuousConv <convolution.rst>
|
ContinuousConv <convolution.rst>
|
||||||
|
|
||||||
|
Loss
|
||||||
|
------
|
||||||
|
|
||||||
|
.. toctree::
|
||||||
|
:maxdepth: 3
|
||||||
|
|
||||||
|
LpLoss <lploss.rst>
|
||||||
|
PowerLoss <powerloss.rst>
|
||||||
|
|
||||||
Problem
|
Problem
|
||||||
-------
|
-------
|
||||||
|
|
||||||
|
|||||||
10
docs/source/_rst/lploss.rst
Normal file
10
docs/source/_rst/lploss.rst
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
LpLoss
|
||||||
|
====
|
||||||
|
.. currentmodule:: pina.loss
|
||||||
|
|
||||||
|
.. automodule:: pina.loss
|
||||||
|
|
||||||
|
.. autoclass:: LpLoss
|
||||||
|
:members:
|
||||||
|
:private-members:
|
||||||
|
:show-inheritance:
|
||||||
@@ -1,8 +1,8 @@
|
|||||||
PINN
|
PINN
|
||||||
====
|
====
|
||||||
.. currentmodule:: pina.pinn
|
.. currentmodule:: pina.solvers.pinn
|
||||||
|
|
||||||
.. automodule:: pina.pinn
|
.. automodule:: pina.solvers.pinn
|
||||||
|
|
||||||
.. autoclass:: PINN
|
.. autoclass:: PINN
|
||||||
:members:
|
:members:
|
||||||
|
|||||||
10
docs/source/_rst/powerloss.rst
Normal file
10
docs/source/_rst/powerloss.rst
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
PowerLoss
|
||||||
|
=========
|
||||||
|
.. currentmodule:: pina.loss
|
||||||
|
|
||||||
|
.. automodule:: pina.loss
|
||||||
|
|
||||||
|
.. autoclass:: PowerLoss
|
||||||
|
:members:
|
||||||
|
:private-members:
|
||||||
|
:show-inheritance:
|
||||||
90
pina/loss.py
90
pina/loss.py
@@ -1,4 +1,6 @@
|
|||||||
""" Module for EquationInterface class """
|
""" Module for Loss class """
|
||||||
|
|
||||||
|
|
||||||
from abc import ABCMeta, abstractmethod
|
from abc import ABCMeta, abstractmethod
|
||||||
from torch.nn.modules.loss import _Loss
|
from torch.nn.modules.loss import _Loss
|
||||||
import torch
|
import torch
|
||||||
@@ -55,25 +57,25 @@ class LossInterface(_Loss, metaclass=ABCMeta):
|
|||||||
return ret
|
return ret
|
||||||
|
|
||||||
class LpLoss(LossInterface):
|
class LpLoss(LossInterface):
|
||||||
"""
|
r"""
|
||||||
The Lp loss implementation class. Creates a criterion that measures
|
The Lp loss implementation class. Creates a criterion that measures
|
||||||
the Lp error between each element in the input :math:`x` and
|
the Lp error between each element in the input :math:`x` and
|
||||||
target :math:`y`.
|
target :math:`y`.
|
||||||
|
|
||||||
The unreduced (i.e. with :attr:`reduction` set to ``'none'``) loss can
|
The unreduced (i.e. with :attr:`reduction` set to ``none``) loss can
|
||||||
be described as:
|
be described as:
|
||||||
|
|
||||||
.. math::
|
.. math::
|
||||||
\ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad
|
\ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad
|
||||||
l_n = \left| x_n - y_n \right|^p,
|
l_n = \left[\sum_{i=1}^{D} \left| x_n^i - y_n^i \right|^p \right],
|
||||||
|
|
||||||
If ``'relative'`` is set to true:
|
If ``'relative'`` is set to true:
|
||||||
|
|
||||||
.. math::
|
.. math::
|
||||||
\ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad
|
\ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad
|
||||||
l_n = \left[\frac{\left| x_n - y_n \right|^p}{\left|y_n \right|^p}\right]^{1/p},
|
l_n = \frac{ [\sum_{i=1}^{D} | x_n^i - y_n^i|^p] }{[\sum_{i=1}^{D}|y_n^i|^p]},
|
||||||
|
|
||||||
where :math:`N` is the batch size. If :attr:`reduction` is not ``'none'``
|
where :math:`N` is the batch size. If :attr:`reduction` is not ``none``
|
||||||
(default ``'mean'``), then:
|
(default ``'mean'``), then:
|
||||||
|
|
||||||
.. math::
|
.. math::
|
||||||
@@ -88,7 +90,7 @@ class LpLoss(LossInterface):
|
|||||||
|
|
||||||
The sum operation still operates over all the elements, and divides by :math:`n`.
|
The sum operation still operates over all the elements, and divides by :math:`n`.
|
||||||
|
|
||||||
The division by :math:`n` can be avoided if one sets ``reduction = 'sum'``.
|
The division by :math:`n` can be avoided if one sets :attr:`reduction` to ``sum``.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, p=2, reduction = 'mean', relative = False):
|
def __init__(self, p=2, reduction = 'mean', relative = False):
|
||||||
@@ -125,3 +127,77 @@ class LpLoss(LossInterface):
|
|||||||
if self.relative:
|
if self.relative:
|
||||||
loss = loss / torch.linalg.norm(input, ord=self.p, dim=-1)
|
loss = loss / torch.linalg.norm(input, ord=self.p, dim=-1)
|
||||||
return self._reduction(loss)
|
return self._reduction(loss)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class PowerLoss(LossInterface):
|
||||||
|
r"""
|
||||||
|
The PowerLoss loss implementation class. Creates a criterion that measures
|
||||||
|
the error between each element in the input :math:`x` and
|
||||||
|
target :math:`y` powered to a specific integer.
|
||||||
|
|
||||||
|
The unreduced (i.e. with :attr:`reduction` set to ``none``) loss can
|
||||||
|
be described as:
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
\ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad
|
||||||
|
l_n = \frac{1}{D}\left[\sum_{i=1}^{D} \left| x_n^i - y_n^i \right|^p \right],
|
||||||
|
|
||||||
|
If ``'relative'`` is set to true:
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
\ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad
|
||||||
|
l_n = \frac{ \sum_{i=1}^{D} | x_n^i - y_n^i|^p }{\sum_{i=1}^{D}|y_n^i|^p},
|
||||||
|
|
||||||
|
where :math:`N` is the batch size. If :attr:`reduction` is not ``none``
|
||||||
|
(default ``'mean'``), then:
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
\ell(x, y) =
|
||||||
|
\begin{cases}
|
||||||
|
\operatorname{mean}(L), & \text{if reduction} = \text{`mean';}\\
|
||||||
|
\operatorname{sum}(L), & \text{if reduction} = \text{`sum'.}
|
||||||
|
\end{cases}
|
||||||
|
|
||||||
|
:math:`x` and :math:`y` are tensors of arbitrary shapes with a total
|
||||||
|
of :math:`n` elements each.
|
||||||
|
|
||||||
|
The sum operation still operates over all the elements, and divides by :math:`n`.
|
||||||
|
|
||||||
|
The division by :math:`n` can be avoided if one sets :attr:`reduction` to ``sum``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, p=2, reduction = 'mean', relative = False):
|
||||||
|
"""
|
||||||
|
:param int p: Degree of Lp norm. It specifies the type of norm to
|
||||||
|
be calculated. See :meth:`torch.linalg.norm` ```'ord'``` to
|
||||||
|
see the possible degrees. Default 2 (euclidean norm).
|
||||||
|
:param str reduction: Specifies the reduction to apply to the output:
|
||||||
|
``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction
|
||||||
|
will be applied, ``'mean'``: the sum of the output will be divided
|
||||||
|
by the number of elements in the output, ``'sum'``: the output will
|
||||||
|
be summed. Note: :attr:`size_average` and :attr:`reduce` are in the
|
||||||
|
process of being deprecated, and in the meantime, specifying either of
|
||||||
|
those two args will override :attr:`reduction`. Default: ``'mean'``.
|
||||||
|
:param bool relative: Specifies if relative error should be computed.
|
||||||
|
"""
|
||||||
|
super().__init__(reduction=reduction)
|
||||||
|
|
||||||
|
# check consistency
|
||||||
|
check_consistency(p, (str,int,float))
|
||||||
|
self.p = p
|
||||||
|
check_consistency(relative, bool)
|
||||||
|
self.relative = relative
|
||||||
|
|
||||||
|
def forward(self, input, target):
|
||||||
|
"""Forward method for loss function.
|
||||||
|
|
||||||
|
:param torch.Tensor input: Input tensor from real data.
|
||||||
|
:param torch.Tensor target: Model tensor output.
|
||||||
|
:return: Loss evaluation.
|
||||||
|
:rtype: torch.Tensor
|
||||||
|
"""
|
||||||
|
loss = torch.abs((input-target)).pow(self.p).mean(-1)
|
||||||
|
if self.relative:
|
||||||
|
loss = loss / torch.abs(input).pow(self.p).mean(-1)
|
||||||
|
return self._reduction(loss)
|
||||||
49
tests/test_loss/test_powerloss.py
Normal file
49
tests/test_loss/test_powerloss.py
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
import torch
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from pina.loss import PowerLoss
|
||||||
|
|
||||||
|
input = torch.tensor([[3.], [1.], [-8.]])
|
||||||
|
target = torch.tensor([[6.], [4.], [2.]])
|
||||||
|
available_reductions = ['str', 'mean', 'none']
|
||||||
|
|
||||||
|
|
||||||
|
def test_PowerLoss_constructor():
|
||||||
|
# test reduction
|
||||||
|
for reduction in available_reductions:
|
||||||
|
PowerLoss(reduction=reduction)
|
||||||
|
# test p
|
||||||
|
for p in [float('inf'), -float('inf'), 1, 10, -8]:
|
||||||
|
PowerLoss(p=p)
|
||||||
|
|
||||||
|
def test_PowerLoss_forward():
|
||||||
|
# l2 loss
|
||||||
|
loss = PowerLoss(p=2, reduction='mean')
|
||||||
|
l2_loss = torch.mean((input-target).pow(2))
|
||||||
|
assert loss(input, target) == l2_loss
|
||||||
|
# l1 loss
|
||||||
|
loss = PowerLoss(p=1, reduction='sum')
|
||||||
|
l1_loss = torch.sum(torch.abs(input-target))
|
||||||
|
assert loss(input, target) == l1_loss
|
||||||
|
|
||||||
|
def test_LpRelativeLoss_constructor():
|
||||||
|
# test reduction
|
||||||
|
for reduction in available_reductions:
|
||||||
|
PowerLoss(reduction=reduction, relative=True)
|
||||||
|
# test p
|
||||||
|
for p in [float('inf'), -float('inf'), 1, 10, -8]:
|
||||||
|
PowerLoss(p=p,relative=True)
|
||||||
|
|
||||||
|
def test_LpRelativeLoss_forward():
|
||||||
|
# l2 relative loss
|
||||||
|
loss = PowerLoss(p=2, reduction='mean',relative=True)
|
||||||
|
l2_loss = (input-target).pow(2)/input.pow(2)
|
||||||
|
assert loss(input, target) == torch.mean(l2_loss)
|
||||||
|
# l1 relative loss
|
||||||
|
loss = PowerLoss(p=1, reduction='sum',relative=True)
|
||||||
|
l1_loss = torch.abs(input-target)/torch.abs(input)
|
||||||
|
assert loss(input, target) == torch.sum(l1_loss)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Reference in New Issue
Block a user