fix doc loss and codacy
This commit is contained in:
@@ -1,20 +1,41 @@
|
||||
"""Module for Loss Interface"""
|
||||
"""Module for the Scalar Weighting."""
|
||||
|
||||
from .weighting_interface import WeightingInterface
|
||||
from ..utils import check_consistency
|
||||
|
||||
|
||||
class _NoWeighting(WeightingInterface):
|
||||
"""
|
||||
Weighting scheme that does not apply any weighting to the losses.
|
||||
"""
|
||||
|
||||
def aggregate(self, losses):
|
||||
"""
|
||||
Aggregate the losses.
|
||||
|
||||
:param dict losses: The dictionary of losses.
|
||||
:return: The aggregated losses.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
return sum(losses.values())
|
||||
|
||||
|
||||
class ScalarWeighting(WeightingInterface):
|
||||
"""
|
||||
TODO
|
||||
Weighting scheme that assigns a scalar weight to each loss term.
|
||||
"""
|
||||
|
||||
def __init__(self, weights):
|
||||
"""
|
||||
Initialization of the :class:`ScalarWeighting` class.
|
||||
|
||||
:param weights: The weights to be assigned to each loss term.
|
||||
If a single scalar value is provided, it is assigned to all loss
|
||||
terms. If a dictionary is provided, the keys are the conditions and
|
||||
the values are the weights. If a condition is not present in the
|
||||
dictionary, the default value is used.
|
||||
:type weights: float | int | dict
|
||||
"""
|
||||
super().__init__()
|
||||
check_consistency([weights], (float, dict, int))
|
||||
if isinstance(weights, (float, int)):
|
||||
@@ -28,8 +49,8 @@ class ScalarWeighting(WeightingInterface):
|
||||
"""
|
||||
Aggregate the losses.
|
||||
|
||||
:param dict(torch.Tensor) losses: The dictionary of losses.
|
||||
:return: The losses aggregation. It should be a scalar Tensor.
|
||||
:param dict losses: The dictionary of losses.
|
||||
:return: The aggregated losses.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
return sum(
|
||||
|
||||
Reference in New Issue
Block a user