Files
PINA/pina/loss/scalar_weighting.py
Dario Coscia 9cae9a438f Update solvers (#434)
* Enable DDP training with batch_size=None and add validity check for split sizes
* Refactoring SolverInterfaces (#435)
* Solver update + weighting
* Updating PINN for 0.2
* Modify GAROM + tests
* Adding more versatile loggers
* Disable compilation when running on Windows
* Fix tests

---------

Co-authored-by: giovanni <giovanni.canali98@yahoo.it>
Co-authored-by: FilippoOlivo <filippo@filippoolivo.com>
2025-03-19 17:46:35 +01:00

37 lines
1.0 KiB
Python

""" Module for Loss Interface """
from .weighting_interface import WeightingInterface
from ..utils import check_consistency
class _NoWeighting(WeightingInterface):
def aggregate(self, losses):
return sum(losses.values())
class ScalarWeighting(WeightingInterface):
"""
TODO
"""
def __init__(self, weights):
super().__init__()
check_consistency([weights], (float, dict, int))
if isinstance(weights, (float, int)):
self.default_value_weights = weights
self.weights = {}
else:
self.default_value_weights = 1
self.weights = weights
def aggregate(self, losses):
"""
Aggregate the losses.
:param dict(torch.Tensor) losses: The dictionary of losses.
:return: The losses aggregation. It should be a scalar Tensor.
:rtype: torch.Tensor
"""
return sum(
self.weights.get(condition, self.default_value_weights) * loss for
condition, loss in losses.items()
)