Update solvers (#434)

* Enable DDP training with batch_size=None and add validity check for split sizes
* Refactoring SolverInterfaces (#435)
* Solver update + weighting
* Updating PINN for 0.2
* Modify GAROM + tests
* Adding more versatile loggers
* Disable compilation when running on Windows
* Fix tests

---------

Co-authored-by: giovanni <giovanni.canali98@yahoo.it>
Co-authored-by: FilippoOlivo <filippo@filippoolivo.com>
This commit is contained in:
Dario Coscia
2025-02-17 11:26:21 +01:00
committed by Nicola Demo
parent 780c4921eb
commit 9cae9a438f
50 changed files with 2848 additions and 4187 deletions

View File

@@ -0,0 +1,36 @@
""" Module for Loss Interface """
from .weighting_interface import WeightingInterface
from ..utils import check_consistency
class _NoWeighting(WeightingInterface):
def aggregate(self, losses):
return sum(losses.values())
class ScalarWeighting(WeightingInterface):
"""
TODO
"""
def __init__(self, weights):
super().__init__()
check_consistency([weights], (float, dict, int))
if isinstance(weights, (float, int)):
self.default_value_weights = weights
self.weights = {}
else:
self.default_value_weights = 1
self.weights = weights
def aggregate(self, losses):
"""
Aggregate the losses.
:param dict(torch.Tensor) losses: The dictionary of losses.
:return: The losses aggregation. It should be a scalar Tensor.
:rtype: torch.Tensor
"""
return sum(
self.weights.get(condition, self.default_value_weights) * loss for
condition, loss in losses.items()
)