Dev Update (#582)
* Fix adaptive refinement (#571) --------- Co-authored-by: Dario Coscia <93731561+dario-coscia@users.noreply.github.com> * Remove collector * Fixes * Fixes * rm unnecessary comment * fix advection (#581) * Fix tutorial .html link (#580) * fix problem data collection for v0.1 (#584) * Message Passing Module (#516) * add deep tensor network block * add interaction network block * add radial field network block * add schnet block * add equivariant network block * fix + tests + doc files * fix egnn + equivariance/invariance tests Co-authored-by: Dario Coscia <dariocos99@gmail.com> --------- Co-authored-by: giovanni <giovanni.canali98@yahoo.it> Co-authored-by: AleDinve <giuseppealessio.d@student.unisi.it> * add type checker (#527) --------- Co-authored-by: Filippo Olivo <filippo@filippoolivo.com> Co-authored-by: Giovanni Canali <115086358+GiovanniCanali@users.noreply.github.com> Co-authored-by: giovanni <giovanni.canali98@yahoo.it> Co-authored-by: AleDinve <giuseppealessio.d@student.unisi.it>
This commit is contained in:
93
pina/type_checker.py
Normal file
93
pina/type_checker.py
Normal file
@@ -0,0 +1,93 @@
|
||||
"""Module for enforcing type hints in Python functions."""
|
||||
|
||||
import inspect
|
||||
import typing
|
||||
import logging
|
||||
|
||||
|
||||
def enforce_types(func):
|
||||
"""
|
||||
Function decorator to enforce type hints at runtime.
|
||||
|
||||
This decorator checks the types of the arguments and of the return value of
|
||||
the decorated function against the type hints specified in the function
|
||||
signature. If the types do not match, a TypeError is raised.
|
||||
Type checking is only performed when the logging level is set to `DEBUG`.
|
||||
|
||||
:param Callable func: The function to be decorated.
|
||||
:return: The decorated function with enforced type hints.
|
||||
:rtype: Callable
|
||||
|
||||
:Example:
|
||||
|
||||
>>> @enforce_types
|
||||
def dummy_function(a: int, b: float) -> float:
|
||||
... return a+b
|
||||
|
||||
# This always works.
|
||||
dummy_function(1, 2.0)
|
||||
|
||||
# This raises a TypeError for the second argument, if logging is set to
|
||||
# `DEBUG`.
|
||||
dummy_function(1, "Hello, world!")
|
||||
|
||||
|
||||
>>> @enforce_types
|
||||
def dummy_function2(a: int, right: bool) -> float:
|
||||
... if right:
|
||||
... return float(a)
|
||||
... else:
|
||||
... return "Hello, world!"
|
||||
|
||||
# This always works.
|
||||
dummy_function2(1, right=True)
|
||||
|
||||
# This raises a TypeError for the return value if logging is set to
|
||||
# `DEBUG`.
|
||||
dummy_function2(1, right=False)
|
||||
"""
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
"""
|
||||
Wrapper function to enforce type hints.
|
||||
|
||||
:param tuple args: Positional arguments passed to the function.
|
||||
:param dict kwargs: Keyword arguments passed to the function.
|
||||
:raises TypeError: If the argument or return type does not match the
|
||||
specified type hints.
|
||||
:return: The result of the decorated function.
|
||||
:rtype: Any
|
||||
"""
|
||||
level = logging.getLevelName(logging.getLogger().getEffectiveLevel())
|
||||
|
||||
# Enforce type hints only in debug mode
|
||||
if level != "DEBUG":
|
||||
return func(*args, **kwargs)
|
||||
|
||||
# Get the type hints for the function arguments
|
||||
hints = typing.get_type_hints(func)
|
||||
sig = inspect.signature(func)
|
||||
bound = sig.bind(*args, **kwargs)
|
||||
bound.apply_defaults()
|
||||
|
||||
for arg_name, arg_value in bound.arguments.items():
|
||||
expected_type = hints.get(arg_name)
|
||||
if expected_type and not isinstance(arg_value, expected_type):
|
||||
raise TypeError(
|
||||
f"Argument '{arg_name}' must be {expected_type.__name__}, "
|
||||
f"but got {type(arg_value).__name__}!"
|
||||
)
|
||||
|
||||
# Get the type hints for the return values
|
||||
return_type = hints.get("return")
|
||||
result = func(*args, **kwargs)
|
||||
|
||||
if return_type and not isinstance(result, return_type):
|
||||
raise TypeError(
|
||||
f"Return value must be {return_type.__name__}, "
|
||||
f"but got {type(result).__name__}!"
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
Reference in New Issue
Block a user