fix doc loss and codacy

This commit is contained in:
giovanni
2025-03-12 18:05:42 +01:00
committed by Nicola Demo
parent 31de079daa
commit cbf886e53e
15 changed files with 114 additions and 108 deletions

View File

@@ -1,20 +1,41 @@
"""Module for Loss Interface"""
"""Module for the Scalar Weighting."""
from .weighting_interface import WeightingInterface
from ..utils import check_consistency
class _NoWeighting(WeightingInterface):
"""
Weighting scheme that does not apply any weighting to the losses.
"""
def aggregate(self, losses):
"""
Aggregate the losses.
:param dict losses: The dictionary of losses.
:return: The aggregated losses.
:rtype: torch.Tensor
"""
return sum(losses.values())
class ScalarWeighting(WeightingInterface):
"""
TODO
Weighting scheme that assigns a scalar weight to each loss term.
"""
def __init__(self, weights):
"""
Initialization of the :class:`ScalarWeighting` class.
:param weights: The weights to be assigned to each loss term.
If a single scalar value is provided, it is assigned to all loss
terms. If a dictionary is provided, the keys are the conditions and
the values are the weights. If a condition is not present in the
dictionary, the default value is used.
:type weights: float | int | dict
"""
super().__init__()
check_consistency([weights], (float, dict, int))
if isinstance(weights, (float, int)):
@@ -28,8 +49,8 @@ class ScalarWeighting(WeightingInterface):
"""
Aggregate the losses.
:param dict(torch.Tensor) losses: The dictionary of losses.
:return: The losses aggregation. It should be a scalar Tensor.
:param dict losses: The dictionary of losses.
:return: The aggregated losses.
:rtype: torch.Tensor
"""
return sum(