Formatting
* Adding black as dev dependency * Formatting pina code * Formatting tests
This commit is contained in:
committed by
Nicola Demo
parent
4c4482b155
commit
42ab1a666b
@@ -1,4 +1,4 @@
|
||||
""" Module for Loss Interface """
|
||||
"""Module for Loss Interface"""
|
||||
|
||||
from .weighting_interface import WeightingInterface
|
||||
from ..utils import check_consistency
|
||||
@@ -8,10 +8,12 @@ class _NoWeighting(WeightingInterface):
|
||||
def aggregate(self, losses):
|
||||
return sum(losses.values())
|
||||
|
||||
|
||||
class ScalarWeighting(WeightingInterface):
|
||||
"""
|
||||
TODO
|
||||
"""
|
||||
|
||||
def __init__(self, weights):
|
||||
super().__init__()
|
||||
check_consistency([weights], (float, dict, int))
|
||||
@@ -31,6 +33,6 @@ class ScalarWeighting(WeightingInterface):
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
return sum(
|
||||
self.weights.get(condition, self.default_value_weights) * loss for
|
||||
condition, loss in losses.items()
|
||||
self.weights.get(condition, self.default_value_weights) * loss
|
||||
for condition, loss in losses.items()
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user