type_enforced.enforcer

  1from types import (
  2    FunctionType,
  3    MethodType,
  4    GeneratorType,
  5    BuiltinFunctionType,
  6    BuiltinMethodType,
  7    UnionType,
  8)
  9from typing import Type, Union, Sized, Literal, Callable, get_type_hints, Any
 10from functools import update_wrapper, wraps
 11from type_enforced.utils import (
 12    Partial,
 13    GenericConstraint,
 14    DeepMerge,
 15    iterable_types,
 16)
 17import sys, traceback
 18from pathlib import Path
 19
 20
 21class FunctionMethodEnforcer:
 22    def __init__(self, __fn__, __strict__=False, __clean_traceback__=True):
 23        """
 24        Initialize a FunctionMethodEnforcer class object as a wrapper for a passed function `__fn__`.
 25
 26        Requires:
 27
 28            - `__fn__`:
 29                - What: The function to enforce
 30                - Type: function | method | class
 31
 32        Optional:
 33
 34            - `__strict__`:
 35                - What: A boolean to enable or disable exceptions. If True, exceptions will be raised
 36                    when type checking fails. If False, exceptions will not be raised but instead a warning
 37                    will be printed to the console.
 38                - Type: bool
 39                - Default: False
 40            - `__clean_traceback__`:
 41                - What: A boolean to enable or disable cleaning of tracebacks when raising exceptions.
 42                - Type: bool
 43                - Default: True
 44        """
 45        update_wrapper(self, __fn__)
 46        self.__fn__ = __fn__
 47        self.__strict__ = __strict__
 48        self.__clean_traceback__ = __clean_traceback__
 49        self.__outer_self__ = None
 50        # Validate that the passed function or method is a method or function
 51        self.__check_method_function__()
 52        # Get input defaults for the function or method
 53        self.__get_defaults__()
 54
 55    def __get_defaults__(self):
 56        """
 57        Get the default values of the passed function or method and store them in `self.__fn_defaults__`.
 58        """
 59        self.__fn_defaults__ = {}
 60        if self.__fn__.__defaults__ is not None:
 61            # Get the names of all provided default values for args
 62            default_varnames = list(self.__fn__.__code__.co_varnames)[
 63                : self.__fn__.__code__.co_argcount
 64            ][-len(self.__fn__.__defaults__) :]
 65            # Update the output dictionary with the default values
 66            self.__fn_defaults__.update(
 67                dict(zip(default_varnames, self.__fn__.__defaults__))
 68            )
 69        if self.__fn__.__kwdefaults__ is not None:
 70            # Update the output dictionary with the keyword default values
 71            self.__fn_defaults__.update(self.__fn__.__kwdefaults__)
 72
 73    def __get_checkable_types__(self):
 74        """
 75        Creates two class attributes:
 76
 77        - `self.__checkable_types__`:
 78            - What: A dictionary of all annotations as checkable types
 79            - Type: dict
 80
 81        - `self.__return_type__`:
 82            - What: The return type of the function or method
 83            - Type: dict | None
 84        """
 85        if not hasattr(self, "__checkable_types__"):
 86            self.__checkable_types__ = {
 87                key: self.__get_checkable_type__(value)
 88                for key, value in get_type_hints(self.__fn__).items()
 89            }
 90            self.__return_type__ = self.__checkable_types__.pop("return", None)
 91
 92    def __get_checkable_type__(self, annotation):
 93        """
 94        Parses a type annotation and returns a nested dict structure
 95        representing the checkable type(s) for validation.
 96        """
 97
 98        if annotation is None:
 99            return {type(None): None}
100
101        # Handle `int | str` syntax (Python 3.10+) and Unions
102        if (
103            isinstance(annotation, UnionType)
104            or getattr(annotation, "__origin__", None) == Union
105        ):
106            combined_types = {}
107            for sub_type in annotation.__args__:
108                combined_types = DeepMerge(
109                    combined_types, self.__get_checkable_type__(sub_type)
110                )
111            return combined_types
112
113        # Handle typing.Literal
114        if getattr(annotation, "__origin__", None) == Literal:
115            return {"__extra__": {"__literal__": list(annotation.__args__)}}
116
117        # Handle generic collections
118        origin = getattr(annotation, "__origin__", None)
119        args = getattr(annotation, "__args__", ())
120
121        if origin == list:
122            if len(args) != 1:
123                self.__exception__(
124                    f"List must have a single type argument, got: {args}",
125                    raise_exception=True,
126                )
127            return {list: self.__get_checkable_type__(args[0])}
128
129        if origin == dict:
130            if len(args) != 2:
131                self.__exception__(
132                    f"Dict must have two type arguments, got: {args}",
133                    raise_exception=True,
134                )
135            key_type = self.__get_checkable_type__(args[0])
136            value_type = self.__get_checkable_type__(args[1])
137            return {dict: (key_type, value_type)}
138
139        if origin == tuple:
140            if len(args) > 2 or len(args) == 1:
141                if Ellipsis in args:
142                    self.__exception__(
143                        "Tuple with Ellipsis must have exactly two type arguments and the second must be Ellipsis.",
144                        raise_exception=True,
145                    )
146            if len(args) == 2:
147                if args[0] is Ellipsis:
148                    self.__exception__(
149                        "Tuple with Ellipsis must have exactly two type arguments and the first must not be Ellipsis.",
150                        raise_exception=True,
151                    )
152                if args[1] is Ellipsis:
153                    return {tuple: (self.__get_checkable_type__(args[0]), True)}
154            return {
155                tuple: (
156                    tuple(self.__get_checkable_type__(arg) for arg in args),
157                    False,
158                )
159            }
160
161        if origin == set:
162            if len(args) != 1:
163                self.__exception__(
164                    f"Set must have a single type argument, got: {args}",
165                    raise_exception=True,
166                )
167            return {set: self.__get_checkable_type__(args[0])}
168
169        # Handle Sized types
170        if annotation == Sized:
171            return {
172                list: None,
173                tuple: None,
174                dict: None,
175                set: None,
176                str: None,
177                bytes: None,
178                bytearray: None,
179                memoryview: None,
180                range: None,
181            }
182
183        # Handle Callable types
184        if annotation == Callable:
185            return {
186                staticmethod: None,
187                classmethod: None,
188                FunctionType: None,
189                BuiltinFunctionType: None,
190                MethodType: None,
191                BuiltinMethodType: None,
192                GeneratorType: None,
193            }
194
195        if annotation == Any:
196            return {
197                object: None,
198            }
199
200        # Handle Constraints
201        if isinstance(annotation, GenericConstraint):
202            return {"__extra__": {"__constraints__": [annotation]}}
203
204        # Handle standard types
205        if isinstance(annotation, type):
206            return {annotation: None}
207
208        # Hanldle typing.Type (for uninitialized classes)
209        if origin is type and len(args) == 1:
210            return {annotation: None}
211
212        self.__exception__(
213            f"Unsupported type hint: {annotation}", raise_exception=True
214        )
215
216    def __exception__(self, message, raise_exception=False):
217        """
218        Usage:
219
220        - Creates a class based exception message
221
222        Requires:
223
224        - `message`:
225            - Type: str
226            - What: The message to warn users with
227
228        Optional:
229
230        - `raise_exception`:
231            - Type: bool
232            - What: Forces an exception to be raised regardless of the `self.__strict__` setting.
233            - Default: False
234        """
235        if self.__strict__ or raise_exception:
236            msg = f"TypeEnforced Exception ({self.__fn__.__qualname__}): {message}"
237            if self.__clean_traceback__:
238                package_path = Path(__file__).parent.resolve()
239                frame = sys._getframe()
240                relevant_tb_count = 0
241                while frame is not None:
242                    frame_file = Path(frame.f_code.co_filename).resolve()
243                    try:
244                        frame_file.relative_to(package_path)
245                    except ValueError:
246                        relevant_tb_count += 1
247                    frame = frame.f_back
248                original_excepthook = sys.excepthook
249
250                def excepthook(type, value, tb):
251                    traceback.print_exception(
252                        type, value, tb, limit=relevant_tb_count
253                    )
254                    sys.excepthook = original_excepthook
255
256                sys.excepthook = excepthook
257            raise TypeError(msg)
258        else:
259            print(
260                f"TypeEnforced Warning ({self.__fn__.__qualname__}): {message}"
261            )
262
263    def __get__(self, obj, objtype):
264        """
265        Overwrite standard __get__ method to return __call__ instead for wrapped class methods.
266
267        Also stores the calling (__get__) `obj` to be passed as an initial argument for `__call__` such that methods can pass `self` correctly.
268        """
269
270        @wraps(self.__fn__)
271        def __get_fn__(*args, **kwargs):
272            return self.__call__(*args, **kwargs)
273
274        self.__outer_self__ = obj
275        return __get_fn__
276
277    def __check_method_function__(self):
278        """
279        Validate that `self.__fn__` is a method or function
280        """
281        if not isinstance(self.__fn__, (MethodType, FunctionType)):
282            raise Exception(
283                f"A non function/method was passed to Enforcer. See the stack trace above for more information."
284            )
285
286    def __call__(self, *args, **kwargs):
287        """
288        This method is used to validate the passed inputs and return the output of the wrapped function or method.
289        """
290        # Special code to pass self as an initial argument
291        # for validation purposes in methods
292        # See: self.__get__
293        if self.__outer_self__ is not None:
294            args = (self.__outer_self__, *args)
295        # Get a dictionary of all annotations as checkable types
296        # Note: This is only done once at first call to avoid redundant calculations
297        self.__get_checkable_types__()
298        # Create a compreshensive dictionary of assigned variables (order matters)
299        assigned_vars = {
300            **self.__fn_defaults__,
301            **dict(zip(self.__fn__.__code__.co_varnames[: len(args)], args)),
302            **kwargs,
303        }
304        # Validate all listed annotations vs the assigned_vars dictionary
305        for key, value in self.__checkable_types__.items():
306            self.__check_type__(assigned_vars.get(key), value, key)
307        # Execute the function callable
308        return_value = self.__fn__(*args, **kwargs)
309        # If a return type was passed, validate the returned object
310        if self.__return_type__ is not None:
311            self.__check_type__(return_value, self.__return_type__, "return")
312        return return_value
313
314    def __quick_check__(self, subtype, obj):
315        if all([v == None for v in subtype.values()]):
316            # If the subtype does not contain iterables with typing, we can validate the items directly.
317            types = set(subtype.keys())
318            values = set([type(v) for v in obj])
319            if values.issubset(types):
320                # We can return True to bypass the full validation
321                return True
322            # Otherwise, validation did not pass and a full validation is required to raise an indexed/keyed type mismatch error
323        return False
324
325    def __check_type__(self, obj, expected, key):
326        """
327        Raises an exception the type of a passed `obj` (parameter) is not in the list of supplied `acceptable_types` for the argument.
328        """
329        # Special case for None
330        if obj is None and type(None) in expected:
331            return
332        extra = expected.get("__extra__", {})
333        expected = {k: v for k, v in expected.items() if k != "__extra__"}
334
335        if isinstance(obj, type):
336            # An uninitialized class is passed, we need to check if the type is in the expected types using Type[obj]
337            obj_type = Type[obj]
338            is_present = obj_type in expected
339        else:
340            obj_type = type(obj)
341            is_present = isinstance(obj, tuple(expected.keys()))
342
343        if not is_present:
344            # Allow for literals to be used to bypass type checks if present
345            literal = extra.get("__literal__", ())
346            if literal:
347                if obj not in literal:
348                    self.__exception__(
349                        f"Type mismatch for typed variable `{key}`. Expected one of the following `{list(expected.keys())}` or a literal value in `{literal}` but got type `{obj_type}` with value `{obj}` instead."
350                    )
351            # Raise an exception if the type is not in the expected types
352            else:
353                self.__exception__(
354                    f"Type mismatch for typed variable `{key}`. Expected one of the following `{list(expected.keys())}` but got `{obj_type}` with value `{obj}` instead."
355                )
356        # If the object_type is in the expected types, we can proceed with validation
357        elif obj_type in iterable_types:
358            subtype = expected.get(obj_type, None)
359            if subtype is None:
360                pass
361            # Recursive validation
362            elif obj_type == list:
363                # If the subtype does not contain iterables with typing, we can validate the items directly.
364                if not self.__quick_check__(subtype, obj):
365                    for idx, item in enumerate(obj):
366                        self.__check_type__(item, subtype, f"{key}[{idx}]")
367            elif obj_type == dict:
368                key_type, val_type = subtype
369                if not self.__quick_check__(key_type, obj.keys()):
370                    for key in obj.keys():
371                        self.__check_type__(
372                            key, key_type, f"{key}.key[{repr(key)}]"
373                        )
374                if not self.__quick_check__(val_type, obj.values()):
375                    for key, value in obj.items():
376                        self.__check_type__(
377                            value, val_type, f"{key}[{repr(key)}]"
378                        )
379            elif obj_type == tuple:
380                expected_args, is_ellipsis = subtype
381                if is_ellipsis:
382                    if not self.__quick_check__(expected_args, obj):
383                        for idx, item in enumerate(obj):
384                            self.__check_type__(
385                                item, expected_args, f"{key}[{idx}]"
386                            )
387                else:
388                    if len(obj) != len(expected_args):
389                        self.__exception__(
390                            f"Tuple length mismatch for `{key}`. Expected length {len(expected_args)}, got {len(obj)}"
391                        )
392                    for idx, (item, ex) in enumerate(zip(obj, expected_args)):
393                        self.__check_type__(item, ex, f"{key}[{idx}]")
394            elif obj_type == set:
395                if not self.__quick_check__(subtype, obj):
396                    for item in obj:
397                        self.__check_type__(
398                            item, subtype, f"{key}[{repr(item)}]"
399                        )
400
401        # Validate constraints if any are present
402        constraints = extra.get("__constraints__", [])
403        for constraint in constraints:
404            constraint_validation_output = constraint.__validate__(key, obj)
405            if constraint_validation_output is not True:
406                self.__exception__(
407                    f"Constraint validation error for variable `{key}` with value `{obj}`. {constraint_validation_output}"
408                )
409
410    def __repr__(self):
411        return f"<type_enforced {self.__fn__.__module__}.{self.__fn__.__qualname__} object at {hex(id(self))}>"
412
413
414@Partial
415def Enforcer(clsFnMethod, enabled=True, strict=True, clean_traceback=True):
416    """
417    A wrapper to enforce types within a function or method given argument annotations.
418
419    Each wrapped item is converted into a special `FunctionMethodEnforcer` class object that validates the passed parameters for the function or method when it is called. If a function or method that is passed does not have any annotations, it is not converted into a `FunctionMethodEnforcer` class as no validation is possible.
420
421    If wrapping a class, all methods in the class that meet any of the following criteria will be wrapped individually:
422
423    - Methods with `__call__`
424    - Methods wrapped with `staticmethod` (if python >= 3.10)
425    - Methods wrapped with `classmethod` (if python >= 3.10)
426
427    Requires:
428
429    - `clsFnMethod`:
430        - What: The class, function or method that should have input types enforced
431        - Type: function | method | class
432
433    Optional:
434
435    - `enabled`:
436        - What: A boolean to enable or disable the enforcer
437        - Type: bool
438        - Default: True
439    - `strict`:
440        - What: A boolean to enable or disable exceptions. If True, exceptions will be raised when type checking fails. If False, exceptions will not be raised but instead a warning will be printed to the console.
441        - Type: bool
442        - Default: False
443        - Note: Type hints that are wrapped with the type enforcer and are invalid will still raise an exception.
444    - `clean_traceback`:
445        - What: A boolean to enable or disable cleaning of tracebacks when raising exceptions.
446        - If True, modifies the excepthook temporarily such that only the relevant stack (not in the type_enforced package) is shown.
447        - Type: bool
448        - Default: True
449
450
451    Example Use:
452    ```
453    >>> import type_enforced
454    >>> @type_enforced.Enforcer
455    ... def my_fn(a: int , b: [int, str] =2, c: int =3) -> None:
456    ...     pass
457    ...
458    >>> my_fn(a=1, b=2, c=3)
459    >>> my_fn(a=1, b='2', c=3)
460    >>> my_fn(a='a', b=2, c=3)
461    Traceback (most recent call last):
462      File "<stdin>", line 1, in <module>
463      File "/home/conmak/development/personal/type_enforced/type_enforced/enforcer.py", line 85, in __call__
464        self.__check_type__(assigned_vars.get(key), value, key)
465      File "/home/conmak/development/personal/type_enforced/type_enforced/enforcer.py", line 107, in __check_type__
466        self.__exception__(
467      File "/home/conmak/development/personal/type_enforced/type_enforced/enforcer.py", line 34, in __exception__
468        raise Exception(f"({self.__fn__.__qualname__}): {message}")
469    Exception: (my_fn): Type mismatch for typed variable `a`. Expected one of the following `[<class 'int'>]` but got `<class 'str'>` instead.
470    ```
471    """
472    if not hasattr(clsFnMethod, "__type_enforced_enabled__"):
473        # Special try except clause to handle cases when the object is immutable
474        try:
475            clsFnMethod.__type_enforced_enabled__ = enabled
476        except:
477            return clsFnMethod
478    if not clsFnMethod.__type_enforced_enabled__:
479        return clsFnMethod
480    if isinstance(
481        clsFnMethod, (staticmethod, classmethod, FunctionType, MethodType)
482    ):
483        # Only apply the enforcer if type_hints are present
484        # Add try except clause to better handle forward refs.
485        try:
486            if get_type_hints(clsFnMethod) == {}:
487                return clsFnMethod
488        except:
489            pass
490        if isinstance(clsFnMethod, staticmethod):
491            return staticmethod(
492                FunctionMethodEnforcer(
493                    __fn__=clsFnMethod.__func__,
494                    __strict__=strict,
495                    __clean_traceback__=clean_traceback,
496                )
497            )
498        elif isinstance(clsFnMethod, classmethod):
499            return classmethod(
500                FunctionMethodEnforcer(
501                    __fn__=clsFnMethod.__func__,
502                    __strict__=strict,
503                    __clean_traceback__=clean_traceback,
504                )
505            )
506        else:
507            return FunctionMethodEnforcer(
508                __fn__=clsFnMethod,
509                __strict__=strict,
510                __clean_traceback__=clean_traceback,
511            )
512    elif hasattr(clsFnMethod, "__dict__"):
513        for key, value in clsFnMethod.__dict__.items():
514            # Skip the __annotate__ method if present in __dict__ as it deletes itself upon invocation
515            # Skip any previously wrapped methods if they are already a FunctionMethodEnforcer
516            if key == "__annotate__" or isinstance(
517                value, FunctionMethodEnforcer
518            ):
519                continue
520            if hasattr(value, "__call__") or isinstance(
521                value, (classmethod, staticmethod)
522            ):
523                setattr(
524                    clsFnMethod,
525                    key,
526                    Enforcer(
527                        value,
528                        enabled=enabled,
529                        strict=strict,
530                        clean_traceback=clean_traceback,
531                    ),
532                )
533        return clsFnMethod
534    else:
535        raise Exception(
536            "Enforcer can only be used on classes, methods, or functions."
537        )
class FunctionMethodEnforcer:
 22class FunctionMethodEnforcer:
 23    def __init__(self, __fn__, __strict__=False, __clean_traceback__=True):
 24        """
 25        Initialize a FunctionMethodEnforcer class object as a wrapper for a passed function `__fn__`.
 26
 27        Requires:
 28
 29            - `__fn__`:
 30                - What: The function to enforce
 31                - Type: function | method | class
 32
 33        Optional:
 34
 35            - `__strict__`:
 36                - What: A boolean to enable or disable exceptions. If True, exceptions will be raised
 37                    when type checking fails. If False, exceptions will not be raised but instead a warning
 38                    will be printed to the console.
 39                - Type: bool
 40                - Default: False
 41            - `__clean_traceback__`:
 42                - What: A boolean to enable or disable cleaning of tracebacks when raising exceptions.
 43                - Type: bool
 44                - Default: True
 45        """
 46        update_wrapper(self, __fn__)
 47        self.__fn__ = __fn__
 48        self.__strict__ = __strict__
 49        self.__clean_traceback__ = __clean_traceback__
 50        self.__outer_self__ = None
 51        # Validate that the passed function or method is a method or function
 52        self.__check_method_function__()
 53        # Get input defaults for the function or method
 54        self.__get_defaults__()
 55
 56    def __get_defaults__(self):
 57        """
 58        Get the default values of the passed function or method and store them in `self.__fn_defaults__`.
 59        """
 60        self.__fn_defaults__ = {}
 61        if self.__fn__.__defaults__ is not None:
 62            # Get the names of all provided default values for args
 63            default_varnames = list(self.__fn__.__code__.co_varnames)[
 64                : self.__fn__.__code__.co_argcount
 65            ][-len(self.__fn__.__defaults__) :]
 66            # Update the output dictionary with the default values
 67            self.__fn_defaults__.update(
 68                dict(zip(default_varnames, self.__fn__.__defaults__))
 69            )
 70        if self.__fn__.__kwdefaults__ is not None:
 71            # Update the output dictionary with the keyword default values
 72            self.__fn_defaults__.update(self.__fn__.__kwdefaults__)
 73
 74    def __get_checkable_types__(self):
 75        """
 76        Creates two class attributes:
 77
 78        - `self.__checkable_types__`:
 79            - What: A dictionary of all annotations as checkable types
 80            - Type: dict
 81
 82        - `self.__return_type__`:
 83            - What: The return type of the function or method
 84            - Type: dict | None
 85        """
 86        if not hasattr(self, "__checkable_types__"):
 87            self.__checkable_types__ = {
 88                key: self.__get_checkable_type__(value)
 89                for key, value in get_type_hints(self.__fn__).items()
 90            }
 91            self.__return_type__ = self.__checkable_types__.pop("return", None)
 92
 93    def __get_checkable_type__(self, annotation):
 94        """
 95        Parses a type annotation and returns a nested dict structure
 96        representing the checkable type(s) for validation.
 97        """
 98
 99        if annotation is None:
100            return {type(None): None}
101
102        # Handle `int | str` syntax (Python 3.10+) and Unions
103        if (
104            isinstance(annotation, UnionType)
105            or getattr(annotation, "__origin__", None) == Union
106        ):
107            combined_types = {}
108            for sub_type in annotation.__args__:
109                combined_types = DeepMerge(
110                    combined_types, self.__get_checkable_type__(sub_type)
111                )
112            return combined_types
113
114        # Handle typing.Literal
115        if getattr(annotation, "__origin__", None) == Literal:
116            return {"__extra__": {"__literal__": list(annotation.__args__)}}
117
118        # Handle generic collections
119        origin = getattr(annotation, "__origin__", None)
120        args = getattr(annotation, "__args__", ())
121
122        if origin == list:
123            if len(args) != 1:
124                self.__exception__(
125                    f"List must have a single type argument, got: {args}",
126                    raise_exception=True,
127                )
128            return {list: self.__get_checkable_type__(args[0])}
129
130        if origin == dict:
131            if len(args) != 2:
132                self.__exception__(
133                    f"Dict must have two type arguments, got: {args}",
134                    raise_exception=True,
135                )
136            key_type = self.__get_checkable_type__(args[0])
137            value_type = self.__get_checkable_type__(args[1])
138            return {dict: (key_type, value_type)}
139
140        if origin == tuple:
141            if len(args) > 2 or len(args) == 1:
142                if Ellipsis in args:
143                    self.__exception__(
144                        "Tuple with Ellipsis must have exactly two type arguments and the second must be Ellipsis.",
145                        raise_exception=True,
146                    )
147            if len(args) == 2:
148                if args[0] is Ellipsis:
149                    self.__exception__(
150                        "Tuple with Ellipsis must have exactly two type arguments and the first must not be Ellipsis.",
151                        raise_exception=True,
152                    )
153                if args[1] is Ellipsis:
154                    return {tuple: (self.__get_checkable_type__(args[0]), True)}
155            return {
156                tuple: (
157                    tuple(self.__get_checkable_type__(arg) for arg in args),
158                    False,
159                )
160            }
161
162        if origin == set:
163            if len(args) != 1:
164                self.__exception__(
165                    f"Set must have a single type argument, got: {args}",
166                    raise_exception=True,
167                )
168            return {set: self.__get_checkable_type__(args[0])}
169
170        # Handle Sized types
171        if annotation == Sized:
172            return {
173                list: None,
174                tuple: None,
175                dict: None,
176                set: None,
177                str: None,
178                bytes: None,
179                bytearray: None,
180                memoryview: None,
181                range: None,
182            }
183
184        # Handle Callable types
185        if annotation == Callable:
186            return {
187                staticmethod: None,
188                classmethod: None,
189                FunctionType: None,
190                BuiltinFunctionType: None,
191                MethodType: None,
192                BuiltinMethodType: None,
193                GeneratorType: None,
194            }
195
196        if annotation == Any:
197            return {
198                object: None,
199            }
200
201        # Handle Constraints
202        if isinstance(annotation, GenericConstraint):
203            return {"__extra__": {"__constraints__": [annotation]}}
204
205        # Handle standard types
206        if isinstance(annotation, type):
207            return {annotation: None}
208
209        # Hanldle typing.Type (for uninitialized classes)
210        if origin is type and len(args) == 1:
211            return {annotation: None}
212
213        self.__exception__(
214            f"Unsupported type hint: {annotation}", raise_exception=True
215        )
216
217    def __exception__(self, message, raise_exception=False):
218        """
219        Usage:
220
221        - Creates a class based exception message
222
223        Requires:
224
225        - `message`:
226            - Type: str
227            - What: The message to warn users with
228
229        Optional:
230
231        - `raise_exception`:
232            - Type: bool
233            - What: Forces an exception to be raised regardless of the `self.__strict__` setting.
234            - Default: False
235        """
236        if self.__strict__ or raise_exception:
237            msg = f"TypeEnforced Exception ({self.__fn__.__qualname__}): {message}"
238            if self.__clean_traceback__:
239                package_path = Path(__file__).parent.resolve()
240                frame = sys._getframe()
241                relevant_tb_count = 0
242                while frame is not None:
243                    frame_file = Path(frame.f_code.co_filename).resolve()
244                    try:
245                        frame_file.relative_to(package_path)
246                    except ValueError:
247                        relevant_tb_count += 1
248                    frame = frame.f_back
249                original_excepthook = sys.excepthook
250
251                def excepthook(type, value, tb):
252                    traceback.print_exception(
253                        type, value, tb, limit=relevant_tb_count
254                    )
255                    sys.excepthook = original_excepthook
256
257                sys.excepthook = excepthook
258            raise TypeError(msg)
259        else:
260            print(
261                f"TypeEnforced Warning ({self.__fn__.__qualname__}): {message}"
262            )
263
264    def __get__(self, obj, objtype):
265        """
266        Overwrite standard __get__ method to return __call__ instead for wrapped class methods.
267
268        Also stores the calling (__get__) `obj` to be passed as an initial argument for `__call__` such that methods can pass `self` correctly.
269        """
270
271        @wraps(self.__fn__)
272        def __get_fn__(*args, **kwargs):
273            return self.__call__(*args, **kwargs)
274
275        self.__outer_self__ = obj
276        return __get_fn__
277
278    def __check_method_function__(self):
279        """
280        Validate that `self.__fn__` is a method or function
281        """
282        if not isinstance(self.__fn__, (MethodType, FunctionType)):
283            raise Exception(
284                f"A non function/method was passed to Enforcer. See the stack trace above for more information."
285            )
286
287    def __call__(self, *args, **kwargs):
288        """
289        This method is used to validate the passed inputs and return the output of the wrapped function or method.
290        """
291        # Special code to pass self as an initial argument
292        # for validation purposes in methods
293        # See: self.__get__
294        if self.__outer_self__ is not None:
295            args = (self.__outer_self__, *args)
296        # Get a dictionary of all annotations as checkable types
297        # Note: This is only done once at first call to avoid redundant calculations
298        self.__get_checkable_types__()
299        # Create a compreshensive dictionary of assigned variables (order matters)
300        assigned_vars = {
301            **self.__fn_defaults__,
302            **dict(zip(self.__fn__.__code__.co_varnames[: len(args)], args)),
303            **kwargs,
304        }
305        # Validate all listed annotations vs the assigned_vars dictionary
306        for key, value in self.__checkable_types__.items():
307            self.__check_type__(assigned_vars.get(key), value, key)
308        # Execute the function callable
309        return_value = self.__fn__(*args, **kwargs)
310        # If a return type was passed, validate the returned object
311        if self.__return_type__ is not None:
312            self.__check_type__(return_value, self.__return_type__, "return")
313        return return_value
314
315    def __quick_check__(self, subtype, obj):
316        if all([v == None for v in subtype.values()]):
317            # If the subtype does not contain iterables with typing, we can validate the items directly.
318            types = set(subtype.keys())
319            values = set([type(v) for v in obj])
320            if values.issubset(types):
321                # We can return True to bypass the full validation
322                return True
323            # Otherwise, validation did not pass and a full validation is required to raise an indexed/keyed type mismatch error
324        return False
325
326    def __check_type__(self, obj, expected, key):
327        """
328        Raises an exception the type of a passed `obj` (parameter) is not in the list of supplied `acceptable_types` for the argument.
329        """
330        # Special case for None
331        if obj is None and type(None) in expected:
332            return
333        extra = expected.get("__extra__", {})
334        expected = {k: v for k, v in expected.items() if k != "__extra__"}
335
336        if isinstance(obj, type):
337            # An uninitialized class is passed, we need to check if the type is in the expected types using Type[obj]
338            obj_type = Type[obj]
339            is_present = obj_type in expected
340        else:
341            obj_type = type(obj)
342            is_present = isinstance(obj, tuple(expected.keys()))
343
344        if not is_present:
345            # Allow for literals to be used to bypass type checks if present
346            literal = extra.get("__literal__", ())
347            if literal:
348                if obj not in literal:
349                    self.__exception__(
350                        f"Type mismatch for typed variable `{key}`. Expected one of the following `{list(expected.keys())}` or a literal value in `{literal}` but got type `{obj_type}` with value `{obj}` instead."
351                    )
352            # Raise an exception if the type is not in the expected types
353            else:
354                self.__exception__(
355                    f"Type mismatch for typed variable `{key}`. Expected one of the following `{list(expected.keys())}` but got `{obj_type}` with value `{obj}` instead."
356                )
357        # If the object_type is in the expected types, we can proceed with validation
358        elif obj_type in iterable_types:
359            subtype = expected.get(obj_type, None)
360            if subtype is None:
361                pass
362            # Recursive validation
363            elif obj_type == list:
364                # If the subtype does not contain iterables with typing, we can validate the items directly.
365                if not self.__quick_check__(subtype, obj):
366                    for idx, item in enumerate(obj):
367                        self.__check_type__(item, subtype, f"{key}[{idx}]")
368            elif obj_type == dict:
369                key_type, val_type = subtype
370                if not self.__quick_check__(key_type, obj.keys()):
371                    for key in obj.keys():
372                        self.__check_type__(
373                            key, key_type, f"{key}.key[{repr(key)}]"
374                        )
375                if not self.__quick_check__(val_type, obj.values()):
376                    for key, value in obj.items():
377                        self.__check_type__(
378                            value, val_type, f"{key}[{repr(key)}]"
379                        )
380            elif obj_type == tuple:
381                expected_args, is_ellipsis = subtype
382                if is_ellipsis:
383                    if not self.__quick_check__(expected_args, obj):
384                        for idx, item in enumerate(obj):
385                            self.__check_type__(
386                                item, expected_args, f"{key}[{idx}]"
387                            )
388                else:
389                    if len(obj) != len(expected_args):
390                        self.__exception__(
391                            f"Tuple length mismatch for `{key}`. Expected length {len(expected_args)}, got {len(obj)}"
392                        )
393                    for idx, (item, ex) in enumerate(zip(obj, expected_args)):
394                        self.__check_type__(item, ex, f"{key}[{idx}]")
395            elif obj_type == set:
396                if not self.__quick_check__(subtype, obj):
397                    for item in obj:
398                        self.__check_type__(
399                            item, subtype, f"{key}[{repr(item)}]"
400                        )
401
402        # Validate constraints if any are present
403        constraints = extra.get("__constraints__", [])
404        for constraint in constraints:
405            constraint_validation_output = constraint.__validate__(key, obj)
406            if constraint_validation_output is not True:
407                self.__exception__(
408                    f"Constraint validation error for variable `{key}` with value `{obj}`. {constraint_validation_output}"
409                )
410
411    def __repr__(self):
412        return f"<type_enforced {self.__fn__.__module__}.{self.__fn__.__qualname__} object at {hex(id(self))}>"
FunctionMethodEnforcer(__fn__, __strict__=False, __clean_traceback__=True)
23    def __init__(self, __fn__, __strict__=False, __clean_traceback__=True):
24        """
25        Initialize a FunctionMethodEnforcer class object as a wrapper for a passed function `__fn__`.
26
27        Requires:
28
29            - `__fn__`:
30                - What: The function to enforce
31                - Type: function | method | class
32
33        Optional:
34
35            - `__strict__`:
36                - What: A boolean to enable or disable exceptions. If True, exceptions will be raised
37                    when type checking fails. If False, exceptions will not be raised but instead a warning
38                    will be printed to the console.
39                - Type: bool
40                - Default: False
41            - `__clean_traceback__`:
42                - What: A boolean to enable or disable cleaning of tracebacks when raising exceptions.
43                - Type: bool
44                - Default: True
45        """
46        update_wrapper(self, __fn__)
47        self.__fn__ = __fn__
48        self.__strict__ = __strict__
49        self.__clean_traceback__ = __clean_traceback__
50        self.__outer_self__ = None
51        # Validate that the passed function or method is a method or function
52        self.__check_method_function__()
53        # Get input defaults for the function or method
54        self.__get_defaults__()

Initialize a FunctionMethodEnforcer class object as a wrapper for a passed function __fn__.

Requires:

- `__fn__`:
    - What: The function to enforce
    - Type: function | method | class

Optional:

- `__strict__`:
    - What: A boolean to enable or disable exceptions. If True, exceptions will be raised
        when type checking fails. If False, exceptions will not be raised but instead a warning
        will be printed to the console.
    - Type: bool
    - Default: False
- `__clean_traceback__`:
    - What: A boolean to enable or disable cleaning of tracebacks when raising exceptions.
    - Type: bool
    - Default: True
@Partial
def Enforcer(clsFnMethod, enabled=True, strict=True, clean_traceback=True):
415@Partial
416def Enforcer(clsFnMethod, enabled=True, strict=True, clean_traceback=True):
417    """
418    A wrapper to enforce types within a function or method given argument annotations.
419
420    Each wrapped item is converted into a special `FunctionMethodEnforcer` class object that validates the passed parameters for the function or method when it is called. If a function or method that is passed does not have any annotations, it is not converted into a `FunctionMethodEnforcer` class as no validation is possible.
421
422    If wrapping a class, all methods in the class that meet any of the following criteria will be wrapped individually:
423
424    - Methods with `__call__`
425    - Methods wrapped with `staticmethod` (if python >= 3.10)
426    - Methods wrapped with `classmethod` (if python >= 3.10)
427
428    Requires:
429
430    - `clsFnMethod`:
431        - What: The class, function or method that should have input types enforced
432        - Type: function | method | class
433
434    Optional:
435
436    - `enabled`:
437        - What: A boolean to enable or disable the enforcer
438        - Type: bool
439        - Default: True
440    - `strict`:
441        - What: A boolean to enable or disable exceptions. If True, exceptions will be raised when type checking fails. If False, exceptions will not be raised but instead a warning will be printed to the console.
442        - Type: bool
443        - Default: False
444        - Note: Type hints that are wrapped with the type enforcer and are invalid will still raise an exception.
445    - `clean_traceback`:
446        - What: A boolean to enable or disable cleaning of tracebacks when raising exceptions.
447        - If True, modifies the excepthook temporarily such that only the relevant stack (not in the type_enforced package) is shown.
448        - Type: bool
449        - Default: True
450
451
452    Example Use:
453    ```
454    >>> import type_enforced
455    >>> @type_enforced.Enforcer
456    ... def my_fn(a: int , b: [int, str] =2, c: int =3) -> None:
457    ...     pass
458    ...
459    >>> my_fn(a=1, b=2, c=3)
460    >>> my_fn(a=1, b='2', c=3)
461    >>> my_fn(a='a', b=2, c=3)
462    Traceback (most recent call last):
463      File "<stdin>", line 1, in <module>
464      File "/home/conmak/development/personal/type_enforced/type_enforced/enforcer.py", line 85, in __call__
465        self.__check_type__(assigned_vars.get(key), value, key)
466      File "/home/conmak/development/personal/type_enforced/type_enforced/enforcer.py", line 107, in __check_type__
467        self.__exception__(
468      File "/home/conmak/development/personal/type_enforced/type_enforced/enforcer.py", line 34, in __exception__
469        raise Exception(f"({self.__fn__.__qualname__}): {message}")
470    Exception: (my_fn): Type mismatch for typed variable `a`. Expected one of the following `[<class 'int'>]` but got `<class 'str'>` instead.
471    ```
472    """
473    if not hasattr(clsFnMethod, "__type_enforced_enabled__"):
474        # Special try except clause to handle cases when the object is immutable
475        try:
476            clsFnMethod.__type_enforced_enabled__ = enabled
477        except:
478            return clsFnMethod
479    if not clsFnMethod.__type_enforced_enabled__:
480        return clsFnMethod
481    if isinstance(
482        clsFnMethod, (staticmethod, classmethod, FunctionType, MethodType)
483    ):
484        # Only apply the enforcer if type_hints are present
485        # Add try except clause to better handle forward refs.
486        try:
487            if get_type_hints(clsFnMethod) == {}:
488                return clsFnMethod
489        except:
490            pass
491        if isinstance(clsFnMethod, staticmethod):
492            return staticmethod(
493                FunctionMethodEnforcer(
494                    __fn__=clsFnMethod.__func__,
495                    __strict__=strict,
496                    __clean_traceback__=clean_traceback,
497                )
498            )
499        elif isinstance(clsFnMethod, classmethod):
500            return classmethod(
501                FunctionMethodEnforcer(
502                    __fn__=clsFnMethod.__func__,
503                    __strict__=strict,
504                    __clean_traceback__=clean_traceback,
505                )
506            )
507        else:
508            return FunctionMethodEnforcer(
509                __fn__=clsFnMethod,
510                __strict__=strict,
511                __clean_traceback__=clean_traceback,
512            )
513    elif hasattr(clsFnMethod, "__dict__"):
514        for key, value in clsFnMethod.__dict__.items():
515            # Skip the __annotate__ method if present in __dict__ as it deletes itself upon invocation
516            # Skip any previously wrapped methods if they are already a FunctionMethodEnforcer
517            if key == "__annotate__" or isinstance(
518                value, FunctionMethodEnforcer
519            ):
520                continue
521            if hasattr(value, "__call__") or isinstance(
522                value, (classmethod, staticmethod)
523            ):
524                setattr(
525                    clsFnMethod,
526                    key,
527                    Enforcer(
528                        value,
529                        enabled=enabled,
530                        strict=strict,
531                        clean_traceback=clean_traceback,
532                    ),
533                )
534        return clsFnMethod
535    else:
536        raise Exception(
537            "Enforcer can only be used on classes, methods, or functions."
538        )

A wrapper to enforce types within a function or method given argument annotations.

Each wrapped item is converted into a special FunctionMethodEnforcer class object that validates the passed parameters for the function or method when it is called. If a function or method that is passed does not have any annotations, it is not converted into a FunctionMethodEnforcer class as no validation is possible.

If wrapping a class, all methods in the class that meet any of the following criteria will be wrapped individually:

  • Methods with __call__
  • Methods wrapped with staticmethod (if python >= 3.10)
  • Methods wrapped with classmethod (if python >= 3.10)

Requires:

  • clsFnMethod:
    • What: The class, function or method that should have input types enforced
    • Type: function | method | class

Optional:

  • enabled:
    • What: A boolean to enable or disable the enforcer
    • Type: bool
    • Default: True
  • strict:
    • What: A boolean to enable or disable exceptions. If True, exceptions will be raised when type checking fails. If False, exceptions will not be raised but instead a warning will be printed to the console.
    • Type: bool
    • Default: False
    • Note: Type hints that are wrapped with the type enforcer and are invalid will still raise an exception.
  • clean_traceback:
    • What: A boolean to enable or disable cleaning of tracebacks when raising exceptions.
    • If True, modifies the excepthook temporarily such that only the relevant stack (not in the type_enforced package) is shown.
    • Type: bool
    • Default: True

Example Use:

>>> import type_enforced
>>> @type_enforced.Enforcer
... def my_fn(a: int , b: [int, str] =2, c: int =3) -> None:
...     pass
...
>>> my_fn(a=1, b=2, c=3)
>>> my_fn(a=1, b='2', c=3)
>>> my_fn(a='a', b=2, c=3)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/conmak/development/personal/type_enforced/type_enforced/enforcer.py", line 85, in __call__
    self.__check_type__(assigned_vars.get(key), value, key)
  File "/home/conmak/development/personal/type_enforced/type_enforced/enforcer.py", line 107, in __check_type__
    self.__exception__(
  File "/home/conmak/development/personal/type_enforced/type_enforced/enforcer.py", line 34, in __exception__
    raise Exception(f"({self.__fn__.__qualname__}): {message}")
Exception: (my_fn): Type mismatch for typed variable `a`. Expected one of the following `[<class 'int'>]` but got `<class 'str'>` instead.