type_enforced.enforcer

  1from types import (
  2    FunctionType,
  3    MethodType,
  4    GeneratorType,
  5    BuiltinFunctionType,
  6    BuiltinMethodType,
  7    UnionType,
  8)
  9from typing import Type, Union, Sized, Literal, Callable, get_type_hints, Any
 10from functools import update_wrapper
 11from type_enforced.utils import (
 12    Partial,
 13    GenericConstraint,
 14    iterable_types,
 15    merge_type_dicts,
 16)
 17import sys, traceback, random
 18from pathlib import Path
 19
 20_NoneType = type(None)
 21_package_path = Path(__file__).parent.resolve()
 22
 23
 24class FunctionMethodEnforcer:
 25    __slots__ = (
 26        "__fn__",
 27        "__strict__",
 28        "__clean_traceback__",
 29        "__iterable_sample_pct__",
 30        "__outer_self__",
 31        "__fn_defaults__",
 32        "__fn_varnames__",
 33        "__types_parsed__",
 34        "__checkable_types__",
 35        "__return_type__",
 36        "__simple_types__",
 37        "__complex_types__",
 38        "__simple_return_type__",
 39        "__param_indices__",
 40        "__flat_subtypes__",
 41        "__wrapped__",
 42        "__name__",
 43        "__qualname__",
 44        "__doc__",
 45        "__dict__",
 46    )
 47
 48    def __init__(
 49        self,
 50        __fn__,
 51        __strict__=False,
 52        __clean_traceback__=True,
 53        __iterable_sample_pct__=100,
 54    ):
 55        """
 56        Initialize a FunctionMethodEnforcer class object as a wrapper for a passed function `__fn__`.
 57
 58        Requires:
 59
 60            - `__fn__`:
 61                - What: The function to enforce
 62                - Type: function | method | class
 63
 64        Optional:
 65
 66            - `__strict__`:
 67                - What: A boolean to enable or disable exceptions. If True, exceptions will be raised
 68                    when type checking fails. If False, exceptions will not be raised but instead a warning
 69                    will be printed to the console.
 70                - Type: bool
 71                - Default: False
 72            - `__clean_traceback__`:
 73                - What: A boolean to enable or disable cleaning of tracebacks when raising exceptions.
 74                - Type: bool
 75                - Default: True
 76            - `__iterable_sample_pct__`:
 77                - What: The percentage of items to sample when validating iterables. If 100, all items
 78                    are validated. If less than 100, the first and last items are always validated
 79                    plus a random sample of the remaining items up to the specified percentage.
 80                - Type: int | float
 81                - Default: 100
 82        """
 83        update_wrapper(self, __fn__)
 84        self.__fn__ = __fn__
 85        self.__strict__ = __strict__
 86        self.__clean_traceback__ = __clean_traceback__
 87        self.__iterable_sample_pct__ = __iterable_sample_pct__
 88        self.__outer_self__ = None
 89        self.__types_parsed__ = False
 90        self.__flat_subtypes__ = {}
 91        # Validate that the passed function or method is a method or function
 92        self.__check_method_function__()
 93        # Get input defaults for the function or method
 94        self.__get_defaults__()
 95
 96    def __get_defaults__(self):
 97        """
 98        Get the default values of the passed function or method and store them in `self.__fn_defaults__`.
 99        Also caches the function's variable names for use in `__call__`.
100        """
101        self.__fn_varnames__ = self.__fn__.__code__.co_varnames
102        self.__fn_defaults__ = {}
103        if self.__fn__.__defaults__ is not None:
104            # Get the names of all provided default values for args
105            default_varnames = list(self.__fn_varnames__)[
106                : self.__fn__.__code__.co_argcount
107            ][-len(self.__fn__.__defaults__) :]
108            # Update the output dictionary with the default values
109            self.__fn_defaults__.update(
110                dict(zip(default_varnames, self.__fn__.__defaults__))
111            )
112        if self.__fn__.__kwdefaults__ is not None:
113            # Update the output dictionary with the keyword default values
114            self.__fn_defaults__.update(self.__fn__.__kwdefaults__)
115
116    def __get_sample_indices__(self, length):
117        """
118        Get a sorted list of indices to sample for iterable validation.
119
120        If iterable_sample_pct is 0, only the first item (index 0) is checked.
121        Otherwise, always includes the first (0) and last (length-1) indices.
122        If length > 3, samples additional middle indices up to the
123        iterable_sample_pct percentage.
124        Only called when self.__iterable_sample_pct__ < 100.
125        """
126        if length == 0:
127            return []
128        if self.__iterable_sample_pct__ == 0:
129            return [0]
130        if length <= 3:
131            return range(length)
132        n = max(3, int(length * self.__iterable_sample_pct__ / 100))
133        if n >= length:
134            return range(length)
135        middle_sample = random.sample(range(1, length - 1), n - 2)
136        return sorted([0] + middle_sample + [length - 1])
137
138    def __get_sample_keys__(self, keys):
139        """
140        Get a sampled list of dict keys for iterable validation.
141
142        If iterable_sample_pct is 0, only the first key is returned.
143        Otherwise, always includes the first key plus a random sample of
144        the remaining keys up to the iterable_sample_pct percentage.
145        Only called when self.__iterable_sample_pct__ < 100.
146        """
147        if len(keys) == 0:
148            return []
149        if self.__iterable_sample_pct__ == 0 or len(keys) == 1:
150            return [keys[0]]
151        n = max(1, int(len(keys) * self.__iterable_sample_pct__ / 100))
152        if n >= len(keys):
153            return keys
154        return [keys[0]] + random.sample(keys[1:], n - 1)
155
156    def __get_checkable_types__(self):
157        """
158        Creates two class attributes:
159
160        - `self.__checkable_types__`:
161            - What: A dictionary of all annotations as checkable types
162            - Type: dict
163
164        - `self.__return_type__`:
165            - What: The return type of the function or method
166            - Type: dict | None
167        """
168        if not self.__types_parsed__:
169            self.__checkable_types__ = {
170                key: self.__get_checkable_type__(value)
171                for key, value in get_type_hints(self.__fn__).items()
172            }
173            self.__return_type__ = self.__checkable_types__.pop("return", None)
174            # Classify params: simple types can use a single
175            # isinstance call, skipping __check_type__ entirely.
176            self.__simple_types__ = {}
177            self.__complex_types__ = {}
178            for key, expected in self.__checkable_types__.items():
179                if (
180                    "__extra__" not in expected
181                    and all(v is None for v in expected.values())
182                    and all(isinstance(k, type) for k in expected.keys())
183                ):
184                    self.__simple_types__[key] = tuple(expected.keys())
185                else:
186                    self.__complex_types__[key] = expected
187            # Same classification for return type
188            if self.__return_type__ is not None and (
189                "__extra__" not in self.__return_type__
190                and all(v is None for v in self.__return_type__.values())
191                and all(
192                    isinstance(k, type) for k in self.__return_type__.keys()
193                )
194            ):
195                self.__simple_return_type__ = tuple(self.__return_type__.keys())
196            else:
197                self.__simple_return_type__ = None
198            # Pre-compute param index in co_varnames for
199            # direct arg lookup (skips assigned_vars dict).
200            self.__param_indices__ = {
201                name: i
202                for i, name in enumerate(self.__fn_varnames__)
203                if name in self.__checkable_types__
204            }
205            self.__types_parsed__ = True
206
207    def __get_checkable_type__(self, annotation):
208        """
209        Parses a type annotation and returns a nested dict structure
210        representing the checkable type(s) for validation.
211        """
212
213        if annotation is None:
214            return {_NoneType: None}
215
216        # Handle `int | str` syntax (Python 3.10+) and Unions
217        if (
218            isinstance(annotation, UnionType)
219            or getattr(annotation, "__origin__", None) == Union
220        ):
221            combined_types = {}
222            for sub_type in annotation.__args__:
223                merge_type_dicts(
224                    combined_types,
225                    self.__get_checkable_type__(sub_type),
226                )
227            return combined_types
228
229        # Handle typing.Literal
230        if getattr(annotation, "__origin__", None) == Literal:
231            return {"__extra__": {"__literal__": list(annotation.__args__)}}
232
233        # Handle generic collections
234        origin = getattr(annotation, "__origin__", None)
235        args = getattr(annotation, "__args__", ())
236
237        if origin == list:
238            if len(args) != 1:
239                self.__exception__(
240                    f"List must have a single type argument, got: {args}",
241                    raise_exception=True,
242                )
243            return {list: self.__get_checkable_type__(args[0])}
244
245        if origin == dict:
246            if len(args) != 2:
247                self.__exception__(
248                    f"Dict must have two type arguments, got: {args}",
249                    raise_exception=True,
250                )
251            key_type = self.__get_checkable_type__(args[0])
252            value_type = self.__get_checkable_type__(args[1])
253            return {dict: (key_type, value_type)}
254
255        if origin == tuple:
256            if len(args) > 2 or len(args) == 1:
257                if Ellipsis in args:
258                    self.__exception__(
259                        "Tuple with Ellipsis must have exactly two type arguments and the second must be Ellipsis.",
260                        raise_exception=True,
261                    )
262            if len(args) == 2:
263                if args[0] is Ellipsis:
264                    self.__exception__(
265                        "Tuple with Ellipsis must have exactly two type arguments and the first must not be Ellipsis.",
266                        raise_exception=True,
267                    )
268                if args[1] is Ellipsis:
269                    return {tuple: (self.__get_checkable_type__(args[0]), True)}
270            return {
271                tuple: (
272                    tuple(self.__get_checkable_type__(arg) for arg in args),
273                    False,
274                )
275            }
276
277        if origin == set:
278            if len(args) != 1:
279                self.__exception__(
280                    f"Set must have a single type argument, got: {args}",
281                    raise_exception=True,
282                )
283            return {set: self.__get_checkable_type__(args[0])}
284
285        # Handle Sized types
286        if annotation == Sized:
287            return {
288                list: None,
289                tuple: None,
290                dict: None,
291                set: None,
292                str: None,
293                bytes: None,
294                bytearray: None,
295                memoryview: None,
296                range: None,
297            }
298
299        # Handle Callable types
300        if annotation == Callable:
301            return {
302                staticmethod: None,
303                classmethod: None,
304                FunctionType: None,
305                BuiltinFunctionType: None,
306                MethodType: None,
307                BuiltinMethodType: None,
308                GeneratorType: None,
309            }
310
311        if annotation == Any:
312            return {
313                object: None,
314            }
315
316        # Handle Constraints
317        if isinstance(annotation, GenericConstraint):
318            return {"__extra__": {"__constraints__": [annotation]}}
319
320        # Handle standard types
321        if isinstance(annotation, type):
322            return {annotation: None}
323
324        # Hanldle typing.Type (for uninitialized classes)
325        if origin is type and len(args) == 1:
326            return {annotation: None}
327
328        self.__exception__(
329            f"Unsupported type hint: {annotation}", raise_exception=True
330        )
331
332    def __exception__(self, message, raise_exception=False):
333        """
334        Usage:
335
336        - Creates a class based exception message
337
338        Requires:
339
340        - `message`:
341            - Type: str
342            - What: The message to warn users with
343
344        Optional:
345
346        - `raise_exception`:
347            - Type: bool
348            - What: Forces an exception to be raised regardless of the `self.__strict__` setting.
349            - Default: False
350        """
351        if self.__strict__ or raise_exception:
352            msg = f"TypeEnforced Exception ({self.__fn__.__qualname__}): {message}"
353            if self.__clean_traceback__:
354                package_path = _package_path
355                frame = sys._getframe()
356                relevant_tb_count = 0
357                while frame is not None:
358                    frame_file = Path(frame.f_code.co_filename).resolve()
359                    try:
360                        frame_file.relative_to(package_path)
361                    except ValueError:
362                        relevant_tb_count += 1
363                    frame = frame.f_back
364                original_excepthook = sys.excepthook
365
366                def excepthook(type, value, tb):
367                    traceback.print_exception(
368                        type, value, tb, limit=relevant_tb_count
369                    )
370                    sys.excepthook = original_excepthook
371
372                sys.excepthook = excepthook
373            raise TypeError(msg)
374        else:
375            print(
376                f"TypeEnforced Warning ({self.__fn__.__qualname__}): {message}"
377            )
378
379    def __get__(self, obj, objtype):
380        """
381        Overwrite standard __get__ method to return __call__ instead for wrapped class methods.
382
383        Also stores the calling (__get__) `obj` to be passed as an initial argument for `__call__` such that methods can pass `self` correctly.
384        """
385        self.__outer_self__ = obj
386
387        def __get_fn__(*args, **kwargs):
388            return self.__call__(*args, **kwargs)
389
390        __get_fn__.__name__ = self.__fn__.__name__
391        __get_fn__.__qualname__ = self.__fn__.__qualname__
392        __get_fn__.__doc__ = self.__fn__.__doc__
393        return __get_fn__
394
395    def __check_method_function__(self):
396        """
397        Validate that `self.__fn__` is a method or function
398        """
399        if not isinstance(self.__fn__, (MethodType, FunctionType)):
400            raise Exception(
401                f"A non function/method was passed to Enforcer. See the stack trace above for more information."
402            )
403
404    def __call__(self, *args, **kwargs):
405        """
406        This method is used to validate the passed inputs and return the output of the wrapped function or method.
407        """
408        # Special code to pass self as an initial argument
409        # for validation purposes in methods
410        # See: self.__get__
411        if self.__outer_self__ is not None:
412            args = (self.__outer_self__, *args)
413        # Get a dictionary of all annotations as checkable types
414        # Note: This is only done once at first call to avoid redundant calculations
415        self.__get_checkable_types__()
416        # Fast path: simple types use direct index lookup
417        for key, types_tuple in self.__simple_types__.items():
418            idx = self.__param_indices__[key]
419            if idx < len(args):
420                obj = args[idx]
421            elif key in kwargs:
422                obj = kwargs[key]
423            else:
424                obj = self.__fn_defaults__.get(key)
425            if not isinstance(obj, types_tuple):
426                # Fall back to full check for error reporting
427                self.__check_type__(obj, self.__checkable_types__[key], key)
428        # Full validation for complex types (nested, extras, Type[X])
429        if self.__complex_types__:
430            assigned_vars = {
431                **self.__fn_defaults__,
432                **dict(zip(self.__fn_varnames__[: len(args)], args)),
433                **kwargs,
434            }
435            for key, value in self.__complex_types__.items():
436                self.__check_type__(assigned_vars.get(key), value, key)
437        # Execute the function callable
438        return_value = self.__fn__(*args, **kwargs)
439        # If a return type was passed, validate the returned object
440        if self.__return_type__ is not None:
441            if self.__simple_return_type__ is not None:
442                if not isinstance(return_value, self.__simple_return_type__):
443                    self.__check_type__(
444                        return_value, self.__return_type__, "return"
445                    )
446            else:
447                self.__check_type__(
448                    return_value, self.__return_type__, "return"
449                )
450        return return_value
451
452    def __quick_check__(self, subtype, obj):
453        subtype_id = id(subtype)
454        if subtype_id not in self.__flat_subtypes__:
455            # First call for this subtype: compute and cache
456            if all(v is None for v in subtype.values()):
457                self.__flat_subtypes__[subtype_id] = frozenset(subtype.keys())
458            else:
459                self.__flat_subtypes__[subtype_id] = None
460        flat_keys = self.__flat_subtypes__[subtype_id]
461        if flat_keys is not None:
462            values = {type(v) for v in obj}
463            if values.issubset(flat_keys):
464                return True
465        return False
466
467    def __check_type__(self, obj, expected, key):
468        """
469        Raises an exception the type of a passed `obj` (parameter) is not in the list of supplied `acceptable_types` for the argument.
470        """
471        # Special case for None
472        if obj is None and _NoneType in expected:
473            return
474        if "__extra__" in expected:
475            extra = expected["__extra__"]
476            expected = {k: v for k, v in expected.items() if k != "__extra__"}
477        else:
478            extra = None
479
480        if isinstance(obj, type):
481            # An uninitialized class is passed, we need to check if the type is in the expected types using Type[obj]
482            obj_type = Type[obj]
483            is_present = obj_type in expected
484        else:
485            obj_type = type(obj)
486            is_present = obj_type in expected or isinstance(
487                obj, tuple(expected.keys())
488            )
489
490        if not is_present:
491            # Allow for literals to be used to bypass type checks if present
492            literal = extra.get("__literal__", ()) if extra is not None else ()
493            if literal:
494                if obj not in literal:
495                    self.__exception__(
496                        f"Type mismatch for typed variable `{key}`. Expected one of the following `{list(expected.keys())}` or a literal value in `{literal}` but got type `{obj_type}` with value `{obj}` instead."
497                    )
498            # Raise an exception if the type is not in the expected types
499            else:
500                self.__exception__(
501                    f"Type mismatch for typed variable `{key}`. Expected one of the following `{list(expected.keys())}` but got `{obj_type}` with value `{obj}` instead."
502                )
503        # If the object_type is in the expected types, we can proceed with validation
504        elif obj_type in iterable_types:
505            subtype = expected.get(obj_type, None)
506            if subtype is None:
507                pass
508            # Recursive validation
509            elif obj_type == list:
510                if self.__iterable_sample_pct__ < 100:
511                    for idx in self.__get_sample_indices__(len(obj)):
512                        self.__check_type__(obj[idx], subtype, f"{key}[{idx}]")
513                # If the subtype does not contain iterables with typing, we can validate the items directly.
514                elif not self.__quick_check__(subtype, obj):
515                    for idx, item in enumerate(obj):
516                        self.__check_type__(item, subtype, f"{key}[{idx}]")
517            elif obj_type == dict:
518                key_type, val_type = subtype
519                if self.__iterable_sample_pct__ < 100:
520                    sampled_keys = self.__get_sample_keys__(list(obj.keys()))
521                    if not self.__quick_check__(key_type, sampled_keys):
522                        for dk in sampled_keys:
523                            self.__check_type__(
524                                dk, key_type, f"{key}.key[{repr(dk)}]"
525                            )
526                    if not self.__quick_check__(
527                        val_type, [obj[dk] for dk in sampled_keys]
528                    ):
529                        for dk in sampled_keys:
530                            self.__check_type__(
531                                obj[dk], val_type, f"{key}[{repr(dk)}]"
532                            )
533                else:
534                    if not self.__quick_check__(key_type, obj.keys()):
535                        for key in obj.keys():
536                            self.__check_type__(
537                                key, key_type, f"{key}.key[{repr(key)}]"
538                            )
539                    if not self.__quick_check__(val_type, obj.values()):
540                        for key, value in obj.items():
541                            self.__check_type__(
542                                value, val_type, f"{key}[{repr(key)}]"
543                            )
544            elif obj_type == tuple:
545                expected_args, is_ellipsis = subtype
546                if is_ellipsis:
547                    if self.__iterable_sample_pct__ < 100:
548                        for idx in self.__get_sample_indices__(len(obj)):
549                            self.__check_type__(
550                                obj[idx], expected_args, f"{key}[{idx}]"
551                            )
552                    elif not self.__quick_check__(expected_args, obj):
553                        for idx, item in enumerate(obj):
554                            self.__check_type__(
555                                item, expected_args, f"{key}[{idx}]"
556                            )
557                else:
558                    if len(obj) != len(expected_args):
559                        self.__exception__(
560                            f"Tuple length mismatch for `{key}`. Expected length {len(expected_args)}, got {len(obj)}"
561                        )
562                    for idx, (item, ex) in enumerate(zip(obj, expected_args)):
563                        self.__check_type__(item, ex, f"{key}[{idx}]")
564            elif obj_type == set:
565                if self.__iterable_sample_pct__ < 100:
566                    obj_list = list(obj)
567                    for idx in self.__get_sample_indices__(len(obj_list)):
568                        item = obj_list[idx]
569                        self.__check_type__(
570                            item, subtype, f"{key}[{repr(item)}]"
571                        )
572                elif not self.__quick_check__(subtype, obj):
573                    for item in obj:
574                        self.__check_type__(
575                            item, subtype, f"{key}[{repr(item)}]"
576                        )
577
578        # Validate constraints if any are present
579        if extra is not None:
580            constraints = extra.get("__constraints__", ())
581            for constraint in constraints:
582                constraint_validation_output = constraint.__validate__(key, obj)
583                if constraint_validation_output is not True:
584                    self.__exception__(
585                        f"Constraint validation error for variable `{key}` with value `{obj}`. {constraint_validation_output}"
586                    )
587
588    def __repr__(self):
589        return f"<type_enforced {self.__fn__.__module__}.{self.__fn__.__qualname__} object at {hex(id(self))}>"
590
591
592@Partial
593def Enforcer(
594    clsFnMethod,
595    enabled=True,
596    strict=True,
597    clean_traceback=True,
598    iterable_sample_pct=100,
599):
600    """
601    A wrapper to enforce types within a function or method given argument annotations.
602
603    Each wrapped item is converted into a special `FunctionMethodEnforcer` class object that validates the passed parameters for the function or method when it is called. If a function or method that is passed does not have any annotations, it is not converted into a `FunctionMethodEnforcer` class as no validation is possible.
604
605    If wrapping a class, all methods in the class that meet any of the following criteria will be wrapped individually:
606
607    - Methods with `__call__`
608    - Methods wrapped with `staticmethod` (if python >= 3.10)
609    - Methods wrapped with `classmethod` (if python >= 3.10)
610
611    Requires:
612
613    - `clsFnMethod`:
614        - What: The class, function or method that should have input types enforced
615        - Type: function | method | class
616
617    Optional:
618
619    - `enabled`:
620        - What: A boolean to enable or disable the enforcer
621        - Type: bool
622        - Default: True
623    - `strict`:
624        - What: A boolean to enable or disable exceptions. If True, exceptions will be raised when type checking fails. If False, exceptions will not be raised but instead a warning will be printed to the console.
625        - Type: bool
626        - Default: False
627        - Note: Type hints that are wrapped with the type enforcer and are invalid will still raise an exception.
628    - `clean_traceback`:
629        - What: A boolean to enable or disable cleaning of tracebacks when raising exceptions.
630        - If True, modifies the excepthook temporarily such that only the relevant stack (not in the type_enforced package) is shown.
631        - Type: bool
632        - Default: True
633    - `iterable_sample_pct`:
634        - What: The percentage (0-100) of items to validate when checking typed iterables (list,
635            dict, set, variable-length tuple). At 100 (default) every item is checked. Below 100,
636            the first and last items are always checked; if the collection has more than 3 items,
637            additional items are randomly sampled so that the total checked is at least 3.
638        - Type: int | float
639        - Default: 100
640
641
642    Example Use:
643    ```
644    >>> import type_enforced
645    >>> @type_enforced.Enforcer
646    ... def my_fn(a: int , b: [int, str] =2, c: int =3) -> None:
647    ...     pass
648    ...
649    >>> my_fn(a=1, b=2, c=3)
650    >>> my_fn(a=1, b='2', c=3)
651    >>> my_fn(a='a', b=2, c=3)
652    Traceback (most recent call last):
653      File "<stdin>", line 1, in <module>
654      File "/home/conmak/development/personal/type_enforced/type_enforced/enforcer.py", line 85, in __call__
655        self.__check_type__(assigned_vars.get(key), value, key)
656      File "/home/conmak/development/personal/type_enforced/type_enforced/enforcer.py", line 107, in __check_type__
657        self.__exception__(
658      File "/home/conmak/development/personal/type_enforced/type_enforced/enforcer.py", line 34, in __exception__
659        raise Exception(f"({self.__fn__.__qualname__}): {message}")
660    Exception: (my_fn): Type mismatch for typed variable `a`. Expected one of the following `[<class 'int'>]` but got `<class 'str'>` instead.
661    ```
662    """
663    if not hasattr(clsFnMethod, "__type_enforced_enabled__"):
664        # Special try except clause to handle cases when the object is immutable
665        try:
666            clsFnMethod.__type_enforced_enabled__ = enabled
667        except:
668            return clsFnMethod
669    if not clsFnMethod.__type_enforced_enabled__:
670        return clsFnMethod
671    if isinstance(
672        clsFnMethod, (staticmethod, classmethod, FunctionType, MethodType)
673    ):
674        # Only apply the enforcer if type_hints are present
675        # Add try except clause to better handle forward refs.
676        try:
677            if get_type_hints(clsFnMethod) == {}:
678                return clsFnMethod
679        except:
680            pass
681        if isinstance(clsFnMethod, staticmethod):
682            return staticmethod(
683                FunctionMethodEnforcer(
684                    __fn__=clsFnMethod.__func__,
685                    __strict__=strict,
686                    __clean_traceback__=clean_traceback,
687                    __iterable_sample_pct__=iterable_sample_pct,
688                )
689            )
690        elif isinstance(clsFnMethod, classmethod):
691            return classmethod(
692                FunctionMethodEnforcer(
693                    __fn__=clsFnMethod.__func__,
694                    __strict__=strict,
695                    __clean_traceback__=clean_traceback,
696                    __iterable_sample_pct__=iterable_sample_pct,
697                )
698            )
699        else:
700            return FunctionMethodEnforcer(
701                __fn__=clsFnMethod,
702                __strict__=strict,
703                __clean_traceback__=clean_traceback,
704                __iterable_sample_pct__=iterable_sample_pct,
705            )
706    elif hasattr(clsFnMethod, "__dict__"):
707        for key, value in clsFnMethod.__dict__.items():
708            # Skip the __annotate__ method if present in __dict__ as it deletes itself upon invocation
709            # Skip any previously wrapped methods if they are already a FunctionMethodEnforcer
710            if key == "__annotate__" or isinstance(
711                value, FunctionMethodEnforcer
712            ):
713                continue
714            if hasattr(value, "__call__") or isinstance(
715                value, (classmethod, staticmethod)
716            ):
717                setattr(
718                    clsFnMethod,
719                    key,
720                    Enforcer(
721                        value,
722                        enabled=enabled,
723                        strict=strict,
724                        clean_traceback=clean_traceback,
725                        iterable_sample_pct=iterable_sample_pct,
726                    ),
727                )
728        return clsFnMethod
729    else:
730        raise Exception(
731            "Enforcer can only be used on classes, methods, or functions."
732        )
class FunctionMethodEnforcer:
 25class FunctionMethodEnforcer:
 26    __slots__ = (
 27        "__fn__",
 28        "__strict__",
 29        "__clean_traceback__",
 30        "__iterable_sample_pct__",
 31        "__outer_self__",
 32        "__fn_defaults__",
 33        "__fn_varnames__",
 34        "__types_parsed__",
 35        "__checkable_types__",
 36        "__return_type__",
 37        "__simple_types__",
 38        "__complex_types__",
 39        "__simple_return_type__",
 40        "__param_indices__",
 41        "__flat_subtypes__",
 42        "__wrapped__",
 43        "__name__",
 44        "__qualname__",
 45        "__doc__",
 46        "__dict__",
 47    )
 48
 49    def __init__(
 50        self,
 51        __fn__,
 52        __strict__=False,
 53        __clean_traceback__=True,
 54        __iterable_sample_pct__=100,
 55    ):
 56        """
 57        Initialize a FunctionMethodEnforcer class object as a wrapper for a passed function `__fn__`.
 58
 59        Requires:
 60
 61            - `__fn__`:
 62                - What: The function to enforce
 63                - Type: function | method | class
 64
 65        Optional:
 66
 67            - `__strict__`:
 68                - What: A boolean to enable or disable exceptions. If True, exceptions will be raised
 69                    when type checking fails. If False, exceptions will not be raised but instead a warning
 70                    will be printed to the console.
 71                - Type: bool
 72                - Default: False
 73            - `__clean_traceback__`:
 74                - What: A boolean to enable or disable cleaning of tracebacks when raising exceptions.
 75                - Type: bool
 76                - Default: True
 77            - `__iterable_sample_pct__`:
 78                - What: The percentage of items to sample when validating iterables. If 100, all items
 79                    are validated. If less than 100, the first and last items are always validated
 80                    plus a random sample of the remaining items up to the specified percentage.
 81                - Type: int | float
 82                - Default: 100
 83        """
 84        update_wrapper(self, __fn__)
 85        self.__fn__ = __fn__
 86        self.__strict__ = __strict__
 87        self.__clean_traceback__ = __clean_traceback__
 88        self.__iterable_sample_pct__ = __iterable_sample_pct__
 89        self.__outer_self__ = None
 90        self.__types_parsed__ = False
 91        self.__flat_subtypes__ = {}
 92        # Validate that the passed function or method is a method or function
 93        self.__check_method_function__()
 94        # Get input defaults for the function or method
 95        self.__get_defaults__()
 96
 97    def __get_defaults__(self):
 98        """
 99        Get the default values of the passed function or method and store them in `self.__fn_defaults__`.
100        Also caches the function's variable names for use in `__call__`.
101        """
102        self.__fn_varnames__ = self.__fn__.__code__.co_varnames
103        self.__fn_defaults__ = {}
104        if self.__fn__.__defaults__ is not None:
105            # Get the names of all provided default values for args
106            default_varnames = list(self.__fn_varnames__)[
107                : self.__fn__.__code__.co_argcount
108            ][-len(self.__fn__.__defaults__) :]
109            # Update the output dictionary with the default values
110            self.__fn_defaults__.update(
111                dict(zip(default_varnames, self.__fn__.__defaults__))
112            )
113        if self.__fn__.__kwdefaults__ is not None:
114            # Update the output dictionary with the keyword default values
115            self.__fn_defaults__.update(self.__fn__.__kwdefaults__)
116
117    def __get_sample_indices__(self, length):
118        """
119        Get a sorted list of indices to sample for iterable validation.
120
121        If iterable_sample_pct is 0, only the first item (index 0) is checked.
122        Otherwise, always includes the first (0) and last (length-1) indices.
123        If length > 3, samples additional middle indices up to the
124        iterable_sample_pct percentage.
125        Only called when self.__iterable_sample_pct__ < 100.
126        """
127        if length == 0:
128            return []
129        if self.__iterable_sample_pct__ == 0:
130            return [0]
131        if length <= 3:
132            return range(length)
133        n = max(3, int(length * self.__iterable_sample_pct__ / 100))
134        if n >= length:
135            return range(length)
136        middle_sample = random.sample(range(1, length - 1), n - 2)
137        return sorted([0] + middle_sample + [length - 1])
138
139    def __get_sample_keys__(self, keys):
140        """
141        Get a sampled list of dict keys for iterable validation.
142
143        If iterable_sample_pct is 0, only the first key is returned.
144        Otherwise, always includes the first key plus a random sample of
145        the remaining keys up to the iterable_sample_pct percentage.
146        Only called when self.__iterable_sample_pct__ < 100.
147        """
148        if len(keys) == 0:
149            return []
150        if self.__iterable_sample_pct__ == 0 or len(keys) == 1:
151            return [keys[0]]
152        n = max(1, int(len(keys) * self.__iterable_sample_pct__ / 100))
153        if n >= len(keys):
154            return keys
155        return [keys[0]] + random.sample(keys[1:], n - 1)
156
157    def __get_checkable_types__(self):
158        """
159        Creates two class attributes:
160
161        - `self.__checkable_types__`:
162            - What: A dictionary of all annotations as checkable types
163            - Type: dict
164
165        - `self.__return_type__`:
166            - What: The return type of the function or method
167            - Type: dict | None
168        """
169        if not self.__types_parsed__:
170            self.__checkable_types__ = {
171                key: self.__get_checkable_type__(value)
172                for key, value in get_type_hints(self.__fn__).items()
173            }
174            self.__return_type__ = self.__checkable_types__.pop("return", None)
175            # Classify params: simple types can use a single
176            # isinstance call, skipping __check_type__ entirely.
177            self.__simple_types__ = {}
178            self.__complex_types__ = {}
179            for key, expected in self.__checkable_types__.items():
180                if (
181                    "__extra__" not in expected
182                    and all(v is None for v in expected.values())
183                    and all(isinstance(k, type) for k in expected.keys())
184                ):
185                    self.__simple_types__[key] = tuple(expected.keys())
186                else:
187                    self.__complex_types__[key] = expected
188            # Same classification for return type
189            if self.__return_type__ is not None and (
190                "__extra__" not in self.__return_type__
191                and all(v is None for v in self.__return_type__.values())
192                and all(
193                    isinstance(k, type) for k in self.__return_type__.keys()
194                )
195            ):
196                self.__simple_return_type__ = tuple(self.__return_type__.keys())
197            else:
198                self.__simple_return_type__ = None
199            # Pre-compute param index in co_varnames for
200            # direct arg lookup (skips assigned_vars dict).
201            self.__param_indices__ = {
202                name: i
203                for i, name in enumerate(self.__fn_varnames__)
204                if name in self.__checkable_types__
205            }
206            self.__types_parsed__ = True
207
208    def __get_checkable_type__(self, annotation):
209        """
210        Parses a type annotation and returns a nested dict structure
211        representing the checkable type(s) for validation.
212        """
213
214        if annotation is None:
215            return {_NoneType: None}
216
217        # Handle `int | str` syntax (Python 3.10+) and Unions
218        if (
219            isinstance(annotation, UnionType)
220            or getattr(annotation, "__origin__", None) == Union
221        ):
222            combined_types = {}
223            for sub_type in annotation.__args__:
224                merge_type_dicts(
225                    combined_types,
226                    self.__get_checkable_type__(sub_type),
227                )
228            return combined_types
229
230        # Handle typing.Literal
231        if getattr(annotation, "__origin__", None) == Literal:
232            return {"__extra__": {"__literal__": list(annotation.__args__)}}
233
234        # Handle generic collections
235        origin = getattr(annotation, "__origin__", None)
236        args = getattr(annotation, "__args__", ())
237
238        if origin == list:
239            if len(args) != 1:
240                self.__exception__(
241                    f"List must have a single type argument, got: {args}",
242                    raise_exception=True,
243                )
244            return {list: self.__get_checkable_type__(args[0])}
245
246        if origin == dict:
247            if len(args) != 2:
248                self.__exception__(
249                    f"Dict must have two type arguments, got: {args}",
250                    raise_exception=True,
251                )
252            key_type = self.__get_checkable_type__(args[0])
253            value_type = self.__get_checkable_type__(args[1])
254            return {dict: (key_type, value_type)}
255
256        if origin == tuple:
257            if len(args) > 2 or len(args) == 1:
258                if Ellipsis in args:
259                    self.__exception__(
260                        "Tuple with Ellipsis must have exactly two type arguments and the second must be Ellipsis.",
261                        raise_exception=True,
262                    )
263            if len(args) == 2:
264                if args[0] is Ellipsis:
265                    self.__exception__(
266                        "Tuple with Ellipsis must have exactly two type arguments and the first must not be Ellipsis.",
267                        raise_exception=True,
268                    )
269                if args[1] is Ellipsis:
270                    return {tuple: (self.__get_checkable_type__(args[0]), True)}
271            return {
272                tuple: (
273                    tuple(self.__get_checkable_type__(arg) for arg in args),
274                    False,
275                )
276            }
277
278        if origin == set:
279            if len(args) != 1:
280                self.__exception__(
281                    f"Set must have a single type argument, got: {args}",
282                    raise_exception=True,
283                )
284            return {set: self.__get_checkable_type__(args[0])}
285
286        # Handle Sized types
287        if annotation == Sized:
288            return {
289                list: None,
290                tuple: None,
291                dict: None,
292                set: None,
293                str: None,
294                bytes: None,
295                bytearray: None,
296                memoryview: None,
297                range: None,
298            }
299
300        # Handle Callable types
301        if annotation == Callable:
302            return {
303                staticmethod: None,
304                classmethod: None,
305                FunctionType: None,
306                BuiltinFunctionType: None,
307                MethodType: None,
308                BuiltinMethodType: None,
309                GeneratorType: None,
310            }
311
312        if annotation == Any:
313            return {
314                object: None,
315            }
316
317        # Handle Constraints
318        if isinstance(annotation, GenericConstraint):
319            return {"__extra__": {"__constraints__": [annotation]}}
320
321        # Handle standard types
322        if isinstance(annotation, type):
323            return {annotation: None}
324
325        # Hanldle typing.Type (for uninitialized classes)
326        if origin is type and len(args) == 1:
327            return {annotation: None}
328
329        self.__exception__(
330            f"Unsupported type hint: {annotation}", raise_exception=True
331        )
332
333    def __exception__(self, message, raise_exception=False):
334        """
335        Usage:
336
337        - Creates a class based exception message
338
339        Requires:
340
341        - `message`:
342            - Type: str
343            - What: The message to warn users with
344
345        Optional:
346
347        - `raise_exception`:
348            - Type: bool
349            - What: Forces an exception to be raised regardless of the `self.__strict__` setting.
350            - Default: False
351        """
352        if self.__strict__ or raise_exception:
353            msg = f"TypeEnforced Exception ({self.__fn__.__qualname__}): {message}"
354            if self.__clean_traceback__:
355                package_path = _package_path
356                frame = sys._getframe()
357                relevant_tb_count = 0
358                while frame is not None:
359                    frame_file = Path(frame.f_code.co_filename).resolve()
360                    try:
361                        frame_file.relative_to(package_path)
362                    except ValueError:
363                        relevant_tb_count += 1
364                    frame = frame.f_back
365                original_excepthook = sys.excepthook
366
367                def excepthook(type, value, tb):
368                    traceback.print_exception(
369                        type, value, tb, limit=relevant_tb_count
370                    )
371                    sys.excepthook = original_excepthook
372
373                sys.excepthook = excepthook
374            raise TypeError(msg)
375        else:
376            print(
377                f"TypeEnforced Warning ({self.__fn__.__qualname__}): {message}"
378            )
379
380    def __get__(self, obj, objtype):
381        """
382        Overwrite standard __get__ method to return __call__ instead for wrapped class methods.
383
384        Also stores the calling (__get__) `obj` to be passed as an initial argument for `__call__` such that methods can pass `self` correctly.
385        """
386        self.__outer_self__ = obj
387
388        def __get_fn__(*args, **kwargs):
389            return self.__call__(*args, **kwargs)
390
391        __get_fn__.__name__ = self.__fn__.__name__
392        __get_fn__.__qualname__ = self.__fn__.__qualname__
393        __get_fn__.__doc__ = self.__fn__.__doc__
394        return __get_fn__
395
396    def __check_method_function__(self):
397        """
398        Validate that `self.__fn__` is a method or function
399        """
400        if not isinstance(self.__fn__, (MethodType, FunctionType)):
401            raise Exception(
402                f"A non function/method was passed to Enforcer. See the stack trace above for more information."
403            )
404
405    def __call__(self, *args, **kwargs):
406        """
407        This method is used to validate the passed inputs and return the output of the wrapped function or method.
408        """
409        # Special code to pass self as an initial argument
410        # for validation purposes in methods
411        # See: self.__get__
412        if self.__outer_self__ is not None:
413            args = (self.__outer_self__, *args)
414        # Get a dictionary of all annotations as checkable types
415        # Note: This is only done once at first call to avoid redundant calculations
416        self.__get_checkable_types__()
417        # Fast path: simple types use direct index lookup
418        for key, types_tuple in self.__simple_types__.items():
419            idx = self.__param_indices__[key]
420            if idx < len(args):
421                obj = args[idx]
422            elif key in kwargs:
423                obj = kwargs[key]
424            else:
425                obj = self.__fn_defaults__.get(key)
426            if not isinstance(obj, types_tuple):
427                # Fall back to full check for error reporting
428                self.__check_type__(obj, self.__checkable_types__[key], key)
429        # Full validation for complex types (nested, extras, Type[X])
430        if self.__complex_types__:
431            assigned_vars = {
432                **self.__fn_defaults__,
433                **dict(zip(self.__fn_varnames__[: len(args)], args)),
434                **kwargs,
435            }
436            for key, value in self.__complex_types__.items():
437                self.__check_type__(assigned_vars.get(key), value, key)
438        # Execute the function callable
439        return_value = self.__fn__(*args, **kwargs)
440        # If a return type was passed, validate the returned object
441        if self.__return_type__ is not None:
442            if self.__simple_return_type__ is not None:
443                if not isinstance(return_value, self.__simple_return_type__):
444                    self.__check_type__(
445                        return_value, self.__return_type__, "return"
446                    )
447            else:
448                self.__check_type__(
449                    return_value, self.__return_type__, "return"
450                )
451        return return_value
452
453    def __quick_check__(self, subtype, obj):
454        subtype_id = id(subtype)
455        if subtype_id not in self.__flat_subtypes__:
456            # First call for this subtype: compute and cache
457            if all(v is None for v in subtype.values()):
458                self.__flat_subtypes__[subtype_id] = frozenset(subtype.keys())
459            else:
460                self.__flat_subtypes__[subtype_id] = None
461        flat_keys = self.__flat_subtypes__[subtype_id]
462        if flat_keys is not None:
463            values = {type(v) for v in obj}
464            if values.issubset(flat_keys):
465                return True
466        return False
467
468    def __check_type__(self, obj, expected, key):
469        """
470        Raises an exception the type of a passed `obj` (parameter) is not in the list of supplied `acceptable_types` for the argument.
471        """
472        # Special case for None
473        if obj is None and _NoneType in expected:
474            return
475        if "__extra__" in expected:
476            extra = expected["__extra__"]
477            expected = {k: v for k, v in expected.items() if k != "__extra__"}
478        else:
479            extra = None
480
481        if isinstance(obj, type):
482            # An uninitialized class is passed, we need to check if the type is in the expected types using Type[obj]
483            obj_type = Type[obj]
484            is_present = obj_type in expected
485        else:
486            obj_type = type(obj)
487            is_present = obj_type in expected or isinstance(
488                obj, tuple(expected.keys())
489            )
490
491        if not is_present:
492            # Allow for literals to be used to bypass type checks if present
493            literal = extra.get("__literal__", ()) if extra is not None else ()
494            if literal:
495                if obj not in literal:
496                    self.__exception__(
497                        f"Type mismatch for typed variable `{key}`. Expected one of the following `{list(expected.keys())}` or a literal value in `{literal}` but got type `{obj_type}` with value `{obj}` instead."
498                    )
499            # Raise an exception if the type is not in the expected types
500            else:
501                self.__exception__(
502                    f"Type mismatch for typed variable `{key}`. Expected one of the following `{list(expected.keys())}` but got `{obj_type}` with value `{obj}` instead."
503                )
504        # If the object_type is in the expected types, we can proceed with validation
505        elif obj_type in iterable_types:
506            subtype = expected.get(obj_type, None)
507            if subtype is None:
508                pass
509            # Recursive validation
510            elif obj_type == list:
511                if self.__iterable_sample_pct__ < 100:
512                    for idx in self.__get_sample_indices__(len(obj)):
513                        self.__check_type__(obj[idx], subtype, f"{key}[{idx}]")
514                # If the subtype does not contain iterables with typing, we can validate the items directly.
515                elif not self.__quick_check__(subtype, obj):
516                    for idx, item in enumerate(obj):
517                        self.__check_type__(item, subtype, f"{key}[{idx}]")
518            elif obj_type == dict:
519                key_type, val_type = subtype
520                if self.__iterable_sample_pct__ < 100:
521                    sampled_keys = self.__get_sample_keys__(list(obj.keys()))
522                    if not self.__quick_check__(key_type, sampled_keys):
523                        for dk in sampled_keys:
524                            self.__check_type__(
525                                dk, key_type, f"{key}.key[{repr(dk)}]"
526                            )
527                    if not self.__quick_check__(
528                        val_type, [obj[dk] for dk in sampled_keys]
529                    ):
530                        for dk in sampled_keys:
531                            self.__check_type__(
532                                obj[dk], val_type, f"{key}[{repr(dk)}]"
533                            )
534                else:
535                    if not self.__quick_check__(key_type, obj.keys()):
536                        for key in obj.keys():
537                            self.__check_type__(
538                                key, key_type, f"{key}.key[{repr(key)}]"
539                            )
540                    if not self.__quick_check__(val_type, obj.values()):
541                        for key, value in obj.items():
542                            self.__check_type__(
543                                value, val_type, f"{key}[{repr(key)}]"
544                            )
545            elif obj_type == tuple:
546                expected_args, is_ellipsis = subtype
547                if is_ellipsis:
548                    if self.__iterable_sample_pct__ < 100:
549                        for idx in self.__get_sample_indices__(len(obj)):
550                            self.__check_type__(
551                                obj[idx], expected_args, f"{key}[{idx}]"
552                            )
553                    elif not self.__quick_check__(expected_args, obj):
554                        for idx, item in enumerate(obj):
555                            self.__check_type__(
556                                item, expected_args, f"{key}[{idx}]"
557                            )
558                else:
559                    if len(obj) != len(expected_args):
560                        self.__exception__(
561                            f"Tuple length mismatch for `{key}`. Expected length {len(expected_args)}, got {len(obj)}"
562                        )
563                    for idx, (item, ex) in enumerate(zip(obj, expected_args)):
564                        self.__check_type__(item, ex, f"{key}[{idx}]")
565            elif obj_type == set:
566                if self.__iterable_sample_pct__ < 100:
567                    obj_list = list(obj)
568                    for idx in self.__get_sample_indices__(len(obj_list)):
569                        item = obj_list[idx]
570                        self.__check_type__(
571                            item, subtype, f"{key}[{repr(item)}]"
572                        )
573                elif not self.__quick_check__(subtype, obj):
574                    for item in obj:
575                        self.__check_type__(
576                            item, subtype, f"{key}[{repr(item)}]"
577                        )
578
579        # Validate constraints if any are present
580        if extra is not None:
581            constraints = extra.get("__constraints__", ())
582            for constraint in constraints:
583                constraint_validation_output = constraint.__validate__(key, obj)
584                if constraint_validation_output is not True:
585                    self.__exception__(
586                        f"Constraint validation error for variable `{key}` with value `{obj}`. {constraint_validation_output}"
587                    )
588
589    def __repr__(self):
590        return f"<type_enforced {self.__fn__.__module__}.{self.__fn__.__qualname__} object at {hex(id(self))}>"
FunctionMethodEnforcer( __fn__, __strict__=False, __clean_traceback__=True, __iterable_sample_pct__=100)
49    def __init__(
50        self,
51        __fn__,
52        __strict__=False,
53        __clean_traceback__=True,
54        __iterable_sample_pct__=100,
55    ):
56        """
57        Initialize a FunctionMethodEnforcer class object as a wrapper for a passed function `__fn__`.
58
59        Requires:
60
61            - `__fn__`:
62                - What: The function to enforce
63                - Type: function | method | class
64
65        Optional:
66
67            - `__strict__`:
68                - What: A boolean to enable or disable exceptions. If True, exceptions will be raised
69                    when type checking fails. If False, exceptions will not be raised but instead a warning
70                    will be printed to the console.
71                - Type: bool
72                - Default: False
73            - `__clean_traceback__`:
74                - What: A boolean to enable or disable cleaning of tracebacks when raising exceptions.
75                - Type: bool
76                - Default: True
77            - `__iterable_sample_pct__`:
78                - What: The percentage of items to sample when validating iterables. If 100, all items
79                    are validated. If less than 100, the first and last items are always validated
80                    plus a random sample of the remaining items up to the specified percentage.
81                - Type: int | float
82                - Default: 100
83        """
84        update_wrapper(self, __fn__)
85        self.__fn__ = __fn__
86        self.__strict__ = __strict__
87        self.__clean_traceback__ = __clean_traceback__
88        self.__iterable_sample_pct__ = __iterable_sample_pct__
89        self.__outer_self__ = None
90        self.__types_parsed__ = False
91        self.__flat_subtypes__ = {}
92        # Validate that the passed function or method is a method or function
93        self.__check_method_function__()
94        # Get input defaults for the function or method
95        self.__get_defaults__()

Initialize a FunctionMethodEnforcer class object as a wrapper for a passed function __fn__.

Requires:

- `__fn__`:
    - What: The function to enforce
    - Type: function | method | class

Optional:

- `__strict__`:
    - What: A boolean to enable or disable exceptions. If True, exceptions will be raised
        when type checking fails. If False, exceptions will not be raised but instead a warning
        will be printed to the console.
    - Type: bool
    - Default: False
- `__clean_traceback__`:
    - What: A boolean to enable or disable cleaning of tracebacks when raising exceptions.
    - Type: bool
    - Default: True
- `__iterable_sample_pct__`:
    - What: The percentage of items to sample when validating iterables. If 100, all items
        are validated. If less than 100, the first and last items are always validated
        plus a random sample of the remaining items up to the specified percentage.
    - Type: int | float
    - Default: 100
@Partial
def Enforcer( clsFnMethod, enabled=True, strict=True, clean_traceback=True, iterable_sample_pct=100):
593@Partial
594def Enforcer(
595    clsFnMethod,
596    enabled=True,
597    strict=True,
598    clean_traceback=True,
599    iterable_sample_pct=100,
600):
601    """
602    A wrapper to enforce types within a function or method given argument annotations.
603
604    Each wrapped item is converted into a special `FunctionMethodEnforcer` class object that validates the passed parameters for the function or method when it is called. If a function or method that is passed does not have any annotations, it is not converted into a `FunctionMethodEnforcer` class as no validation is possible.
605
606    If wrapping a class, all methods in the class that meet any of the following criteria will be wrapped individually:
607
608    - Methods with `__call__`
609    - Methods wrapped with `staticmethod` (if python >= 3.10)
610    - Methods wrapped with `classmethod` (if python >= 3.10)
611
612    Requires:
613
614    - `clsFnMethod`:
615        - What: The class, function or method that should have input types enforced
616        - Type: function | method | class
617
618    Optional:
619
620    - `enabled`:
621        - What: A boolean to enable or disable the enforcer
622        - Type: bool
623        - Default: True
624    - `strict`:
625        - What: A boolean to enable or disable exceptions. If True, exceptions will be raised when type checking fails. If False, exceptions will not be raised but instead a warning will be printed to the console.
626        - Type: bool
627        - Default: False
628        - Note: Type hints that are wrapped with the type enforcer and are invalid will still raise an exception.
629    - `clean_traceback`:
630        - What: A boolean to enable or disable cleaning of tracebacks when raising exceptions.
631        - If True, modifies the excepthook temporarily such that only the relevant stack (not in the type_enforced package) is shown.
632        - Type: bool
633        - Default: True
634    - `iterable_sample_pct`:
635        - What: The percentage (0-100) of items to validate when checking typed iterables (list,
636            dict, set, variable-length tuple). At 100 (default) every item is checked. Below 100,
637            the first and last items are always checked; if the collection has more than 3 items,
638            additional items are randomly sampled so that the total checked is at least 3.
639        - Type: int | float
640        - Default: 100
641
642
643    Example Use:
644    ```
645    >>> import type_enforced
646    >>> @type_enforced.Enforcer
647    ... def my_fn(a: int , b: [int, str] =2, c: int =3) -> None:
648    ...     pass
649    ...
650    >>> my_fn(a=1, b=2, c=3)
651    >>> my_fn(a=1, b='2', c=3)
652    >>> my_fn(a='a', b=2, c=3)
653    Traceback (most recent call last):
654      File "<stdin>", line 1, in <module>
655      File "/home/conmak/development/personal/type_enforced/type_enforced/enforcer.py", line 85, in __call__
656        self.__check_type__(assigned_vars.get(key), value, key)
657      File "/home/conmak/development/personal/type_enforced/type_enforced/enforcer.py", line 107, in __check_type__
658        self.__exception__(
659      File "/home/conmak/development/personal/type_enforced/type_enforced/enforcer.py", line 34, in __exception__
660        raise Exception(f"({self.__fn__.__qualname__}): {message}")
661    Exception: (my_fn): Type mismatch for typed variable `a`. Expected one of the following `[<class 'int'>]` but got `<class 'str'>` instead.
662    ```
663    """
664    if not hasattr(clsFnMethod, "__type_enforced_enabled__"):
665        # Special try except clause to handle cases when the object is immutable
666        try:
667            clsFnMethod.__type_enforced_enabled__ = enabled
668        except:
669            return clsFnMethod
670    if not clsFnMethod.__type_enforced_enabled__:
671        return clsFnMethod
672    if isinstance(
673        clsFnMethod, (staticmethod, classmethod, FunctionType, MethodType)
674    ):
675        # Only apply the enforcer if type_hints are present
676        # Add try except clause to better handle forward refs.
677        try:
678            if get_type_hints(clsFnMethod) == {}:
679                return clsFnMethod
680        except:
681            pass
682        if isinstance(clsFnMethod, staticmethod):
683            return staticmethod(
684                FunctionMethodEnforcer(
685                    __fn__=clsFnMethod.__func__,
686                    __strict__=strict,
687                    __clean_traceback__=clean_traceback,
688                    __iterable_sample_pct__=iterable_sample_pct,
689                )
690            )
691        elif isinstance(clsFnMethod, classmethod):
692            return classmethod(
693                FunctionMethodEnforcer(
694                    __fn__=clsFnMethod.__func__,
695                    __strict__=strict,
696                    __clean_traceback__=clean_traceback,
697                    __iterable_sample_pct__=iterable_sample_pct,
698                )
699            )
700        else:
701            return FunctionMethodEnforcer(
702                __fn__=clsFnMethod,
703                __strict__=strict,
704                __clean_traceback__=clean_traceback,
705                __iterable_sample_pct__=iterable_sample_pct,
706            )
707    elif hasattr(clsFnMethod, "__dict__"):
708        for key, value in clsFnMethod.__dict__.items():
709            # Skip the __annotate__ method if present in __dict__ as it deletes itself upon invocation
710            # Skip any previously wrapped methods if they are already a FunctionMethodEnforcer
711            if key == "__annotate__" or isinstance(
712                value, FunctionMethodEnforcer
713            ):
714                continue
715            if hasattr(value, "__call__") or isinstance(
716                value, (classmethod, staticmethod)
717            ):
718                setattr(
719                    clsFnMethod,
720                    key,
721                    Enforcer(
722                        value,
723                        enabled=enabled,
724                        strict=strict,
725                        clean_traceback=clean_traceback,
726                        iterable_sample_pct=iterable_sample_pct,
727                    ),
728                )
729        return clsFnMethod
730    else:
731        raise Exception(
732            "Enforcer can only be used on classes, methods, or functions."
733        )

A wrapper to enforce types within a function or method given argument annotations.

Each wrapped item is converted into a special FunctionMethodEnforcer class object that validates the passed parameters for the function or method when it is called. If a function or method that is passed does not have any annotations, it is not converted into a FunctionMethodEnforcer class as no validation is possible.

If wrapping a class, all methods in the class that meet any of the following criteria will be wrapped individually:

  • Methods with __call__
  • Methods wrapped with staticmethod (if python >= 3.10)
  • Methods wrapped with classmethod (if python >= 3.10)

Requires:

  • clsFnMethod:
    • What: The class, function or method that should have input types enforced
    • Type: function | method | class

Optional:

  • enabled:
    • What: A boolean to enable or disable the enforcer
    • Type: bool
    • Default: True
  • strict:
    • What: A boolean to enable or disable exceptions. If True, exceptions will be raised when type checking fails. If False, exceptions will not be raised but instead a warning will be printed to the console.
    • Type: bool
    • Default: False
    • Note: Type hints that are wrapped with the type enforcer and are invalid will still raise an exception.
  • clean_traceback:
    • What: A boolean to enable or disable cleaning of tracebacks when raising exceptions.
    • If True, modifies the excepthook temporarily such that only the relevant stack (not in the type_enforced package) is shown.
    • Type: bool
    • Default: True
  • iterable_sample_pct:
    • What: The percentage (0-100) of items to validate when checking typed iterables (list, dict, set, variable-length tuple). At 100 (default) every item is checked. Below 100, the first and last items are always checked; if the collection has more than 3 items, additional items are randomly sampled so that the total checked is at least 3.
    • Type: int | float
    • Default: 100

Example Use:

>>> import type_enforced
>>> @type_enforced.Enforcer
... def my_fn(a: int , b: [int, str] =2, c: int =3) -> None:
...     pass
...
>>> my_fn(a=1, b=2, c=3)
>>> my_fn(a=1, b='2', c=3)
>>> my_fn(a='a', b=2, c=3)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/conmak/development/personal/type_enforced/type_enforced/enforcer.py", line 85, in __call__
    self.__check_type__(assigned_vars.get(key), value, key)
  File "/home/conmak/development/personal/type_enforced/type_enforced/enforcer.py", line 107, in __check_type__
    self.__exception__(
  File "/home/conmak/development/personal/type_enforced/type_enforced/enforcer.py", line 34, in __exception__
    raise Exception(f"({self.__fn__.__qualname__}): {message}")
Exception: (my_fn): Type mismatch for typed variable `a`. Expected one of the following `[<class 'int'>]` but got `<class 'str'>` instead.