diff --git a/pina/utils.py b/pina/utils.py index 569ba63..b72d500 100644 --- a/pina/utils.py +++ b/pina/utils.py @@ -47,15 +47,43 @@ def check_consistency(object_, object_instance, subclass=False): object_ = [object_] for obj in object_: - try: - if not subclass: - assert isinstance(obj, object_instance) - else: - assert issubclass(obj, object_instance) - except AssertionError as e: - raise ValueError( - f"{type(obj).__name__} must be {object_instance}." - ) from e + is_class = isinstance(obj, type) + expected_type_name = ( + object_instance.__name__ + if isinstance(object_instance, type) + else str(object_instance) + ) + + if subclass: + if not is_class: + raise ValueError( + f"You passed {repr(obj)} " + f"(an instance of {type(obj).__name__}), " + f"but a {expected_type_name} class was expected. " + f"Please pass a {expected_type_name} class or a " + "derived one." + ) + elif not issubclass(obj, object_instance): + raise ValueError( + f"You passed {obj.__name__} class, but a " + f"{expected_type_name} class was expected. " + f"Please pass a {expected_type_name} class or a " + "derived one." + ) + else: + if is_class: + raise ValueError( + f"You passed {obj.__name__} class, but a " + f"{expected_type_name} instance was expected. " + f"Please pass a {expected_type_name} instance." + ) + elif not isinstance(obj, object_instance): + raise ValueError( + f"You passed {repr(obj)} " + f"(an instance of {type(obj).__name__}), " + f"but a {expected_type_name} instance was expected. " + f"Please pass a {expected_type_name} instance." + ) def labelize_forward(forward, input_variables, output_variables):