Update solvers (#434)

* Enable DDP training with batch_size=None and add validity check for split sizes
* Refactoring SolverInterfaces (#435)
* Solver update + weighting
* Updating PINN for 0.2
* Modify GAROM + tests
* Adding more versatile loggers
* Disable compilation when running on Windows
* Fix tests

---------

Co-authored-by: giovanni <giovanni.canali98@yahoo.it>
Co-authored-by: FilippoOlivo <filippo@filippoolivo.com>
This commit is contained in:
Dario Coscia
2025-02-17 11:26:21 +01:00
committed by Nicola Demo
parent 780c4921eb
commit 9cae9a438f
50 changed files with 2848 additions and 4187 deletions

View File

@@ -1,11 +1,13 @@
__all__ = [
'LossInterface',
'LpLoss',
'PowerLoss',
'weightningInterface',
'LossInterface'
'WeightingInterface',
'ScalarWeighting'
]
from .loss_interface import LossInterface
from .power_loss import PowerLoss
from .lp_loss import LpLoss
from .weightning_interface import weightningInterface
from .weighting_interface import WeightingInterface
from .scalar_weighting import ScalarWeighting

View File

@@ -0,0 +1,36 @@
""" Module for Loss Interface """
from .weighting_interface import WeightingInterface
from ..utils import check_consistency
class _NoWeighting(WeightingInterface):
def aggregate(self, losses):
return sum(losses.values())
class ScalarWeighting(WeightingInterface):
"""
TODO
"""
def __init__(self, weights):
super().__init__()
check_consistency([weights], (float, dict, int))
if isinstance(weights, (float, int)):
self.default_value_weights = weights
self.weights = {}
else:
self.default_value_weights = 1
self.weights = weights
def aggregate(self, losses):
"""
Aggregate the losses.
:param dict(torch.Tensor) losses: The dictionary of losses.
:return: The losses aggregation. It should be a scalar Tensor.
:rtype: torch.Tensor
"""
return sum(
self.weights.get(condition, self.default_value_weights) * loss for
condition, loss in losses.items()
)

View File

@@ -1,35 +0,0 @@
""" Module for Loss Interface """
from .weightning_interface import weightningInterface
class WeightedAggregation(WeightningInterface):
"""
TODO
"""
def __init__(self, aggr='mean', weights=None):
self.aggr = aggr
self.weights = weights
def aggregate(self, losses):
"""
Aggregate the losses.
:param dict(torch.Tensor) input: The dictionary of losses.
:return: The losses aggregation. It should be a scalar Tensor.
:rtype: torch.Tensor
"""
if self.weights:
weighted_losses = {
condition: self.weights[condition] * losses[condition]
for condition in losses
}
else:
weighted_losses = losses
if self.aggr == 'mean':
return sum(weighted_losses.values()) / len(weighted_losses)
elif self.aggr == 'sum':
return sum(weighted_losses.values())
else:
raise ValueError(self.aggr + " is not valid for aggregation.")

View File

@@ -3,22 +3,21 @@
from abc import ABCMeta, abstractmethod
class weightningInterface(metaclass=ABCMeta):
class WeightingInterface(metaclass=ABCMeta):
"""
The ``weightingInterface`` class. TODO
"""
@abstractmethod
def __init__(self, *args, **kwargs):
pass
def __init__(self):
self.condition_names = None
@abstractmethod
def aggregate(self, losses):
"""
Aggregate the losses.
:param list(torch.Tensor) input: The list
:param dict(torch.Tensor) input: The dictionary of losses.
:return: The losses aggregation. It should be a scalar Tensor.
:rtype: torch.Tensor
"""
pass
pass