Update solvers (#434)
* Enable DDP training with batch_size=None and add validity check for split sizes * Refactoring SolverInterfaces (#435) * Solver update + weighting * Updating PINN for 0.2 * Modify GAROM + tests * Adding more versatile loggers * Disable compilation when running on Windows * Fix tests --------- Co-authored-by: giovanni <giovanni.canali98@yahoo.it> Co-authored-by: FilippoOlivo <filippo@filippoolivo.com>
This commit is contained in:
committed by
Nicola Demo
parent
780c4921eb
commit
9cae9a438f
@@ -1,11 +1,13 @@
|
||||
__all__ = [
|
||||
'LossInterface',
|
||||
'LpLoss',
|
||||
'PowerLoss',
|
||||
'weightningInterface',
|
||||
'LossInterface'
|
||||
'WeightingInterface',
|
||||
'ScalarWeighting'
|
||||
]
|
||||
|
||||
from .loss_interface import LossInterface
|
||||
from .power_loss import PowerLoss
|
||||
from .lp_loss import LpLoss
|
||||
from .weightning_interface import weightningInterface
|
||||
from .weighting_interface import WeightingInterface
|
||||
from .scalar_weighting import ScalarWeighting
|
||||
|
||||
36
pina/loss/scalar_weighting.py
Normal file
36
pina/loss/scalar_weighting.py
Normal file
@@ -0,0 +1,36 @@
|
||||
""" Module for Loss Interface """
|
||||
|
||||
from .weighting_interface import WeightingInterface
|
||||
from ..utils import check_consistency
|
||||
|
||||
|
||||
class _NoWeighting(WeightingInterface):
|
||||
def aggregate(self, losses):
|
||||
return sum(losses.values())
|
||||
|
||||
class ScalarWeighting(WeightingInterface):
|
||||
"""
|
||||
TODO
|
||||
"""
|
||||
def __init__(self, weights):
|
||||
super().__init__()
|
||||
check_consistency([weights], (float, dict, int))
|
||||
if isinstance(weights, (float, int)):
|
||||
self.default_value_weights = weights
|
||||
self.weights = {}
|
||||
else:
|
||||
self.default_value_weights = 1
|
||||
self.weights = weights
|
||||
|
||||
def aggregate(self, losses):
|
||||
"""
|
||||
Aggregate the losses.
|
||||
|
||||
:param dict(torch.Tensor) losses: The dictionary of losses.
|
||||
:return: The losses aggregation. It should be a scalar Tensor.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
return sum(
|
||||
self.weights.get(condition, self.default_value_weights) * loss for
|
||||
condition, loss in losses.items()
|
||||
)
|
||||
@@ -1,35 +0,0 @@
|
||||
""" Module for Loss Interface """
|
||||
|
||||
from .weightning_interface import weightningInterface
|
||||
|
||||
|
||||
class WeightedAggregation(WeightningInterface):
|
||||
"""
|
||||
TODO
|
||||
"""
|
||||
def __init__(self, aggr='mean', weights=None):
|
||||
self.aggr = aggr
|
||||
self.weights = weights
|
||||
|
||||
def aggregate(self, losses):
|
||||
"""
|
||||
Aggregate the losses.
|
||||
|
||||
:param dict(torch.Tensor) input: The dictionary of losses.
|
||||
:return: The losses aggregation. It should be a scalar Tensor.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
if self.weights:
|
||||
weighted_losses = {
|
||||
condition: self.weights[condition] * losses[condition]
|
||||
for condition in losses
|
||||
}
|
||||
else:
|
||||
weighted_losses = losses
|
||||
|
||||
if self.aggr == 'mean':
|
||||
return sum(weighted_losses.values()) / len(weighted_losses)
|
||||
elif self.aggr == 'sum':
|
||||
return sum(weighted_losses.values())
|
||||
else:
|
||||
raise ValueError(self.aggr + " is not valid for aggregation.")
|
||||
@@ -3,22 +3,21 @@
|
||||
from abc import ABCMeta, abstractmethod
|
||||
|
||||
|
||||
class weightningInterface(metaclass=ABCMeta):
|
||||
class WeightingInterface(metaclass=ABCMeta):
|
||||
"""
|
||||
The ``weightingInterface`` class. TODO
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
def __init__(self):
|
||||
self.condition_names = None
|
||||
|
||||
@abstractmethod
|
||||
def aggregate(self, losses):
|
||||
"""
|
||||
Aggregate the losses.
|
||||
|
||||
:param list(torch.Tensor) input: The list
|
||||
:param dict(torch.Tensor) input: The dictionary of losses.
|
||||
:return: The losses aggregation. It should be a scalar Tensor.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
pass
|
||||
pass
|
||||
Reference in New Issue
Block a user