* Enable DDP training with batch_size=None and add validity check for split sizes * Refactoring SolverInterfaces (#435) * Solver update + weighting * Updating PINN for 0.2 * Modify GAROM + tests * Adding more versatile loggers * Disable compilation when running on Windows * Fix tests --------- Co-authored-by: giovanni <giovanni.canali98@yahoo.it> Co-authored-by: FilippoOlivo <filippo@filippoolivo.com>
37 lines
1.0 KiB
Python
37 lines
1.0 KiB
Python
""" Module for Loss Interface """
|
|
|
|
from .weighting_interface import WeightingInterface
|
|
from ..utils import check_consistency
|
|
|
|
|
|
class _NoWeighting(WeightingInterface):
|
|
def aggregate(self, losses):
|
|
return sum(losses.values())
|
|
|
|
class ScalarWeighting(WeightingInterface):
|
|
"""
|
|
TODO
|
|
"""
|
|
def __init__(self, weights):
|
|
super().__init__()
|
|
check_consistency([weights], (float, dict, int))
|
|
if isinstance(weights, (float, int)):
|
|
self.default_value_weights = weights
|
|
self.weights = {}
|
|
else:
|
|
self.default_value_weights = 1
|
|
self.weights = weights
|
|
|
|
def aggregate(self, losses):
|
|
"""
|
|
Aggregate the losses.
|
|
|
|
:param dict(torch.Tensor) losses: The dictionary of losses.
|
|
:return: The losses aggregation. It should be a scalar Tensor.
|
|
:rtype: torch.Tensor
|
|
"""
|
|
return sum(
|
|
self.weights.get(condition, self.default_value_weights) * loss for
|
|
condition, loss in losses.items()
|
|
)
|