type_enforced.enforcer

  1from types import (
  2    FunctionType,
  3    MethodType,
  4    GeneratorType,
  5    BuiltinFunctionType,
  6    BuiltinMethodType,
  7    UnionType,
  8)
  9from typing import Type, Union, Sized, Literal, Callable, get_type_hints, Any
 10from functools import update_wrapper, wraps
 11from type_enforced.utils import (
 12    Partial,
 13    GenericConstraint,
 14    DeepMerge,
 15    iterable_types,
 16)
 17
 18
 19class FunctionMethodEnforcer:
 20    def __init__(self, __fn__, __strict__=False):
 21        """
 22        Initialize a FunctionMethodEnforcer class object as a wrapper for a passed function `__fn__`.
 23
 24        Requires:
 25
 26            - `__fn__`:
 27                - What: The function to enforce
 28                - Type: function | method | class
 29        """
 30        update_wrapper(self, __fn__)
 31        self.__fn__ = __fn__
 32        self.__strict__ = __strict__
 33        self.__outer_self__ = None
 34        # Validate that the passed function or method is a method or function
 35        self.__check_method_function__()
 36        # Get input defaults for the function or method
 37        self.__get_defaults__()
 38
 39    def __get_defaults__(self):
 40        """
 41        Get the default values of the passed function or method and store them in `self.__fn_defaults__`.
 42        """
 43        self.__fn_defaults__ = {}
 44        if self.__fn__.__defaults__ is not None:
 45            # Get the names of all provided default values for args
 46            default_varnames = list(self.__fn__.__code__.co_varnames)[
 47                : self.__fn__.__code__.co_argcount
 48            ][-len(self.__fn__.__defaults__) :]
 49            # Update the output dictionary with the default values
 50            self.__fn_defaults__.update(
 51                dict(zip(default_varnames, self.__fn__.__defaults__))
 52            )
 53        if self.__fn__.__kwdefaults__ is not None:
 54            # Update the output dictionary with the keyword default values
 55            self.__fn_defaults__.update(self.__fn__.__kwdefaults__)
 56
 57    def __get_checkable_types__(self):
 58        """
 59        Creates two class attributes:
 60
 61        - `self.__checkable_types__`:
 62            - What: A dictionary of all annotations as checkable types
 63            - Type: dict
 64
 65        - `self.__return_type__`:
 66            - What: The return type of the function or method
 67            - Type: dict | None
 68        """
 69        if not hasattr(self, "__checkable_types__"):
 70            self.__checkable_types__ = {
 71                key: self.__get_checkable_type__(value)
 72                for key, value in get_type_hints(self.__fn__).items()
 73            }
 74            self.__return_type__ = self.__checkable_types__.pop("return", None)
 75
 76    def __get_checkable_type__(self, annotation):
 77        """
 78        Parses a type annotation and returns a nested dict structure
 79        representing the checkable type(s) for validation.
 80        """
 81
 82        if annotation is None:
 83            return {type(None): None}
 84
 85        # Handle `int | str` syntax (Python 3.10+) and Unions
 86        if (
 87            isinstance(annotation, UnionType)
 88            or getattr(annotation, "__origin__", None) == Union
 89        ):
 90            combined_types = {}
 91            for sub_type in annotation.__args__:
 92                combined_types = DeepMerge(
 93                    combined_types, self.__get_checkable_type__(sub_type)
 94                )
 95            return combined_types
 96
 97        # Handle typing.Literal
 98        if getattr(annotation, "__origin__", None) == Literal:
 99            return {"__extra__": {"__literal__": list(annotation.__args__)}}
100
101        # Handle generic collections
102        origin = getattr(annotation, "__origin__", None)
103        args = getattr(annotation, "__args__", ())
104
105        if origin == list:
106            if len(args) != 1:
107                self.__exception__(
108                    f"List must have a single type argument, got: {args}",
109                    raise_exception=True,
110                )
111            return {list: self.__get_checkable_type__(args[0])}
112
113        if origin == dict:
114            if len(args) != 2:
115                self.__exception__(
116                    f"Dict must have two type arguments, got: {args}",
117                    raise_exception=True,
118                )
119            key_type = self.__get_checkable_type__(args[0])
120            value_type = self.__get_checkable_type__(args[1])
121            return {dict: (key_type, value_type)}
122
123        if origin == tuple:
124            if len(args) > 2 or len(args) == 1:
125                if Ellipsis in args:
126                    self.__exception__(
127                        "Tuple with Ellipsis must have exactly two type arguments and the second must be Ellipsis.",
128                        raise_exception=True,
129                    )
130            if len(args) == 2:
131                if args[0] is Ellipsis:
132                    self.__exception__(
133                        "Tuple with Ellipsis must have exactly two type arguments and the first must not be Ellipsis.",
134                        raise_exception=True,
135                    )
136                if args[1] is Ellipsis:
137                    return {tuple: (self.__get_checkable_type__(args[0]), True)}
138            return {
139                tuple: (
140                    tuple(self.__get_checkable_type__(arg) for arg in args),
141                    False,
142                )
143            }
144
145        if origin == set:
146            if len(args) != 1:
147                self.__exception__(
148                    f"Set must have a single type argument, got: {args}",
149                    raise_exception=True,
150                )
151            return {set: self.__get_checkable_type__(args[0])}
152
153        # Handle Sized types
154        if annotation == Sized:
155            return {
156                list: None,
157                tuple: None,
158                dict: None,
159                set: None,
160                str: None,
161                bytes: None,
162                bytearray: None,
163                memoryview: None,
164                range: None,
165            }
166
167        # Handle Callable types
168        if annotation == Callable:
169            return {
170                staticmethod: None,
171                classmethod: None,
172                FunctionType: None,
173                BuiltinFunctionType: None,
174                MethodType: None,
175                BuiltinMethodType: None,
176                GeneratorType: None,
177            }
178
179        if annotation == Any:
180            return {
181                object: None,
182            }
183
184        # Handle Constraints
185        if isinstance(annotation, GenericConstraint):
186            return {"__extra__": {"__constraints__": [annotation]}}
187
188        # Handle standard types
189        if isinstance(annotation, type):
190            return {annotation: None}
191
192        # Hanldle typing.Type (for uninitialized classes)
193        if origin is type and len(args) == 1:
194            return {annotation: None}
195
196        self.__exception__(
197            f"Unsupported type hint: {annotation}", raise_exception=True
198        )
199
200    def __exception__(self, message, raise_exception=False):
201        """
202        Usage:
203
204        - Creates a class based exception message
205
206        Requires:
207
208        - `message`:
209            - Type: str
210            - What: The message to warn users with
211
212        Optional:
213
214        - `raise_exception`:
215            - Type: bool
216            - What: Forces an exception to be raised regardless of the `self.__strict__` setting.
217            - Default: False
218        """
219        if self.__strict__ or raise_exception:
220            raise TypeError(
221                f"TypeEnforced Exception ({self.__fn__.__qualname__}): {message}"
222            )
223        else:
224            print(
225                f"TypeEnforced Warning ({self.__fn__.__qualname__}): {message}"
226            )
227
228    def __get__(self, obj, objtype):
229        """
230        Overwrite standard __get__ method to return __call__ instead for wrapped class methods.
231
232        Also stores the calling (__get__) `obj` to be passed as an initial argument for `__call__` such that methods can pass `self` correctly.
233        """
234
235        @wraps(self.__fn__)
236        def __get_fn__(*args, **kwargs):
237            return self.__call__(*args, **kwargs)
238
239        self.__outer_self__ = obj
240        return __get_fn__
241
242    def __check_method_function__(self):
243        """
244        Validate that `self.__fn__` is a method or function
245        """
246        if not isinstance(self.__fn__, (MethodType, FunctionType)):
247            raise Exception(
248                f"A non function/method was passed to Enforcer. See the stack trace above for more information."
249            )
250
251    def __call__(self, *args, **kwargs):
252        """
253        This method is used to validate the passed inputs and return the output of the wrapped function or method.
254        """
255        # Special code to pass self as an initial argument
256        # for validation purposes in methods
257        # See: self.__get__
258        if self.__outer_self__ is not None:
259            args = (self.__outer_self__, *args)
260        # Get a dictionary of all annotations as checkable types
261        # Note: This is only done once at first call to avoid redundant calculations
262        self.__get_checkable_types__()
263        # Create a compreshensive dictionary of assigned variables (order matters)
264        assigned_vars = {
265            **self.__fn_defaults__,
266            **dict(zip(self.__fn__.__code__.co_varnames[: len(args)], args)),
267            **kwargs,
268        }
269        # Validate all listed annotations vs the assigned_vars dictionary
270        for key, value in self.__checkable_types__.items():
271            self.__check_type__(assigned_vars.get(key), value, key)
272        # Execute the function callable
273        return_value = self.__fn__(*args, **kwargs)
274        # If a return type was passed, validate the returned object
275        if self.__return_type__ is not None:
276            self.__check_type__(return_value, self.__return_type__, "return")
277        return return_value
278
279    def __quick_check__(self, subtype, obj):
280        if all([v == None for v in subtype.values()]):
281            # If the subtype does not contain iterables with typing, we can validate the items directly.
282            types = set(subtype.keys())
283            values = set([type(v) for v in obj])
284            if values.issubset(types):
285                # We can return True to bypass the full validation
286                return True
287            # Otherwise, validation did not pass and a full validation is required to raise an indexed/keyed type mismatch error
288        return False
289
290    def __check_type__(self, obj, expected, key):
291        """
292        Raises an exception the type of a passed `obj` (parameter) is not in the list of supplied `acceptable_types` for the argument.
293        """
294        # Special case for None
295        if obj is None and type(None) in expected:
296            return
297        extra = expected.get("__extra__", {})
298        expected = {k: v for k, v in expected.items() if k != "__extra__"}
299
300        if isinstance(obj, type):
301            # An uninitialized class is passed, we need to check if the type is in the expected types using Type[obj]
302            obj_type = Type[obj]
303            is_present = obj_type in expected
304        else:
305            obj_type = type(obj)
306            is_present = isinstance(obj, tuple(expected.keys()))
307
308        if not is_present:
309            # Allow for literals to be used to bypass type checks if present
310            literal = extra.get("__literal__", ())
311            if literal:
312                if obj not in literal:
313                    self.__exception__(
314                        f"Type mismatch for typed variable `{key}`. Expected one of the following `{list(expected.keys())}` or a literal value in `{literal}` but got type `{obj_type}` with value `{obj}` instead."
315                    )
316            # Raise an exception if the type is not in the expected types
317            else:
318                self.__exception__(
319                    f"Type mismatch for typed variable `{key}`. Expected one of the following `{list(expected.keys())}` but got `{obj_type}` with value `{obj}` instead."
320                )
321        # If the object_type is in the expected types, we can proceed with validation
322        elif obj_type in iterable_types:
323            subtype = expected.get(obj_type, None)
324            if subtype is None:
325                pass
326            # Recursive validation
327            elif obj_type == list:
328                # If the subtype does not contain iterables with typing, we can validate the items directly.
329                if not self.__quick_check__(subtype, obj):
330                    for idx, item in enumerate(obj):
331                        self.__check_type__(item, subtype, f"{key}[{idx}]")
332            elif obj_type == dict:
333                key_type, val_type = subtype
334                if not self.__quick_check__(key_type, obj.keys()):
335                    for key in obj.keys():
336                        self.__check_type__(
337                            key, key_type, f"{key}.key[{repr(key)}]"
338                        )
339                if not self.__quick_check__(val_type, obj.values()):
340                    for key, value in obj.items():
341                        self.__check_type__(
342                            value, val_type, f"{key}[{repr(key)}]"
343                        )
344            elif obj_type == tuple:
345                expected_args, is_ellipsis = subtype
346                if is_ellipsis:
347                    if not self.__quick_check__(expected_args, obj):
348                        for idx, item in enumerate(obj):
349                            self.__check_type__(
350                                item, expected_args, f"{key}[{idx}]"
351                            )
352                else:
353                    if len(obj) != len(expected_args):
354                        self.__exception__(
355                            f"Tuple length mismatch for `{key}`. Expected length {len(expected_args)}, got {len(obj)}"
356                        )
357                    for idx, (item, ex) in enumerate(zip(obj, expected_args)):
358                        self.__check_type__(item, ex, f"{key}[{idx}]")
359            elif obj_type == set:
360                if not self.__quick_check__(subtype, obj):
361                    for item in obj:
362                        self.__check_type__(
363                            item, subtype, f"{key}[{repr(item)}]"
364                        )
365
366        # Validate constraints if any are present
367        constraints = extra.get("__constraints__", [])
368        for constraint in constraints:
369            constraint_validation_output = constraint.__validate__(key, obj)
370            if constraint_validation_output is not True:
371                self.__exception__(
372                    f"Constraint validation error for variable `{key}` with value `{obj}`. {constraint_validation_output}"
373                )
374
375    def __repr__(self):
376        return f"<type_enforced {self.__fn__.__module__}.{self.__fn__.__qualname__} object at {hex(id(self))}>"
377
378
379@Partial
380def Enforcer(clsFnMethod, enabled=True, strict=True):
381    """
382    A wrapper to enforce types within a function or method given argument annotations.
383
384    Each wrapped item is converted into a special `FunctionMethodEnforcer` class object that validates the passed parameters for the function or method when it is called. If a function or method that is passed does not have any annotations, it is not converted into a `FunctionMethodEnforcer` class as no validation is possible.
385
386    If wrapping a class, all methods in the class that meet any of the following criteria will be wrapped individually:
387
388    - Methods with `__call__`
389    - Methods wrapped with `staticmethod` (if python >= 3.10)
390    - Methods wrapped with `classmethod` (if python >= 3.10)
391
392    Requires:
393
394    - `clsFnMethod`:
395        - What: The class, function or method that should have input types enforced
396        - Type: function | method | class
397
398    Optional:
399
400    - `enabled`:
401        - What: A boolean to enable or disable the enforcer
402        - Type: bool
403        - Default: True
404    - `strict`:
405        - What: A boolean to enable or disable exceptions. If True, exceptions will be raised when type checking fails. If False, exceptions will not be raised but instead a warning will be printed to the console.
406        - Type: bool
407        - Default: False
408        - Note: Type hints that are wrapped with the type enforcer and are invalid will still raise an exception.
409
410
411    Example Use:
412    ```
413    >>> import type_enforced
414    >>> @type_enforced.Enforcer
415    ... def my_fn(a: int , b: [int, str] =2, c: int =3) -> None:
416    ...     pass
417    ...
418    >>> my_fn(a=1, b=2, c=3)
419    >>> my_fn(a=1, b='2', c=3)
420    >>> my_fn(a='a', b=2, c=3)
421    Traceback (most recent call last):
422      File "<stdin>", line 1, in <module>
423      File "/home/conmak/development/personal/type_enforced/type_enforced/enforcer.py", line 85, in __call__
424        self.__check_type__(assigned_vars.get(key), value, key)
425      File "/home/conmak/development/personal/type_enforced/type_enforced/enforcer.py", line 107, in __check_type__
426        self.__exception__(
427      File "/home/conmak/development/personal/type_enforced/type_enforced/enforcer.py", line 34, in __exception__
428        raise Exception(f"({self.__fn__.__qualname__}): {message}")
429    Exception: (my_fn): Type mismatch for typed variable `a`. Expected one of the following `[<class 'int'>]` but got `<class 'str'>` instead.
430    ```
431    """
432    if not hasattr(clsFnMethod, "__type_enforced_enabled__"):
433        # Special try except clause to handle cases when the object is immutable
434        try:
435            clsFnMethod.__type_enforced_enabled__ = enabled
436        except:
437            return clsFnMethod
438    if not clsFnMethod.__type_enforced_enabled__:
439        return clsFnMethod
440    if isinstance(
441        clsFnMethod, (staticmethod, classmethod, FunctionType, MethodType)
442    ):
443        # Only apply the enforcer if type_hints are present
444        if get_type_hints(clsFnMethod) == {}:
445            return clsFnMethod
446        elif isinstance(clsFnMethod, staticmethod):
447            return staticmethod(
448                FunctionMethodEnforcer(
449                    __fn__=clsFnMethod.__func__, __strict__=strict
450                )
451            )
452        elif isinstance(clsFnMethod, classmethod):
453            return classmethod(
454                FunctionMethodEnforcer(
455                    __fn__=clsFnMethod.__func__, __strict__=strict
456                )
457            )
458        else:
459            return FunctionMethodEnforcer(__fn__=clsFnMethod, __strict__=strict)
460    elif hasattr(clsFnMethod, "__dict__"):
461        for key, value in clsFnMethod.__dict__.items():
462            # Skip the __annotate__ method if present in __dict__ as it deletes itself upon invocation
463            # Skip any previously wrapped methods if they are already a FunctionMethodEnforcer
464            if key == "__annotate__" or isinstance(
465                value, FunctionMethodEnforcer
466            ):
467                continue
468            if hasattr(value, "__call__") or isinstance(
469                value, (classmethod, staticmethod)
470            ):
471                setattr(
472                    clsFnMethod,
473                    key,
474                    Enforcer(value, enabled=enabled, strict=strict),
475                )
476        return clsFnMethod
477    else:
478        raise Exception(
479            "Enforcer can only be used on classes, methods, or functions."
480        )
class FunctionMethodEnforcer:
 20class FunctionMethodEnforcer:
 21    def __init__(self, __fn__, __strict__=False):
 22        """
 23        Initialize a FunctionMethodEnforcer class object as a wrapper for a passed function `__fn__`.
 24
 25        Requires:
 26
 27            - `__fn__`:
 28                - What: The function to enforce
 29                - Type: function | method | class
 30        """
 31        update_wrapper(self, __fn__)
 32        self.__fn__ = __fn__
 33        self.__strict__ = __strict__
 34        self.__outer_self__ = None
 35        # Validate that the passed function or method is a method or function
 36        self.__check_method_function__()
 37        # Get input defaults for the function or method
 38        self.__get_defaults__()
 39
 40    def __get_defaults__(self):
 41        """
 42        Get the default values of the passed function or method and store them in `self.__fn_defaults__`.
 43        """
 44        self.__fn_defaults__ = {}
 45        if self.__fn__.__defaults__ is not None:
 46            # Get the names of all provided default values for args
 47            default_varnames = list(self.__fn__.__code__.co_varnames)[
 48                : self.__fn__.__code__.co_argcount
 49            ][-len(self.__fn__.__defaults__) :]
 50            # Update the output dictionary with the default values
 51            self.__fn_defaults__.update(
 52                dict(zip(default_varnames, self.__fn__.__defaults__))
 53            )
 54        if self.__fn__.__kwdefaults__ is not None:
 55            # Update the output dictionary with the keyword default values
 56            self.__fn_defaults__.update(self.__fn__.__kwdefaults__)
 57
 58    def __get_checkable_types__(self):
 59        """
 60        Creates two class attributes:
 61
 62        - `self.__checkable_types__`:
 63            - What: A dictionary of all annotations as checkable types
 64            - Type: dict
 65
 66        - `self.__return_type__`:
 67            - What: The return type of the function or method
 68            - Type: dict | None
 69        """
 70        if not hasattr(self, "__checkable_types__"):
 71            self.__checkable_types__ = {
 72                key: self.__get_checkable_type__(value)
 73                for key, value in get_type_hints(self.__fn__).items()
 74            }
 75            self.__return_type__ = self.__checkable_types__.pop("return", None)
 76
 77    def __get_checkable_type__(self, annotation):
 78        """
 79        Parses a type annotation and returns a nested dict structure
 80        representing the checkable type(s) for validation.
 81        """
 82
 83        if annotation is None:
 84            return {type(None): None}
 85
 86        # Handle `int | str` syntax (Python 3.10+) and Unions
 87        if (
 88            isinstance(annotation, UnionType)
 89            or getattr(annotation, "__origin__", None) == Union
 90        ):
 91            combined_types = {}
 92            for sub_type in annotation.__args__:
 93                combined_types = DeepMerge(
 94                    combined_types, self.__get_checkable_type__(sub_type)
 95                )
 96            return combined_types
 97
 98        # Handle typing.Literal
 99        if getattr(annotation, "__origin__", None) == Literal:
100            return {"__extra__": {"__literal__": list(annotation.__args__)}}
101
102        # Handle generic collections
103        origin = getattr(annotation, "__origin__", None)
104        args = getattr(annotation, "__args__", ())
105
106        if origin == list:
107            if len(args) != 1:
108                self.__exception__(
109                    f"List must have a single type argument, got: {args}",
110                    raise_exception=True,
111                )
112            return {list: self.__get_checkable_type__(args[0])}
113
114        if origin == dict:
115            if len(args) != 2:
116                self.__exception__(
117                    f"Dict must have two type arguments, got: {args}",
118                    raise_exception=True,
119                )
120            key_type = self.__get_checkable_type__(args[0])
121            value_type = self.__get_checkable_type__(args[1])
122            return {dict: (key_type, value_type)}
123
124        if origin == tuple:
125            if len(args) > 2 or len(args) == 1:
126                if Ellipsis in args:
127                    self.__exception__(
128                        "Tuple with Ellipsis must have exactly two type arguments and the second must be Ellipsis.",
129                        raise_exception=True,
130                    )
131            if len(args) == 2:
132                if args[0] is Ellipsis:
133                    self.__exception__(
134                        "Tuple with Ellipsis must have exactly two type arguments and the first must not be Ellipsis.",
135                        raise_exception=True,
136                    )
137                if args[1] is Ellipsis:
138                    return {tuple: (self.__get_checkable_type__(args[0]), True)}
139            return {
140                tuple: (
141                    tuple(self.__get_checkable_type__(arg) for arg in args),
142                    False,
143                )
144            }
145
146        if origin == set:
147            if len(args) != 1:
148                self.__exception__(
149                    f"Set must have a single type argument, got: {args}",
150                    raise_exception=True,
151                )
152            return {set: self.__get_checkable_type__(args[0])}
153
154        # Handle Sized types
155        if annotation == Sized:
156            return {
157                list: None,
158                tuple: None,
159                dict: None,
160                set: None,
161                str: None,
162                bytes: None,
163                bytearray: None,
164                memoryview: None,
165                range: None,
166            }
167
168        # Handle Callable types
169        if annotation == Callable:
170            return {
171                staticmethod: None,
172                classmethod: None,
173                FunctionType: None,
174                BuiltinFunctionType: None,
175                MethodType: None,
176                BuiltinMethodType: None,
177                GeneratorType: None,
178            }
179
180        if annotation == Any:
181            return {
182                object: None,
183            }
184
185        # Handle Constraints
186        if isinstance(annotation, GenericConstraint):
187            return {"__extra__": {"__constraints__": [annotation]}}
188
189        # Handle standard types
190        if isinstance(annotation, type):
191            return {annotation: None}
192
193        # Hanldle typing.Type (for uninitialized classes)
194        if origin is type and len(args) == 1:
195            return {annotation: None}
196
197        self.__exception__(
198            f"Unsupported type hint: {annotation}", raise_exception=True
199        )
200
201    def __exception__(self, message, raise_exception=False):
202        """
203        Usage:
204
205        - Creates a class based exception message
206
207        Requires:
208
209        - `message`:
210            - Type: str
211            - What: The message to warn users with
212
213        Optional:
214
215        - `raise_exception`:
216            - Type: bool
217            - What: Forces an exception to be raised regardless of the `self.__strict__` setting.
218            - Default: False
219        """
220        if self.__strict__ or raise_exception:
221            raise TypeError(
222                f"TypeEnforced Exception ({self.__fn__.__qualname__}): {message}"
223            )
224        else:
225            print(
226                f"TypeEnforced Warning ({self.__fn__.__qualname__}): {message}"
227            )
228
229    def __get__(self, obj, objtype):
230        """
231        Overwrite standard __get__ method to return __call__ instead for wrapped class methods.
232
233        Also stores the calling (__get__) `obj` to be passed as an initial argument for `__call__` such that methods can pass `self` correctly.
234        """
235
236        @wraps(self.__fn__)
237        def __get_fn__(*args, **kwargs):
238            return self.__call__(*args, **kwargs)
239
240        self.__outer_self__ = obj
241        return __get_fn__
242
243    def __check_method_function__(self):
244        """
245        Validate that `self.__fn__` is a method or function
246        """
247        if not isinstance(self.__fn__, (MethodType, FunctionType)):
248            raise Exception(
249                f"A non function/method was passed to Enforcer. See the stack trace above for more information."
250            )
251
252    def __call__(self, *args, **kwargs):
253        """
254        This method is used to validate the passed inputs and return the output of the wrapped function or method.
255        """
256        # Special code to pass self as an initial argument
257        # for validation purposes in methods
258        # See: self.__get__
259        if self.__outer_self__ is not None:
260            args = (self.__outer_self__, *args)
261        # Get a dictionary of all annotations as checkable types
262        # Note: This is only done once at first call to avoid redundant calculations
263        self.__get_checkable_types__()
264        # Create a compreshensive dictionary of assigned variables (order matters)
265        assigned_vars = {
266            **self.__fn_defaults__,
267            **dict(zip(self.__fn__.__code__.co_varnames[: len(args)], args)),
268            **kwargs,
269        }
270        # Validate all listed annotations vs the assigned_vars dictionary
271        for key, value in self.__checkable_types__.items():
272            self.__check_type__(assigned_vars.get(key), value, key)
273        # Execute the function callable
274        return_value = self.__fn__(*args, **kwargs)
275        # If a return type was passed, validate the returned object
276        if self.__return_type__ is not None:
277            self.__check_type__(return_value, self.__return_type__, "return")
278        return return_value
279
280    def __quick_check__(self, subtype, obj):
281        if all([v == None for v in subtype.values()]):
282            # If the subtype does not contain iterables with typing, we can validate the items directly.
283            types = set(subtype.keys())
284            values = set([type(v) for v in obj])
285            if values.issubset(types):
286                # We can return True to bypass the full validation
287                return True
288            # Otherwise, validation did not pass and a full validation is required to raise an indexed/keyed type mismatch error
289        return False
290
291    def __check_type__(self, obj, expected, key):
292        """
293        Raises an exception the type of a passed `obj` (parameter) is not in the list of supplied `acceptable_types` for the argument.
294        """
295        # Special case for None
296        if obj is None and type(None) in expected:
297            return
298        extra = expected.get("__extra__", {})
299        expected = {k: v for k, v in expected.items() if k != "__extra__"}
300
301        if isinstance(obj, type):
302            # An uninitialized class is passed, we need to check if the type is in the expected types using Type[obj]
303            obj_type = Type[obj]
304            is_present = obj_type in expected
305        else:
306            obj_type = type(obj)
307            is_present = isinstance(obj, tuple(expected.keys()))
308
309        if not is_present:
310            # Allow for literals to be used to bypass type checks if present
311            literal = extra.get("__literal__", ())
312            if literal:
313                if obj not in literal:
314                    self.__exception__(
315                        f"Type mismatch for typed variable `{key}`. Expected one of the following `{list(expected.keys())}` or a literal value in `{literal}` but got type `{obj_type}` with value `{obj}` instead."
316                    )
317            # Raise an exception if the type is not in the expected types
318            else:
319                self.__exception__(
320                    f"Type mismatch for typed variable `{key}`. Expected one of the following `{list(expected.keys())}` but got `{obj_type}` with value `{obj}` instead."
321                )
322        # If the object_type is in the expected types, we can proceed with validation
323        elif obj_type in iterable_types:
324            subtype = expected.get(obj_type, None)
325            if subtype is None:
326                pass
327            # Recursive validation
328            elif obj_type == list:
329                # If the subtype does not contain iterables with typing, we can validate the items directly.
330                if not self.__quick_check__(subtype, obj):
331                    for idx, item in enumerate(obj):
332                        self.__check_type__(item, subtype, f"{key}[{idx}]")
333            elif obj_type == dict:
334                key_type, val_type = subtype
335                if not self.__quick_check__(key_type, obj.keys()):
336                    for key in obj.keys():
337                        self.__check_type__(
338                            key, key_type, f"{key}.key[{repr(key)}]"
339                        )
340                if not self.__quick_check__(val_type, obj.values()):
341                    for key, value in obj.items():
342                        self.__check_type__(
343                            value, val_type, f"{key}[{repr(key)}]"
344                        )
345            elif obj_type == tuple:
346                expected_args, is_ellipsis = subtype
347                if is_ellipsis:
348                    if not self.__quick_check__(expected_args, obj):
349                        for idx, item in enumerate(obj):
350                            self.__check_type__(
351                                item, expected_args, f"{key}[{idx}]"
352                            )
353                else:
354                    if len(obj) != len(expected_args):
355                        self.__exception__(
356                            f"Tuple length mismatch for `{key}`. Expected length {len(expected_args)}, got {len(obj)}"
357                        )
358                    for idx, (item, ex) in enumerate(zip(obj, expected_args)):
359                        self.__check_type__(item, ex, f"{key}[{idx}]")
360            elif obj_type == set:
361                if not self.__quick_check__(subtype, obj):
362                    for item in obj:
363                        self.__check_type__(
364                            item, subtype, f"{key}[{repr(item)}]"
365                        )
366
367        # Validate constraints if any are present
368        constraints = extra.get("__constraints__", [])
369        for constraint in constraints:
370            constraint_validation_output = constraint.__validate__(key, obj)
371            if constraint_validation_output is not True:
372                self.__exception__(
373                    f"Constraint validation error for variable `{key}` with value `{obj}`. {constraint_validation_output}"
374                )
375
376    def __repr__(self):
377        return f"<type_enforced {self.__fn__.__module__}.{self.__fn__.__qualname__} object at {hex(id(self))}>"
FunctionMethodEnforcer(__fn__, __strict__=False)
21    def __init__(self, __fn__, __strict__=False):
22        """
23        Initialize a FunctionMethodEnforcer class object as a wrapper for a passed function `__fn__`.
24
25        Requires:
26
27            - `__fn__`:
28                - What: The function to enforce
29                - Type: function | method | class
30        """
31        update_wrapper(self, __fn__)
32        self.__fn__ = __fn__
33        self.__strict__ = __strict__
34        self.__outer_self__ = None
35        # Validate that the passed function or method is a method or function
36        self.__check_method_function__()
37        # Get input defaults for the function or method
38        self.__get_defaults__()

Initialize a FunctionMethodEnforcer class object as a wrapper for a passed function __fn__.

Requires:

- `__fn__`:
    - What: The function to enforce
    - Type: function | method | class
@Partial
def Enforcer(clsFnMethod, enabled=True, strict=True):
380@Partial
381def Enforcer(clsFnMethod, enabled=True, strict=True):
382    """
383    A wrapper to enforce types within a function or method given argument annotations.
384
385    Each wrapped item is converted into a special `FunctionMethodEnforcer` class object that validates the passed parameters for the function or method when it is called. If a function or method that is passed does not have any annotations, it is not converted into a `FunctionMethodEnforcer` class as no validation is possible.
386
387    If wrapping a class, all methods in the class that meet any of the following criteria will be wrapped individually:
388
389    - Methods with `__call__`
390    - Methods wrapped with `staticmethod` (if python >= 3.10)
391    - Methods wrapped with `classmethod` (if python >= 3.10)
392
393    Requires:
394
395    - `clsFnMethod`:
396        - What: The class, function or method that should have input types enforced
397        - Type: function | method | class
398
399    Optional:
400
401    - `enabled`:
402        - What: A boolean to enable or disable the enforcer
403        - Type: bool
404        - Default: True
405    - `strict`:
406        - What: A boolean to enable or disable exceptions. If True, exceptions will be raised when type checking fails. If False, exceptions will not be raised but instead a warning will be printed to the console.
407        - Type: bool
408        - Default: False
409        - Note: Type hints that are wrapped with the type enforcer and are invalid will still raise an exception.
410
411
412    Example Use:
413    ```
414    >>> import type_enforced
415    >>> @type_enforced.Enforcer
416    ... def my_fn(a: int , b: [int, str] =2, c: int =3) -> None:
417    ...     pass
418    ...
419    >>> my_fn(a=1, b=2, c=3)
420    >>> my_fn(a=1, b='2', c=3)
421    >>> my_fn(a='a', b=2, c=3)
422    Traceback (most recent call last):
423      File "<stdin>", line 1, in <module>
424      File "/home/conmak/development/personal/type_enforced/type_enforced/enforcer.py", line 85, in __call__
425        self.__check_type__(assigned_vars.get(key), value, key)
426      File "/home/conmak/development/personal/type_enforced/type_enforced/enforcer.py", line 107, in __check_type__
427        self.__exception__(
428      File "/home/conmak/development/personal/type_enforced/type_enforced/enforcer.py", line 34, in __exception__
429        raise Exception(f"({self.__fn__.__qualname__}): {message}")
430    Exception: (my_fn): Type mismatch for typed variable `a`. Expected one of the following `[<class 'int'>]` but got `<class 'str'>` instead.
431    ```
432    """
433    if not hasattr(clsFnMethod, "__type_enforced_enabled__"):
434        # Special try except clause to handle cases when the object is immutable
435        try:
436            clsFnMethod.__type_enforced_enabled__ = enabled
437        except:
438            return clsFnMethod
439    if not clsFnMethod.__type_enforced_enabled__:
440        return clsFnMethod
441    if isinstance(
442        clsFnMethod, (staticmethod, classmethod, FunctionType, MethodType)
443    ):
444        # Only apply the enforcer if type_hints are present
445        if get_type_hints(clsFnMethod) == {}:
446            return clsFnMethod
447        elif isinstance(clsFnMethod, staticmethod):
448            return staticmethod(
449                FunctionMethodEnforcer(
450                    __fn__=clsFnMethod.__func__, __strict__=strict
451                )
452            )
453        elif isinstance(clsFnMethod, classmethod):
454            return classmethod(
455                FunctionMethodEnforcer(
456                    __fn__=clsFnMethod.__func__, __strict__=strict
457                )
458            )
459        else:
460            return FunctionMethodEnforcer(__fn__=clsFnMethod, __strict__=strict)
461    elif hasattr(clsFnMethod, "__dict__"):
462        for key, value in clsFnMethod.__dict__.items():
463            # Skip the __annotate__ method if present in __dict__ as it deletes itself upon invocation
464            # Skip any previously wrapped methods if they are already a FunctionMethodEnforcer
465            if key == "__annotate__" or isinstance(
466                value, FunctionMethodEnforcer
467            ):
468                continue
469            if hasattr(value, "__call__") or isinstance(
470                value, (classmethod, staticmethod)
471            ):
472                setattr(
473                    clsFnMethod,
474                    key,
475                    Enforcer(value, enabled=enabled, strict=strict),
476                )
477        return clsFnMethod
478    else:
479        raise Exception(
480            "Enforcer can only be used on classes, methods, or functions."
481        )

A wrapper to enforce types within a function or method given argument annotations.

Each wrapped item is converted into a special FunctionMethodEnforcer class object that validates the passed parameters for the function or method when it is called. If a function or method that is passed does not have any annotations, it is not converted into a FunctionMethodEnforcer class as no validation is possible.

If wrapping a class, all methods in the class that meet any of the following criteria will be wrapped individually:

  • Methods with __call__
  • Methods wrapped with staticmethod (if python >= 3.10)
  • Methods wrapped with classmethod (if python >= 3.10)

Requires:

  • clsFnMethod:
    • What: The class, function or method that should have input types enforced
    • Type: function | method | class

Optional:

  • enabled:
    • What: A boolean to enable or disable the enforcer
    • Type: bool
    • Default: True
  • strict:
    • What: A boolean to enable or disable exceptions. If True, exceptions will be raised when type checking fails. If False, exceptions will not be raised but instead a warning will be printed to the console.
    • Type: bool
    • Default: False
    • Note: Type hints that are wrapped with the type enforcer and are invalid will still raise an exception.

Example Use:

>>> import type_enforced
>>> @type_enforced.Enforcer
... def my_fn(a: int , b: [int, str] =2, c: int =3) -> None:
...     pass
...
>>> my_fn(a=1, b=2, c=3)
>>> my_fn(a=1, b='2', c=3)
>>> my_fn(a='a', b=2, c=3)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/conmak/development/personal/type_enforced/type_enforced/enforcer.py", line 85, in __call__
    self.__check_type__(assigned_vars.get(key), value, key)
  File "/home/conmak/development/personal/type_enforced/type_enforced/enforcer.py", line 107, in __check_type__
    self.__exception__(
  File "/home/conmak/development/personal/type_enforced/type_enforced/enforcer.py", line 34, in __exception__
    raise Exception(f"({self.__fn__.__qualname__}): {message}")
Exception: (my_fn): Type mismatch for typed variable `a`. Expected one of the following `[<class 'int'>]` but got `<class 'str'>` instead.