Files
PINA/pina/loss/weighting_interface.py
Dario Coscia 42ab1a666b Formatting
* Adding black as dev dependency
* Formatting pina code
* Formatting tests
2025-03-19 17:46:36 +01:00

24 lines
529 B
Python

"""Module for Loss Interface"""
from abc import ABCMeta, abstractmethod
class WeightingInterface(metaclass=ABCMeta):
"""
The ``weightingInterface`` class. TODO
"""
def __init__(self):
self.condition_names = None
@abstractmethod
def aggregate(self, losses):
"""
Aggregate the losses.
:param dict(torch.Tensor) input: The dictionary of losses.
:return: The losses aggregation. It should be a scalar Tensor.
:rtype: torch.Tensor
"""
pass