Update solvers (#434)
* Enable DDP training with batch_size=None and add validity check for split sizes * Refactoring SolverInterfaces (#435) * Solver update + weighting * Updating PINN for 0.2 * Modify GAROM + tests * Adding more versatile loggers * Disable compilation when running on Windows * Fix tests --------- Co-authored-by: giovanni <giovanni.canali98@yahoo.it> Co-authored-by: FilippoOlivo <filippo@filippoolivo.com>
This commit is contained in:
committed by
Nicola Demo
parent
780c4921eb
commit
9cae9a438f
36
pina/loss/scalar_weighting.py
Normal file
36
pina/loss/scalar_weighting.py
Normal file
@@ -0,0 +1,36 @@
|
||||
""" Module for Loss Interface """
|
||||
|
||||
from .weighting_interface import WeightingInterface
|
||||
from ..utils import check_consistency
|
||||
|
||||
|
||||
class _NoWeighting(WeightingInterface):
|
||||
def aggregate(self, losses):
|
||||
return sum(losses.values())
|
||||
|
||||
class ScalarWeighting(WeightingInterface):
|
||||
"""
|
||||
TODO
|
||||
"""
|
||||
def __init__(self, weights):
|
||||
super().__init__()
|
||||
check_consistency([weights], (float, dict, int))
|
||||
if isinstance(weights, (float, int)):
|
||||
self.default_value_weights = weights
|
||||
self.weights = {}
|
||||
else:
|
||||
self.default_value_weights = 1
|
||||
self.weights = weights
|
||||
|
||||
def aggregate(self, losses):
|
||||
"""
|
||||
Aggregate the losses.
|
||||
|
||||
:param dict(torch.Tensor) losses: The dictionary of losses.
|
||||
:return: The losses aggregation. It should be a scalar Tensor.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
return sum(
|
||||
self.weights.get(condition, self.default_value_weights) * loss for
|
||||
condition, loss in losses.items()
|
||||
)
|
||||
Reference in New Issue
Block a user