type_enforced.enforcer
1from types import ( 2 FunctionType, 3 MethodType, 4 GeneratorType, 5 BuiltinFunctionType, 6 BuiltinMethodType, 7 UnionType, 8) 9from typing import Type, Union, Sized, Literal, Callable, get_type_hints, Any 10from functools import update_wrapper 11from type_enforced.utils import ( 12 Partial, 13 GenericConstraint, 14 iterable_types, 15 merge_type_dicts, 16) 17import sys, traceback, random 18from pathlib import Path 19 20_NoneType = type(None) 21_package_path = Path(__file__).parent.resolve() 22 23 24class FunctionMethodEnforcer: 25 __slots__ = ( 26 "__fn__", 27 "__strict__", 28 "__clean_traceback__", 29 "__iterable_sample_pct__", 30 "__outer_self__", 31 "__fn_defaults__", 32 "__fn_varnames__", 33 "__types_parsed__", 34 "__checkable_types__", 35 "__return_type__", 36 "__simple_types__", 37 "__complex_types__", 38 "__simple_return_type__", 39 "__param_indices__", 40 "__flat_subtypes__", 41 "__wrapped__", 42 "__name__", 43 "__qualname__", 44 "__doc__", 45 "__dict__", 46 ) 47 48 def __init__( 49 self, 50 __fn__, 51 __strict__=False, 52 __clean_traceback__=True, 53 __iterable_sample_pct__=100, 54 ): 55 """ 56 Initialize a FunctionMethodEnforcer class object as a wrapper for a passed function `__fn__`. 57 58 Requires: 59 60 - `__fn__`: 61 - What: The function to enforce 62 - Type: function | method | class 63 64 Optional: 65 66 - `__strict__`: 67 - What: A boolean to enable or disable exceptions. If True, exceptions will be raised 68 when type checking fails. If False, exceptions will not be raised but instead a warning 69 will be printed to the console. 70 - Type: bool 71 - Default: False 72 - `__clean_traceback__`: 73 - What: A boolean to enable or disable cleaning of tracebacks when raising exceptions. 74 - Type: bool 75 - Default: True 76 - `__iterable_sample_pct__`: 77 - What: The percentage of items to sample when validating iterables. If 100, all items 78 are validated. If less than 100, the first and last items are always validated 79 plus a random sample of the remaining items up to the specified percentage. 80 - Type: int | float 81 - Default: 100 82 """ 83 update_wrapper(self, __fn__) 84 self.__fn__ = __fn__ 85 self.__strict__ = __strict__ 86 self.__clean_traceback__ = __clean_traceback__ 87 self.__iterable_sample_pct__ = __iterable_sample_pct__ 88 self.__outer_self__ = None 89 self.__types_parsed__ = False 90 self.__flat_subtypes__ = {} 91 # Validate that the passed function or method is a method or function 92 self.__check_method_function__() 93 # Get input defaults for the function or method 94 self.__get_defaults__() 95 96 def __get_defaults__(self): 97 """ 98 Get the default values of the passed function or method and store them in `self.__fn_defaults__`. 99 Also caches the function's variable names for use in `__call__`. 100 """ 101 self.__fn_varnames__ = self.__fn__.__code__.co_varnames 102 self.__fn_defaults__ = {} 103 if self.__fn__.__defaults__ is not None: 104 # Get the names of all provided default values for args 105 default_varnames = list(self.__fn_varnames__)[ 106 : self.__fn__.__code__.co_argcount 107 ][-len(self.__fn__.__defaults__) :] 108 # Update the output dictionary with the default values 109 self.__fn_defaults__.update( 110 dict(zip(default_varnames, self.__fn__.__defaults__)) 111 ) 112 if self.__fn__.__kwdefaults__ is not None: 113 # Update the output dictionary with the keyword default values 114 self.__fn_defaults__.update(self.__fn__.__kwdefaults__) 115 116 def __get_sample_indices__(self, length): 117 """ 118 Get a sorted list of indices to sample for iterable validation. 119 120 If iterable_sample_pct is 0, only the first item (index 0) is checked. 121 Otherwise, always includes the first (0) and last (length-1) indices. 122 If length > 3, samples additional middle indices up to the 123 iterable_sample_pct percentage. 124 Only called when self.__iterable_sample_pct__ < 100. 125 """ 126 if length == 0: 127 return [] 128 if self.__iterable_sample_pct__ == 0: 129 return [0] 130 if length <= 3: 131 return range(length) 132 n = max(3, int(length * self.__iterable_sample_pct__ / 100)) 133 if n >= length: 134 return range(length) 135 middle_sample = random.sample(range(1, length - 1), n - 2) 136 return sorted([0] + middle_sample + [length - 1]) 137 138 def __get_sample_keys__(self, keys): 139 """ 140 Get a sampled list of dict keys for iterable validation. 141 142 If iterable_sample_pct is 0, only the first key is returned. 143 Otherwise, always includes the first key plus a random sample of 144 the remaining keys up to the iterable_sample_pct percentage. 145 Only called when self.__iterable_sample_pct__ < 100. 146 """ 147 if len(keys) == 0: 148 return [] 149 if self.__iterable_sample_pct__ == 0 or len(keys) == 1: 150 return [keys[0]] 151 n = max(1, int(len(keys) * self.__iterable_sample_pct__ / 100)) 152 if n >= len(keys): 153 return keys 154 return [keys[0]] + random.sample(keys[1:], n - 1) 155 156 def __get_checkable_types__(self): 157 """ 158 Creates two class attributes: 159 160 - `self.__checkable_types__`: 161 - What: A dictionary of all annotations as checkable types 162 - Type: dict 163 164 - `self.__return_type__`: 165 - What: The return type of the function or method 166 - Type: dict | None 167 """ 168 if not self.__types_parsed__: 169 self.__checkable_types__ = { 170 key: self.__get_checkable_type__(value) 171 for key, value in get_type_hints(self.__fn__).items() 172 } 173 self.__return_type__ = self.__checkable_types__.pop("return", None) 174 # Classify params: simple types can use a single 175 # isinstance call, skipping __check_type__ entirely. 176 self.__simple_types__ = {} 177 self.__complex_types__ = {} 178 for key, expected in self.__checkable_types__.items(): 179 if ( 180 "__extra__" not in expected 181 and all(v is None for v in expected.values()) 182 and all(isinstance(k, type) for k in expected.keys()) 183 ): 184 self.__simple_types__[key] = tuple(expected.keys()) 185 else: 186 self.__complex_types__[key] = expected 187 # Same classification for return type 188 if self.__return_type__ is not None and ( 189 "__extra__" not in self.__return_type__ 190 and all(v is None for v in self.__return_type__.values()) 191 and all( 192 isinstance(k, type) for k in self.__return_type__.keys() 193 ) 194 ): 195 self.__simple_return_type__ = tuple(self.__return_type__.keys()) 196 else: 197 self.__simple_return_type__ = None 198 # Pre-compute param index in co_varnames for 199 # direct arg lookup (skips assigned_vars dict). 200 self.__param_indices__ = { 201 name: i 202 for i, name in enumerate(self.__fn_varnames__) 203 if name in self.__checkable_types__ 204 } 205 self.__types_parsed__ = True 206 207 def __get_checkable_type__(self, annotation): 208 """ 209 Parses a type annotation and returns a nested dict structure 210 representing the checkable type(s) for validation. 211 """ 212 213 if annotation is None: 214 return {_NoneType: None} 215 216 # Handle `int | str` syntax (Python 3.10+) and Unions 217 if ( 218 isinstance(annotation, UnionType) 219 or getattr(annotation, "__origin__", None) == Union 220 ): 221 combined_types = {} 222 for sub_type in annotation.__args__: 223 merge_type_dicts( 224 combined_types, 225 self.__get_checkable_type__(sub_type), 226 ) 227 return combined_types 228 229 # Handle typing.Literal 230 if getattr(annotation, "__origin__", None) == Literal: 231 return {"__extra__": {"__literal__": list(annotation.__args__)}} 232 233 # Handle generic collections 234 origin = getattr(annotation, "__origin__", None) 235 args = getattr(annotation, "__args__", ()) 236 237 if origin == list: 238 if len(args) != 1: 239 self.__exception__( 240 f"List must have a single type argument, got: {args}", 241 raise_exception=True, 242 ) 243 return {list: self.__get_checkable_type__(args[0])} 244 245 if origin == dict: 246 if len(args) != 2: 247 self.__exception__( 248 f"Dict must have two type arguments, got: {args}", 249 raise_exception=True, 250 ) 251 key_type = self.__get_checkable_type__(args[0]) 252 value_type = self.__get_checkable_type__(args[1]) 253 return {dict: (key_type, value_type)} 254 255 if origin == tuple: 256 if len(args) > 2 or len(args) == 1: 257 if Ellipsis in args: 258 self.__exception__( 259 "Tuple with Ellipsis must have exactly two type arguments and the second must be Ellipsis.", 260 raise_exception=True, 261 ) 262 if len(args) == 2: 263 if args[0] is Ellipsis: 264 self.__exception__( 265 "Tuple with Ellipsis must have exactly two type arguments and the first must not be Ellipsis.", 266 raise_exception=True, 267 ) 268 if args[1] is Ellipsis: 269 return {tuple: (self.__get_checkable_type__(args[0]), True)} 270 return { 271 tuple: ( 272 tuple(self.__get_checkable_type__(arg) for arg in args), 273 False, 274 ) 275 } 276 277 if origin == set: 278 if len(args) != 1: 279 self.__exception__( 280 f"Set must have a single type argument, got: {args}", 281 raise_exception=True, 282 ) 283 return {set: self.__get_checkable_type__(args[0])} 284 285 # Handle Sized types 286 if annotation == Sized: 287 return { 288 list: None, 289 tuple: None, 290 dict: None, 291 set: None, 292 str: None, 293 bytes: None, 294 bytearray: None, 295 memoryview: None, 296 range: None, 297 } 298 299 # Handle Callable types 300 if annotation == Callable: 301 return { 302 staticmethod: None, 303 classmethod: None, 304 FunctionType: None, 305 BuiltinFunctionType: None, 306 MethodType: None, 307 BuiltinMethodType: None, 308 GeneratorType: None, 309 } 310 311 if annotation == Any: 312 return { 313 object: None, 314 } 315 316 # Handle Constraints 317 if isinstance(annotation, GenericConstraint): 318 return {"__extra__": {"__constraints__": [annotation]}} 319 320 # Handle standard types 321 if isinstance(annotation, type): 322 return {annotation: None} 323 324 # Hanldle typing.Type (for uninitialized classes) 325 if origin is type and len(args) == 1: 326 return {annotation: None} 327 328 self.__exception__( 329 f"Unsupported type hint: {annotation}", raise_exception=True 330 ) 331 332 def __exception__(self, message, raise_exception=False): 333 """ 334 Usage: 335 336 - Creates a class based exception message 337 338 Requires: 339 340 - `message`: 341 - Type: str 342 - What: The message to warn users with 343 344 Optional: 345 346 - `raise_exception`: 347 - Type: bool 348 - What: Forces an exception to be raised regardless of the `self.__strict__` setting. 349 - Default: False 350 """ 351 if self.__strict__ or raise_exception: 352 msg = f"TypeEnforced Exception ({self.__fn__.__qualname__}): {message}" 353 if self.__clean_traceback__: 354 package_path = _package_path 355 frame = sys._getframe() 356 relevant_tb_count = 0 357 while frame is not None: 358 frame_file = Path(frame.f_code.co_filename).resolve() 359 try: 360 frame_file.relative_to(package_path) 361 except ValueError: 362 relevant_tb_count += 1 363 frame = frame.f_back 364 original_excepthook = sys.excepthook 365 366 def excepthook(type, value, tb): 367 traceback.print_exception( 368 type, value, tb, limit=relevant_tb_count 369 ) 370 sys.excepthook = original_excepthook 371 372 sys.excepthook = excepthook 373 raise TypeError(msg) 374 else: 375 print( 376 f"TypeEnforced Warning ({self.__fn__.__qualname__}): {message}" 377 ) 378 379 def __get__(self, obj, objtype): 380 """ 381 Overwrite standard __get__ method to return __call__ instead for wrapped class methods. 382 383 Also stores the calling (__get__) `obj` to be passed as an initial argument for `__call__` such that methods can pass `self` correctly. 384 """ 385 self.__outer_self__ = obj 386 387 def __get_fn__(*args, **kwargs): 388 return self.__call__(*args, **kwargs) 389 390 __get_fn__.__name__ = self.__fn__.__name__ 391 __get_fn__.__qualname__ = self.__fn__.__qualname__ 392 __get_fn__.__doc__ = self.__fn__.__doc__ 393 return __get_fn__ 394 395 def __check_method_function__(self): 396 """ 397 Validate that `self.__fn__` is a method or function 398 """ 399 if not isinstance(self.__fn__, (MethodType, FunctionType)): 400 raise Exception( 401 f"A non function/method was passed to Enforcer. See the stack trace above for more information." 402 ) 403 404 def __call__(self, *args, **kwargs): 405 """ 406 This method is used to validate the passed inputs and return the output of the wrapped function or method. 407 """ 408 # Special code to pass self as an initial argument 409 # for validation purposes in methods 410 # See: self.__get__ 411 if self.__outer_self__ is not None: 412 args = (self.__outer_self__, *args) 413 # Get a dictionary of all annotations as checkable types 414 # Note: This is only done once at first call to avoid redundant calculations 415 self.__get_checkable_types__() 416 # Fast path: simple types use direct index lookup 417 for key, types_tuple in self.__simple_types__.items(): 418 idx = self.__param_indices__[key] 419 if idx < len(args): 420 obj = args[idx] 421 elif key in kwargs: 422 obj = kwargs[key] 423 else: 424 obj = self.__fn_defaults__.get(key) 425 if not isinstance(obj, types_tuple): 426 # Fall back to full check for error reporting 427 self.__check_type__(obj, self.__checkable_types__[key], key) 428 # Full validation for complex types (nested, extras, Type[X]) 429 if self.__complex_types__: 430 assigned_vars = { 431 **self.__fn_defaults__, 432 **dict(zip(self.__fn_varnames__[: len(args)], args)), 433 **kwargs, 434 } 435 for key, value in self.__complex_types__.items(): 436 self.__check_type__(assigned_vars.get(key), value, key) 437 # Execute the function callable 438 return_value = self.__fn__(*args, **kwargs) 439 # If a return type was passed, validate the returned object 440 if self.__return_type__ is not None: 441 if self.__simple_return_type__ is not None: 442 if not isinstance(return_value, self.__simple_return_type__): 443 self.__check_type__( 444 return_value, self.__return_type__, "return" 445 ) 446 else: 447 self.__check_type__( 448 return_value, self.__return_type__, "return" 449 ) 450 return return_value 451 452 def __quick_check__(self, subtype, obj): 453 subtype_id = id(subtype) 454 if subtype_id not in self.__flat_subtypes__: 455 # First call for this subtype: compute and cache 456 if all(v is None for v in subtype.values()): 457 self.__flat_subtypes__[subtype_id] = frozenset(subtype.keys()) 458 else: 459 self.__flat_subtypes__[subtype_id] = None 460 flat_keys = self.__flat_subtypes__[subtype_id] 461 if flat_keys is not None: 462 values = {type(v) for v in obj} 463 if values.issubset(flat_keys): 464 return True 465 return False 466 467 def __check_type__(self, obj, expected, key): 468 """ 469 Raises an exception the type of a passed `obj` (parameter) is not in the list of supplied `acceptable_types` for the argument. 470 """ 471 # Special case for None 472 if obj is None and _NoneType in expected: 473 return 474 if "__extra__" in expected: 475 extra = expected["__extra__"] 476 expected = {k: v for k, v in expected.items() if k != "__extra__"} 477 else: 478 extra = None 479 480 if isinstance(obj, type): 481 # An uninitialized class is passed, we need to check if the type is in the expected types using Type[obj] 482 obj_type = Type[obj] 483 is_present = obj_type in expected 484 else: 485 obj_type = type(obj) 486 is_present = obj_type in expected or isinstance( 487 obj, tuple(expected.keys()) 488 ) 489 490 if not is_present: 491 # Allow for literals to be used to bypass type checks if present 492 literal = extra.get("__literal__", ()) if extra is not None else () 493 if literal: 494 if obj not in literal: 495 self.__exception__( 496 f"Type mismatch for typed variable `{key}`. Expected one of the following `{list(expected.keys())}` or a literal value in `{literal}` but got type `{obj_type}` with value `{obj}` instead." 497 ) 498 # Raise an exception if the type is not in the expected types 499 else: 500 self.__exception__( 501 f"Type mismatch for typed variable `{key}`. Expected one of the following `{list(expected.keys())}` but got `{obj_type}` with value `{obj}` instead." 502 ) 503 # If the object_type is in the expected types, we can proceed with validation 504 elif obj_type in iterable_types: 505 subtype = expected.get(obj_type, None) 506 if subtype is None: 507 pass 508 # Recursive validation 509 elif obj_type == list: 510 if self.__iterable_sample_pct__ < 100: 511 for idx in self.__get_sample_indices__(len(obj)): 512 self.__check_type__(obj[idx], subtype, f"{key}[{idx}]") 513 # If the subtype does not contain iterables with typing, we can validate the items directly. 514 elif not self.__quick_check__(subtype, obj): 515 for idx, item in enumerate(obj): 516 self.__check_type__(item, subtype, f"{key}[{idx}]") 517 elif obj_type == dict: 518 key_type, val_type = subtype 519 if self.__iterable_sample_pct__ < 100: 520 sampled_keys = self.__get_sample_keys__(list(obj.keys())) 521 if not self.__quick_check__(key_type, sampled_keys): 522 for dk in sampled_keys: 523 self.__check_type__( 524 dk, key_type, f"{key}.key[{repr(dk)}]" 525 ) 526 if not self.__quick_check__( 527 val_type, [obj[dk] for dk in sampled_keys] 528 ): 529 for dk in sampled_keys: 530 self.__check_type__( 531 obj[dk], val_type, f"{key}[{repr(dk)}]" 532 ) 533 else: 534 if not self.__quick_check__(key_type, obj.keys()): 535 for key in obj.keys(): 536 self.__check_type__( 537 key, key_type, f"{key}.key[{repr(key)}]" 538 ) 539 if not self.__quick_check__(val_type, obj.values()): 540 for key, value in obj.items(): 541 self.__check_type__( 542 value, val_type, f"{key}[{repr(key)}]" 543 ) 544 elif obj_type == tuple: 545 expected_args, is_ellipsis = subtype 546 if is_ellipsis: 547 if self.__iterable_sample_pct__ < 100: 548 for idx in self.__get_sample_indices__(len(obj)): 549 self.__check_type__( 550 obj[idx], expected_args, f"{key}[{idx}]" 551 ) 552 elif not self.__quick_check__(expected_args, obj): 553 for idx, item in enumerate(obj): 554 self.__check_type__( 555 item, expected_args, f"{key}[{idx}]" 556 ) 557 else: 558 if len(obj) != len(expected_args): 559 self.__exception__( 560 f"Tuple length mismatch for `{key}`. Expected length {len(expected_args)}, got {len(obj)}" 561 ) 562 for idx, (item, ex) in enumerate(zip(obj, expected_args)): 563 self.__check_type__(item, ex, f"{key}[{idx}]") 564 elif obj_type == set: 565 if self.__iterable_sample_pct__ < 100: 566 obj_list = list(obj) 567 for idx in self.__get_sample_indices__(len(obj_list)): 568 item = obj_list[idx] 569 self.__check_type__( 570 item, subtype, f"{key}[{repr(item)}]" 571 ) 572 elif not self.__quick_check__(subtype, obj): 573 for item in obj: 574 self.__check_type__( 575 item, subtype, f"{key}[{repr(item)}]" 576 ) 577 578 # Validate constraints if any are present 579 if extra is not None: 580 constraints = extra.get("__constraints__", ()) 581 for constraint in constraints: 582 constraint_validation_output = constraint.__validate__(key, obj) 583 if constraint_validation_output is not True: 584 self.__exception__( 585 f"Constraint validation error for variable `{key}` with value `{obj}`. {constraint_validation_output}" 586 ) 587 588 def __repr__(self): 589 return f"<type_enforced {self.__fn__.__module__}.{self.__fn__.__qualname__} object at {hex(id(self))}>" 590 591 592@Partial 593def Enforcer( 594 clsFnMethod, 595 enabled=True, 596 strict=True, 597 clean_traceback=True, 598 iterable_sample_pct=100, 599): 600 """ 601 A wrapper to enforce types within a function or method given argument annotations. 602 603 Each wrapped item is converted into a special `FunctionMethodEnforcer` class object that validates the passed parameters for the function or method when it is called. If a function or method that is passed does not have any annotations, it is not converted into a `FunctionMethodEnforcer` class as no validation is possible. 604 605 If wrapping a class, all methods in the class that meet any of the following criteria will be wrapped individually: 606 607 - Methods with `__call__` 608 - Methods wrapped with `staticmethod` (if python >= 3.10) 609 - Methods wrapped with `classmethod` (if python >= 3.10) 610 611 Requires: 612 613 - `clsFnMethod`: 614 - What: The class, function or method that should have input types enforced 615 - Type: function | method | class 616 617 Optional: 618 619 - `enabled`: 620 - What: A boolean to enable or disable the enforcer 621 - Type: bool 622 - Default: True 623 - `strict`: 624 - What: A boolean to enable or disable exceptions. If True, exceptions will be raised when type checking fails. If False, exceptions will not be raised but instead a warning will be printed to the console. 625 - Type: bool 626 - Default: False 627 - Note: Type hints that are wrapped with the type enforcer and are invalid will still raise an exception. 628 - `clean_traceback`: 629 - What: A boolean to enable or disable cleaning of tracebacks when raising exceptions. 630 - If True, modifies the excepthook temporarily such that only the relevant stack (not in the type_enforced package) is shown. 631 - Type: bool 632 - Default: True 633 - `iterable_sample_pct`: 634 - What: The percentage (0-100) of items to validate when checking typed iterables (list, 635 dict, set, variable-length tuple). At 100 (default) every item is checked. Below 100, 636 the first and last items are always checked; if the collection has more than 3 items, 637 additional items are randomly sampled so that the total checked is at least 3. 638 - Type: int | float 639 - Default: 100 640 641 642 Example Use: 643 ``` 644 >>> import type_enforced 645 >>> @type_enforced.Enforcer 646 ... def my_fn(a: int , b: [int, str] =2, c: int =3) -> None: 647 ... pass 648 ... 649 >>> my_fn(a=1, b=2, c=3) 650 >>> my_fn(a=1, b='2', c=3) 651 >>> my_fn(a='a', b=2, c=3) 652 Traceback (most recent call last): 653 File "<stdin>", line 1, in <module> 654 File "/home/conmak/development/personal/type_enforced/type_enforced/enforcer.py", line 85, in __call__ 655 self.__check_type__(assigned_vars.get(key), value, key) 656 File "/home/conmak/development/personal/type_enforced/type_enforced/enforcer.py", line 107, in __check_type__ 657 self.__exception__( 658 File "/home/conmak/development/personal/type_enforced/type_enforced/enforcer.py", line 34, in __exception__ 659 raise Exception(f"({self.__fn__.__qualname__}): {message}") 660 Exception: (my_fn): Type mismatch for typed variable `a`. Expected one of the following `[<class 'int'>]` but got `<class 'str'>` instead. 661 ``` 662 """ 663 if not hasattr(clsFnMethod, "__type_enforced_enabled__"): 664 # Special try except clause to handle cases when the object is immutable 665 try: 666 clsFnMethod.__type_enforced_enabled__ = enabled 667 except: 668 return clsFnMethod 669 if not clsFnMethod.__type_enforced_enabled__: 670 return clsFnMethod 671 if isinstance( 672 clsFnMethod, (staticmethod, classmethod, FunctionType, MethodType) 673 ): 674 # Only apply the enforcer if type_hints are present 675 # Add try except clause to better handle forward refs. 676 try: 677 if get_type_hints(clsFnMethod) == {}: 678 return clsFnMethod 679 except: 680 pass 681 if isinstance(clsFnMethod, staticmethod): 682 return staticmethod( 683 FunctionMethodEnforcer( 684 __fn__=clsFnMethod.__func__, 685 __strict__=strict, 686 __clean_traceback__=clean_traceback, 687 __iterable_sample_pct__=iterable_sample_pct, 688 ) 689 ) 690 elif isinstance(clsFnMethod, classmethod): 691 return classmethod( 692 FunctionMethodEnforcer( 693 __fn__=clsFnMethod.__func__, 694 __strict__=strict, 695 __clean_traceback__=clean_traceback, 696 __iterable_sample_pct__=iterable_sample_pct, 697 ) 698 ) 699 else: 700 return FunctionMethodEnforcer( 701 __fn__=clsFnMethod, 702 __strict__=strict, 703 __clean_traceback__=clean_traceback, 704 __iterable_sample_pct__=iterable_sample_pct, 705 ) 706 elif hasattr(clsFnMethod, "__dict__"): 707 for key, value in clsFnMethod.__dict__.items(): 708 # Skip the __annotate__ method if present in __dict__ as it deletes itself upon invocation 709 # Skip any previously wrapped methods if they are already a FunctionMethodEnforcer 710 if key == "__annotate__" or isinstance( 711 value, FunctionMethodEnforcer 712 ): 713 continue 714 if hasattr(value, "__call__") or isinstance( 715 value, (classmethod, staticmethod) 716 ): 717 setattr( 718 clsFnMethod, 719 key, 720 Enforcer( 721 value, 722 enabled=enabled, 723 strict=strict, 724 clean_traceback=clean_traceback, 725 iterable_sample_pct=iterable_sample_pct, 726 ), 727 ) 728 return clsFnMethod 729 else: 730 raise Exception( 731 "Enforcer can only be used on classes, methods, or functions." 732 )
class
FunctionMethodEnforcer:
25class FunctionMethodEnforcer: 26 __slots__ = ( 27 "__fn__", 28 "__strict__", 29 "__clean_traceback__", 30 "__iterable_sample_pct__", 31 "__outer_self__", 32 "__fn_defaults__", 33 "__fn_varnames__", 34 "__types_parsed__", 35 "__checkable_types__", 36 "__return_type__", 37 "__simple_types__", 38 "__complex_types__", 39 "__simple_return_type__", 40 "__param_indices__", 41 "__flat_subtypes__", 42 "__wrapped__", 43 "__name__", 44 "__qualname__", 45 "__doc__", 46 "__dict__", 47 ) 48 49 def __init__( 50 self, 51 __fn__, 52 __strict__=False, 53 __clean_traceback__=True, 54 __iterable_sample_pct__=100, 55 ): 56 """ 57 Initialize a FunctionMethodEnforcer class object as a wrapper for a passed function `__fn__`. 58 59 Requires: 60 61 - `__fn__`: 62 - What: The function to enforce 63 - Type: function | method | class 64 65 Optional: 66 67 - `__strict__`: 68 - What: A boolean to enable or disable exceptions. If True, exceptions will be raised 69 when type checking fails. If False, exceptions will not be raised but instead a warning 70 will be printed to the console. 71 - Type: bool 72 - Default: False 73 - `__clean_traceback__`: 74 - What: A boolean to enable or disable cleaning of tracebacks when raising exceptions. 75 - Type: bool 76 - Default: True 77 - `__iterable_sample_pct__`: 78 - What: The percentage of items to sample when validating iterables. If 100, all items 79 are validated. If less than 100, the first and last items are always validated 80 plus a random sample of the remaining items up to the specified percentage. 81 - Type: int | float 82 - Default: 100 83 """ 84 update_wrapper(self, __fn__) 85 self.__fn__ = __fn__ 86 self.__strict__ = __strict__ 87 self.__clean_traceback__ = __clean_traceback__ 88 self.__iterable_sample_pct__ = __iterable_sample_pct__ 89 self.__outer_self__ = None 90 self.__types_parsed__ = False 91 self.__flat_subtypes__ = {} 92 # Validate that the passed function or method is a method or function 93 self.__check_method_function__() 94 # Get input defaults for the function or method 95 self.__get_defaults__() 96 97 def __get_defaults__(self): 98 """ 99 Get the default values of the passed function or method and store them in `self.__fn_defaults__`. 100 Also caches the function's variable names for use in `__call__`. 101 """ 102 self.__fn_varnames__ = self.__fn__.__code__.co_varnames 103 self.__fn_defaults__ = {} 104 if self.__fn__.__defaults__ is not None: 105 # Get the names of all provided default values for args 106 default_varnames = list(self.__fn_varnames__)[ 107 : self.__fn__.__code__.co_argcount 108 ][-len(self.__fn__.__defaults__) :] 109 # Update the output dictionary with the default values 110 self.__fn_defaults__.update( 111 dict(zip(default_varnames, self.__fn__.__defaults__)) 112 ) 113 if self.__fn__.__kwdefaults__ is not None: 114 # Update the output dictionary with the keyword default values 115 self.__fn_defaults__.update(self.__fn__.__kwdefaults__) 116 117 def __get_sample_indices__(self, length): 118 """ 119 Get a sorted list of indices to sample for iterable validation. 120 121 If iterable_sample_pct is 0, only the first item (index 0) is checked. 122 Otherwise, always includes the first (0) and last (length-1) indices. 123 If length > 3, samples additional middle indices up to the 124 iterable_sample_pct percentage. 125 Only called when self.__iterable_sample_pct__ < 100. 126 """ 127 if length == 0: 128 return [] 129 if self.__iterable_sample_pct__ == 0: 130 return [0] 131 if length <= 3: 132 return range(length) 133 n = max(3, int(length * self.__iterable_sample_pct__ / 100)) 134 if n >= length: 135 return range(length) 136 middle_sample = random.sample(range(1, length - 1), n - 2) 137 return sorted([0] + middle_sample + [length - 1]) 138 139 def __get_sample_keys__(self, keys): 140 """ 141 Get a sampled list of dict keys for iterable validation. 142 143 If iterable_sample_pct is 0, only the first key is returned. 144 Otherwise, always includes the first key plus a random sample of 145 the remaining keys up to the iterable_sample_pct percentage. 146 Only called when self.__iterable_sample_pct__ < 100. 147 """ 148 if len(keys) == 0: 149 return [] 150 if self.__iterable_sample_pct__ == 0 or len(keys) == 1: 151 return [keys[0]] 152 n = max(1, int(len(keys) * self.__iterable_sample_pct__ / 100)) 153 if n >= len(keys): 154 return keys 155 return [keys[0]] + random.sample(keys[1:], n - 1) 156 157 def __get_checkable_types__(self): 158 """ 159 Creates two class attributes: 160 161 - `self.__checkable_types__`: 162 - What: A dictionary of all annotations as checkable types 163 - Type: dict 164 165 - `self.__return_type__`: 166 - What: The return type of the function or method 167 - Type: dict | None 168 """ 169 if not self.__types_parsed__: 170 self.__checkable_types__ = { 171 key: self.__get_checkable_type__(value) 172 for key, value in get_type_hints(self.__fn__).items() 173 } 174 self.__return_type__ = self.__checkable_types__.pop("return", None) 175 # Classify params: simple types can use a single 176 # isinstance call, skipping __check_type__ entirely. 177 self.__simple_types__ = {} 178 self.__complex_types__ = {} 179 for key, expected in self.__checkable_types__.items(): 180 if ( 181 "__extra__" not in expected 182 and all(v is None for v in expected.values()) 183 and all(isinstance(k, type) for k in expected.keys()) 184 ): 185 self.__simple_types__[key] = tuple(expected.keys()) 186 else: 187 self.__complex_types__[key] = expected 188 # Same classification for return type 189 if self.__return_type__ is not None and ( 190 "__extra__" not in self.__return_type__ 191 and all(v is None for v in self.__return_type__.values()) 192 and all( 193 isinstance(k, type) for k in self.__return_type__.keys() 194 ) 195 ): 196 self.__simple_return_type__ = tuple(self.__return_type__.keys()) 197 else: 198 self.__simple_return_type__ = None 199 # Pre-compute param index in co_varnames for 200 # direct arg lookup (skips assigned_vars dict). 201 self.__param_indices__ = { 202 name: i 203 for i, name in enumerate(self.__fn_varnames__) 204 if name in self.__checkable_types__ 205 } 206 self.__types_parsed__ = True 207 208 def __get_checkable_type__(self, annotation): 209 """ 210 Parses a type annotation and returns a nested dict structure 211 representing the checkable type(s) for validation. 212 """ 213 214 if annotation is None: 215 return {_NoneType: None} 216 217 # Handle `int | str` syntax (Python 3.10+) and Unions 218 if ( 219 isinstance(annotation, UnionType) 220 or getattr(annotation, "__origin__", None) == Union 221 ): 222 combined_types = {} 223 for sub_type in annotation.__args__: 224 merge_type_dicts( 225 combined_types, 226 self.__get_checkable_type__(sub_type), 227 ) 228 return combined_types 229 230 # Handle typing.Literal 231 if getattr(annotation, "__origin__", None) == Literal: 232 return {"__extra__": {"__literal__": list(annotation.__args__)}} 233 234 # Handle generic collections 235 origin = getattr(annotation, "__origin__", None) 236 args = getattr(annotation, "__args__", ()) 237 238 if origin == list: 239 if len(args) != 1: 240 self.__exception__( 241 f"List must have a single type argument, got: {args}", 242 raise_exception=True, 243 ) 244 return {list: self.__get_checkable_type__(args[0])} 245 246 if origin == dict: 247 if len(args) != 2: 248 self.__exception__( 249 f"Dict must have two type arguments, got: {args}", 250 raise_exception=True, 251 ) 252 key_type = self.__get_checkable_type__(args[0]) 253 value_type = self.__get_checkable_type__(args[1]) 254 return {dict: (key_type, value_type)} 255 256 if origin == tuple: 257 if len(args) > 2 or len(args) == 1: 258 if Ellipsis in args: 259 self.__exception__( 260 "Tuple with Ellipsis must have exactly two type arguments and the second must be Ellipsis.", 261 raise_exception=True, 262 ) 263 if len(args) == 2: 264 if args[0] is Ellipsis: 265 self.__exception__( 266 "Tuple with Ellipsis must have exactly two type arguments and the first must not be Ellipsis.", 267 raise_exception=True, 268 ) 269 if args[1] is Ellipsis: 270 return {tuple: (self.__get_checkable_type__(args[0]), True)} 271 return { 272 tuple: ( 273 tuple(self.__get_checkable_type__(arg) for arg in args), 274 False, 275 ) 276 } 277 278 if origin == set: 279 if len(args) != 1: 280 self.__exception__( 281 f"Set must have a single type argument, got: {args}", 282 raise_exception=True, 283 ) 284 return {set: self.__get_checkable_type__(args[0])} 285 286 # Handle Sized types 287 if annotation == Sized: 288 return { 289 list: None, 290 tuple: None, 291 dict: None, 292 set: None, 293 str: None, 294 bytes: None, 295 bytearray: None, 296 memoryview: None, 297 range: None, 298 } 299 300 # Handle Callable types 301 if annotation == Callable: 302 return { 303 staticmethod: None, 304 classmethod: None, 305 FunctionType: None, 306 BuiltinFunctionType: None, 307 MethodType: None, 308 BuiltinMethodType: None, 309 GeneratorType: None, 310 } 311 312 if annotation == Any: 313 return { 314 object: None, 315 } 316 317 # Handle Constraints 318 if isinstance(annotation, GenericConstraint): 319 return {"__extra__": {"__constraints__": [annotation]}} 320 321 # Handle standard types 322 if isinstance(annotation, type): 323 return {annotation: None} 324 325 # Hanldle typing.Type (for uninitialized classes) 326 if origin is type and len(args) == 1: 327 return {annotation: None} 328 329 self.__exception__( 330 f"Unsupported type hint: {annotation}", raise_exception=True 331 ) 332 333 def __exception__(self, message, raise_exception=False): 334 """ 335 Usage: 336 337 - Creates a class based exception message 338 339 Requires: 340 341 - `message`: 342 - Type: str 343 - What: The message to warn users with 344 345 Optional: 346 347 - `raise_exception`: 348 - Type: bool 349 - What: Forces an exception to be raised regardless of the `self.__strict__` setting. 350 - Default: False 351 """ 352 if self.__strict__ or raise_exception: 353 msg = f"TypeEnforced Exception ({self.__fn__.__qualname__}): {message}" 354 if self.__clean_traceback__: 355 package_path = _package_path 356 frame = sys._getframe() 357 relevant_tb_count = 0 358 while frame is not None: 359 frame_file = Path(frame.f_code.co_filename).resolve() 360 try: 361 frame_file.relative_to(package_path) 362 except ValueError: 363 relevant_tb_count += 1 364 frame = frame.f_back 365 original_excepthook = sys.excepthook 366 367 def excepthook(type, value, tb): 368 traceback.print_exception( 369 type, value, tb, limit=relevant_tb_count 370 ) 371 sys.excepthook = original_excepthook 372 373 sys.excepthook = excepthook 374 raise TypeError(msg) 375 else: 376 print( 377 f"TypeEnforced Warning ({self.__fn__.__qualname__}): {message}" 378 ) 379 380 def __get__(self, obj, objtype): 381 """ 382 Overwrite standard __get__ method to return __call__ instead for wrapped class methods. 383 384 Also stores the calling (__get__) `obj` to be passed as an initial argument for `__call__` such that methods can pass `self` correctly. 385 """ 386 self.__outer_self__ = obj 387 388 def __get_fn__(*args, **kwargs): 389 return self.__call__(*args, **kwargs) 390 391 __get_fn__.__name__ = self.__fn__.__name__ 392 __get_fn__.__qualname__ = self.__fn__.__qualname__ 393 __get_fn__.__doc__ = self.__fn__.__doc__ 394 return __get_fn__ 395 396 def __check_method_function__(self): 397 """ 398 Validate that `self.__fn__` is a method or function 399 """ 400 if not isinstance(self.__fn__, (MethodType, FunctionType)): 401 raise Exception( 402 f"A non function/method was passed to Enforcer. See the stack trace above for more information." 403 ) 404 405 def __call__(self, *args, **kwargs): 406 """ 407 This method is used to validate the passed inputs and return the output of the wrapped function or method. 408 """ 409 # Special code to pass self as an initial argument 410 # for validation purposes in methods 411 # See: self.__get__ 412 if self.__outer_self__ is not None: 413 args = (self.__outer_self__, *args) 414 # Get a dictionary of all annotations as checkable types 415 # Note: This is only done once at first call to avoid redundant calculations 416 self.__get_checkable_types__() 417 # Fast path: simple types use direct index lookup 418 for key, types_tuple in self.__simple_types__.items(): 419 idx = self.__param_indices__[key] 420 if idx < len(args): 421 obj = args[idx] 422 elif key in kwargs: 423 obj = kwargs[key] 424 else: 425 obj = self.__fn_defaults__.get(key) 426 if not isinstance(obj, types_tuple): 427 # Fall back to full check for error reporting 428 self.__check_type__(obj, self.__checkable_types__[key], key) 429 # Full validation for complex types (nested, extras, Type[X]) 430 if self.__complex_types__: 431 assigned_vars = { 432 **self.__fn_defaults__, 433 **dict(zip(self.__fn_varnames__[: len(args)], args)), 434 **kwargs, 435 } 436 for key, value in self.__complex_types__.items(): 437 self.__check_type__(assigned_vars.get(key), value, key) 438 # Execute the function callable 439 return_value = self.__fn__(*args, **kwargs) 440 # If a return type was passed, validate the returned object 441 if self.__return_type__ is not None: 442 if self.__simple_return_type__ is not None: 443 if not isinstance(return_value, self.__simple_return_type__): 444 self.__check_type__( 445 return_value, self.__return_type__, "return" 446 ) 447 else: 448 self.__check_type__( 449 return_value, self.__return_type__, "return" 450 ) 451 return return_value 452 453 def __quick_check__(self, subtype, obj): 454 subtype_id = id(subtype) 455 if subtype_id not in self.__flat_subtypes__: 456 # First call for this subtype: compute and cache 457 if all(v is None for v in subtype.values()): 458 self.__flat_subtypes__[subtype_id] = frozenset(subtype.keys()) 459 else: 460 self.__flat_subtypes__[subtype_id] = None 461 flat_keys = self.__flat_subtypes__[subtype_id] 462 if flat_keys is not None: 463 values = {type(v) for v in obj} 464 if values.issubset(flat_keys): 465 return True 466 return False 467 468 def __check_type__(self, obj, expected, key): 469 """ 470 Raises an exception the type of a passed `obj` (parameter) is not in the list of supplied `acceptable_types` for the argument. 471 """ 472 # Special case for None 473 if obj is None and _NoneType in expected: 474 return 475 if "__extra__" in expected: 476 extra = expected["__extra__"] 477 expected = {k: v for k, v in expected.items() if k != "__extra__"} 478 else: 479 extra = None 480 481 if isinstance(obj, type): 482 # An uninitialized class is passed, we need to check if the type is in the expected types using Type[obj] 483 obj_type = Type[obj] 484 is_present = obj_type in expected 485 else: 486 obj_type = type(obj) 487 is_present = obj_type in expected or isinstance( 488 obj, tuple(expected.keys()) 489 ) 490 491 if not is_present: 492 # Allow for literals to be used to bypass type checks if present 493 literal = extra.get("__literal__", ()) if extra is not None else () 494 if literal: 495 if obj not in literal: 496 self.__exception__( 497 f"Type mismatch for typed variable `{key}`. Expected one of the following `{list(expected.keys())}` or a literal value in `{literal}` but got type `{obj_type}` with value `{obj}` instead." 498 ) 499 # Raise an exception if the type is not in the expected types 500 else: 501 self.__exception__( 502 f"Type mismatch for typed variable `{key}`. Expected one of the following `{list(expected.keys())}` but got `{obj_type}` with value `{obj}` instead." 503 ) 504 # If the object_type is in the expected types, we can proceed with validation 505 elif obj_type in iterable_types: 506 subtype = expected.get(obj_type, None) 507 if subtype is None: 508 pass 509 # Recursive validation 510 elif obj_type == list: 511 if self.__iterable_sample_pct__ < 100: 512 for idx in self.__get_sample_indices__(len(obj)): 513 self.__check_type__(obj[idx], subtype, f"{key}[{idx}]") 514 # If the subtype does not contain iterables with typing, we can validate the items directly. 515 elif not self.__quick_check__(subtype, obj): 516 for idx, item in enumerate(obj): 517 self.__check_type__(item, subtype, f"{key}[{idx}]") 518 elif obj_type == dict: 519 key_type, val_type = subtype 520 if self.__iterable_sample_pct__ < 100: 521 sampled_keys = self.__get_sample_keys__(list(obj.keys())) 522 if not self.__quick_check__(key_type, sampled_keys): 523 for dk in sampled_keys: 524 self.__check_type__( 525 dk, key_type, f"{key}.key[{repr(dk)}]" 526 ) 527 if not self.__quick_check__( 528 val_type, [obj[dk] for dk in sampled_keys] 529 ): 530 for dk in sampled_keys: 531 self.__check_type__( 532 obj[dk], val_type, f"{key}[{repr(dk)}]" 533 ) 534 else: 535 if not self.__quick_check__(key_type, obj.keys()): 536 for key in obj.keys(): 537 self.__check_type__( 538 key, key_type, f"{key}.key[{repr(key)}]" 539 ) 540 if not self.__quick_check__(val_type, obj.values()): 541 for key, value in obj.items(): 542 self.__check_type__( 543 value, val_type, f"{key}[{repr(key)}]" 544 ) 545 elif obj_type == tuple: 546 expected_args, is_ellipsis = subtype 547 if is_ellipsis: 548 if self.__iterable_sample_pct__ < 100: 549 for idx in self.__get_sample_indices__(len(obj)): 550 self.__check_type__( 551 obj[idx], expected_args, f"{key}[{idx}]" 552 ) 553 elif not self.__quick_check__(expected_args, obj): 554 for idx, item in enumerate(obj): 555 self.__check_type__( 556 item, expected_args, f"{key}[{idx}]" 557 ) 558 else: 559 if len(obj) != len(expected_args): 560 self.__exception__( 561 f"Tuple length mismatch for `{key}`. Expected length {len(expected_args)}, got {len(obj)}" 562 ) 563 for idx, (item, ex) in enumerate(zip(obj, expected_args)): 564 self.__check_type__(item, ex, f"{key}[{idx}]") 565 elif obj_type == set: 566 if self.__iterable_sample_pct__ < 100: 567 obj_list = list(obj) 568 for idx in self.__get_sample_indices__(len(obj_list)): 569 item = obj_list[idx] 570 self.__check_type__( 571 item, subtype, f"{key}[{repr(item)}]" 572 ) 573 elif not self.__quick_check__(subtype, obj): 574 for item in obj: 575 self.__check_type__( 576 item, subtype, f"{key}[{repr(item)}]" 577 ) 578 579 # Validate constraints if any are present 580 if extra is not None: 581 constraints = extra.get("__constraints__", ()) 582 for constraint in constraints: 583 constraint_validation_output = constraint.__validate__(key, obj) 584 if constraint_validation_output is not True: 585 self.__exception__( 586 f"Constraint validation error for variable `{key}` with value `{obj}`. {constraint_validation_output}" 587 ) 588 589 def __repr__(self): 590 return f"<type_enforced {self.__fn__.__module__}.{self.__fn__.__qualname__} object at {hex(id(self))}>"
FunctionMethodEnforcer( __fn__, __strict__=False, __clean_traceback__=True, __iterable_sample_pct__=100)
49 def __init__( 50 self, 51 __fn__, 52 __strict__=False, 53 __clean_traceback__=True, 54 __iterable_sample_pct__=100, 55 ): 56 """ 57 Initialize a FunctionMethodEnforcer class object as a wrapper for a passed function `__fn__`. 58 59 Requires: 60 61 - `__fn__`: 62 - What: The function to enforce 63 - Type: function | method | class 64 65 Optional: 66 67 - `__strict__`: 68 - What: A boolean to enable or disable exceptions. If True, exceptions will be raised 69 when type checking fails. If False, exceptions will not be raised but instead a warning 70 will be printed to the console. 71 - Type: bool 72 - Default: False 73 - `__clean_traceback__`: 74 - What: A boolean to enable or disable cleaning of tracebacks when raising exceptions. 75 - Type: bool 76 - Default: True 77 - `__iterable_sample_pct__`: 78 - What: The percentage of items to sample when validating iterables. If 100, all items 79 are validated. If less than 100, the first and last items are always validated 80 plus a random sample of the remaining items up to the specified percentage. 81 - Type: int | float 82 - Default: 100 83 """ 84 update_wrapper(self, __fn__) 85 self.__fn__ = __fn__ 86 self.__strict__ = __strict__ 87 self.__clean_traceback__ = __clean_traceback__ 88 self.__iterable_sample_pct__ = __iterable_sample_pct__ 89 self.__outer_self__ = None 90 self.__types_parsed__ = False 91 self.__flat_subtypes__ = {} 92 # Validate that the passed function or method is a method or function 93 self.__check_method_function__() 94 # Get input defaults for the function or method 95 self.__get_defaults__()
Initialize a FunctionMethodEnforcer class object as a wrapper for a passed function __fn__.
Requires:
- `__fn__`:
- What: The function to enforce
- Type: function | method | class
Optional:
- `__strict__`:
- What: A boolean to enable or disable exceptions. If True, exceptions will be raised
when type checking fails. If False, exceptions will not be raised but instead a warning
will be printed to the console.
- Type: bool
- Default: False
- `__clean_traceback__`:
- What: A boolean to enable or disable cleaning of tracebacks when raising exceptions.
- Type: bool
- Default: True
- `__iterable_sample_pct__`:
- What: The percentage of items to sample when validating iterables. If 100, all items
are validated. If less than 100, the first and last items are always validated
plus a random sample of the remaining items up to the specified percentage.
- Type: int | float
- Default: 100
@Partial
def
Enforcer( clsFnMethod, enabled=True, strict=True, clean_traceback=True, iterable_sample_pct=100):
593@Partial 594def Enforcer( 595 clsFnMethod, 596 enabled=True, 597 strict=True, 598 clean_traceback=True, 599 iterable_sample_pct=100, 600): 601 """ 602 A wrapper to enforce types within a function or method given argument annotations. 603 604 Each wrapped item is converted into a special `FunctionMethodEnforcer` class object that validates the passed parameters for the function or method when it is called. If a function or method that is passed does not have any annotations, it is not converted into a `FunctionMethodEnforcer` class as no validation is possible. 605 606 If wrapping a class, all methods in the class that meet any of the following criteria will be wrapped individually: 607 608 - Methods with `__call__` 609 - Methods wrapped with `staticmethod` (if python >= 3.10) 610 - Methods wrapped with `classmethod` (if python >= 3.10) 611 612 Requires: 613 614 - `clsFnMethod`: 615 - What: The class, function or method that should have input types enforced 616 - Type: function | method | class 617 618 Optional: 619 620 - `enabled`: 621 - What: A boolean to enable or disable the enforcer 622 - Type: bool 623 - Default: True 624 - `strict`: 625 - What: A boolean to enable or disable exceptions. If True, exceptions will be raised when type checking fails. If False, exceptions will not be raised but instead a warning will be printed to the console. 626 - Type: bool 627 - Default: False 628 - Note: Type hints that are wrapped with the type enforcer and are invalid will still raise an exception. 629 - `clean_traceback`: 630 - What: A boolean to enable or disable cleaning of tracebacks when raising exceptions. 631 - If True, modifies the excepthook temporarily such that only the relevant stack (not in the type_enforced package) is shown. 632 - Type: bool 633 - Default: True 634 - `iterable_sample_pct`: 635 - What: The percentage (0-100) of items to validate when checking typed iterables (list, 636 dict, set, variable-length tuple). At 100 (default) every item is checked. Below 100, 637 the first and last items are always checked; if the collection has more than 3 items, 638 additional items are randomly sampled so that the total checked is at least 3. 639 - Type: int | float 640 - Default: 100 641 642 643 Example Use: 644 ``` 645 >>> import type_enforced 646 >>> @type_enforced.Enforcer 647 ... def my_fn(a: int , b: [int, str] =2, c: int =3) -> None: 648 ... pass 649 ... 650 >>> my_fn(a=1, b=2, c=3) 651 >>> my_fn(a=1, b='2', c=3) 652 >>> my_fn(a='a', b=2, c=3) 653 Traceback (most recent call last): 654 File "<stdin>", line 1, in <module> 655 File "/home/conmak/development/personal/type_enforced/type_enforced/enforcer.py", line 85, in __call__ 656 self.__check_type__(assigned_vars.get(key), value, key) 657 File "/home/conmak/development/personal/type_enforced/type_enforced/enforcer.py", line 107, in __check_type__ 658 self.__exception__( 659 File "/home/conmak/development/personal/type_enforced/type_enforced/enforcer.py", line 34, in __exception__ 660 raise Exception(f"({self.__fn__.__qualname__}): {message}") 661 Exception: (my_fn): Type mismatch for typed variable `a`. Expected one of the following `[<class 'int'>]` but got `<class 'str'>` instead. 662 ``` 663 """ 664 if not hasattr(clsFnMethod, "__type_enforced_enabled__"): 665 # Special try except clause to handle cases when the object is immutable 666 try: 667 clsFnMethod.__type_enforced_enabled__ = enabled 668 except: 669 return clsFnMethod 670 if not clsFnMethod.__type_enforced_enabled__: 671 return clsFnMethod 672 if isinstance( 673 clsFnMethod, (staticmethod, classmethod, FunctionType, MethodType) 674 ): 675 # Only apply the enforcer if type_hints are present 676 # Add try except clause to better handle forward refs. 677 try: 678 if get_type_hints(clsFnMethod) == {}: 679 return clsFnMethod 680 except: 681 pass 682 if isinstance(clsFnMethod, staticmethod): 683 return staticmethod( 684 FunctionMethodEnforcer( 685 __fn__=clsFnMethod.__func__, 686 __strict__=strict, 687 __clean_traceback__=clean_traceback, 688 __iterable_sample_pct__=iterable_sample_pct, 689 ) 690 ) 691 elif isinstance(clsFnMethod, classmethod): 692 return classmethod( 693 FunctionMethodEnforcer( 694 __fn__=clsFnMethod.__func__, 695 __strict__=strict, 696 __clean_traceback__=clean_traceback, 697 __iterable_sample_pct__=iterable_sample_pct, 698 ) 699 ) 700 else: 701 return FunctionMethodEnforcer( 702 __fn__=clsFnMethod, 703 __strict__=strict, 704 __clean_traceback__=clean_traceback, 705 __iterable_sample_pct__=iterable_sample_pct, 706 ) 707 elif hasattr(clsFnMethod, "__dict__"): 708 for key, value in clsFnMethod.__dict__.items(): 709 # Skip the __annotate__ method if present in __dict__ as it deletes itself upon invocation 710 # Skip any previously wrapped methods if they are already a FunctionMethodEnforcer 711 if key == "__annotate__" or isinstance( 712 value, FunctionMethodEnforcer 713 ): 714 continue 715 if hasattr(value, "__call__") or isinstance( 716 value, (classmethod, staticmethod) 717 ): 718 setattr( 719 clsFnMethod, 720 key, 721 Enforcer( 722 value, 723 enabled=enabled, 724 strict=strict, 725 clean_traceback=clean_traceback, 726 iterable_sample_pct=iterable_sample_pct, 727 ), 728 ) 729 return clsFnMethod 730 else: 731 raise Exception( 732 "Enforcer can only be used on classes, methods, or functions." 733 )
A wrapper to enforce types within a function or method given argument annotations.
Each wrapped item is converted into a special FunctionMethodEnforcer class object that validates the passed parameters for the function or method when it is called. If a function or method that is passed does not have any annotations, it is not converted into a FunctionMethodEnforcer class as no validation is possible.
If wrapping a class, all methods in the class that meet any of the following criteria will be wrapped individually:
- Methods with
__call__ - Methods wrapped with
staticmethod(if python >= 3.10) - Methods wrapped with
classmethod(if python >= 3.10)
Requires:
clsFnMethod:- What: The class, function or method that should have input types enforced
- Type: function | method | class
Optional:
enabled:- What: A boolean to enable or disable the enforcer
- Type: bool
- Default: True
strict:- What: A boolean to enable or disable exceptions. If True, exceptions will be raised when type checking fails. If False, exceptions will not be raised but instead a warning will be printed to the console.
- Type: bool
- Default: False
- Note: Type hints that are wrapped with the type enforcer and are invalid will still raise an exception.
clean_traceback:- What: A boolean to enable or disable cleaning of tracebacks when raising exceptions.
- If True, modifies the excepthook temporarily such that only the relevant stack (not in the type_enforced package) is shown.
- Type: bool
- Default: True
iterable_sample_pct:- What: The percentage (0-100) of items to validate when checking typed iterables (list, dict, set, variable-length tuple). At 100 (default) every item is checked. Below 100, the first and last items are always checked; if the collection has more than 3 items, additional items are randomly sampled so that the total checked is at least 3.
- Type: int | float
- Default: 100
Example Use:
>>> import type_enforced
>>> @type_enforced.Enforcer
... def my_fn(a: int , b: [int, str] =2, c: int =3) -> None:
... pass
...
>>> my_fn(a=1, b=2, c=3)
>>> my_fn(a=1, b='2', c=3)
>>> my_fn(a='a', b=2, c=3)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/home/conmak/development/personal/type_enforced/type_enforced/enforcer.py", line 85, in __call__
self.__check_type__(assigned_vars.get(key), value, key)
File "/home/conmak/development/personal/type_enforced/type_enforced/enforcer.py", line 107, in __check_type__
self.__exception__(
File "/home/conmak/development/personal/type_enforced/type_enforced/enforcer.py", line 34, in __exception__
raise Exception(f"({self.__fn__.__qualname__}): {message}")
Exception: (my_fn): Type mismatch for typed variable `a`. Expected one of the following `[<class 'int'>]` but got `<class 'str'>` instead.