Formatting

* Adding black as dev dependency
* Formatting pina code
* Formatting tests
This commit is contained in:
Dario Coscia
2025-02-24 11:26:49 +01:00
committed by Nicola Demo
parent 4c4482b155
commit 42ab1a666b
77 changed files with 1170 additions and 924 deletions

View File

@@ -1,4 +1,4 @@
""" Module for Loss Interface """
"""Module for Loss Interface"""
from .weighting_interface import WeightingInterface
from ..utils import check_consistency
@@ -8,10 +8,12 @@ class _NoWeighting(WeightingInterface):
def aggregate(self, losses):
return sum(losses.values())
class ScalarWeighting(WeightingInterface):
"""
TODO
"""
def __init__(self, weights):
super().__init__()
check_consistency([weights], (float, dict, int))
@@ -31,6 +33,6 @@ class ScalarWeighting(WeightingInterface):
:rtype: torch.Tensor
"""
return sum(
self.weights.get(condition, self.default_value_weights) * loss for
condition, loss in losses.items()
self.weights.get(condition, self.default_value_weights) * loss
for condition, loss in losses.items()
)