* Fix adaptive refinement (#571) --------- Co-authored-by: Dario Coscia <93731561+dario-coscia@users.noreply.github.com> * Remove collector * Fixes * Fixes * rm unnecessary comment * fix advection (#581) * Fix tutorial .html link (#580) * fix problem data collection for v0.1 (#584) * Message Passing Module (#516) * add deep tensor network block * add interaction network block * add radial field network block * add schnet block * add equivariant network block * fix + tests + doc files * fix egnn + equivariance/invariance tests Co-authored-by: Dario Coscia <dariocos99@gmail.com> --------- Co-authored-by: giovanni <giovanni.canali98@yahoo.it> Co-authored-by: AleDinve <giuseppealessio.d@student.unisi.it> * add type checker (#527) --------- Co-authored-by: Filippo Olivo <filippo@filippoolivo.com> Co-authored-by: Giovanni Canali <115086358+GiovanniCanali@users.noreply.github.com> Co-authored-by: giovanni <giovanni.canali98@yahoo.it> Co-authored-by: AleDinve <giuseppealessio.d@student.unisi.it>
94 lines
2.9 KiB
Python
94 lines
2.9 KiB
Python
"""Module for enforcing type hints in Python functions."""
|
|
|
|
import inspect
|
|
import typing
|
|
import logging
|
|
|
|
|
|
def enforce_types(func):
|
|
"""
|
|
Function decorator to enforce type hints at runtime.
|
|
|
|
This decorator checks the types of the arguments and of the return value of
|
|
the decorated function against the type hints specified in the function
|
|
signature. If the types do not match, a TypeError is raised.
|
|
Type checking is only performed when the logging level is set to `DEBUG`.
|
|
|
|
:param Callable func: The function to be decorated.
|
|
:return: The decorated function with enforced type hints.
|
|
:rtype: Callable
|
|
|
|
:Example:
|
|
|
|
>>> @enforce_types
|
|
def dummy_function(a: int, b: float) -> float:
|
|
... return a+b
|
|
|
|
# This always works.
|
|
dummy_function(1, 2.0)
|
|
|
|
# This raises a TypeError for the second argument, if logging is set to
|
|
# `DEBUG`.
|
|
dummy_function(1, "Hello, world!")
|
|
|
|
|
|
>>> @enforce_types
|
|
def dummy_function2(a: int, right: bool) -> float:
|
|
... if right:
|
|
... return float(a)
|
|
... else:
|
|
... return "Hello, world!"
|
|
|
|
# This always works.
|
|
dummy_function2(1, right=True)
|
|
|
|
# This raises a TypeError for the return value if logging is set to
|
|
# `DEBUG`.
|
|
dummy_function2(1, right=False)
|
|
"""
|
|
|
|
def wrapper(*args, **kwargs):
|
|
"""
|
|
Wrapper function to enforce type hints.
|
|
|
|
:param tuple args: Positional arguments passed to the function.
|
|
:param dict kwargs: Keyword arguments passed to the function.
|
|
:raises TypeError: If the argument or return type does not match the
|
|
specified type hints.
|
|
:return: The result of the decorated function.
|
|
:rtype: Any
|
|
"""
|
|
level = logging.getLevelName(logging.getLogger().getEffectiveLevel())
|
|
|
|
# Enforce type hints only in debug mode
|
|
if level != "DEBUG":
|
|
return func(*args, **kwargs)
|
|
|
|
# Get the type hints for the function arguments
|
|
hints = typing.get_type_hints(func)
|
|
sig = inspect.signature(func)
|
|
bound = sig.bind(*args, **kwargs)
|
|
bound.apply_defaults()
|
|
|
|
for arg_name, arg_value in bound.arguments.items():
|
|
expected_type = hints.get(arg_name)
|
|
if expected_type and not isinstance(arg_value, expected_type):
|
|
raise TypeError(
|
|
f"Argument '{arg_name}' must be {expected_type.__name__}, "
|
|
f"but got {type(arg_value).__name__}!"
|
|
)
|
|
|
|
# Get the type hints for the return values
|
|
return_type = hints.get("return")
|
|
result = func(*args, **kwargs)
|
|
|
|
if return_type and not isinstance(result, return_type):
|
|
raise TypeError(
|
|
f"Return value must be {return_type.__name__}, "
|
|
f"but got {type(result).__name__}!"
|
|
)
|
|
|
|
return result
|
|
|
|
return wrapper
|