type_enforced.enforcer
1from types import ( 2 FunctionType, 3 MethodType, 4 GeneratorType, 5 BuiltinFunctionType, 6 BuiltinMethodType, 7 UnionType, 8) 9from typing import Type, Union, Sized, Literal, Callable, get_type_hints, Any 10from functools import update_wrapper, wraps 11from type_enforced.utils import ( 12 Partial, 13 GenericConstraint, 14 DeepMerge, 15 iterable_types, 16) 17import sys, traceback 18from pathlib import Path 19 20 21class FunctionMethodEnforcer: 22 def __init__(self, __fn__, __strict__=False, __clean_traceback__=True): 23 """ 24 Initialize a FunctionMethodEnforcer class object as a wrapper for a passed function `__fn__`. 25 26 Requires: 27 28 - `__fn__`: 29 - What: The function to enforce 30 - Type: function | method | class 31 32 Optional: 33 34 - `__strict__`: 35 - What: A boolean to enable or disable exceptions. If True, exceptions will be raised 36 when type checking fails. If False, exceptions will not be raised but instead a warning 37 will be printed to the console. 38 - Type: bool 39 - Default: False 40 - `__clean_traceback__`: 41 - What: A boolean to enable or disable cleaning of tracebacks when raising exceptions. 42 - Type: bool 43 - Default: True 44 """ 45 update_wrapper(self, __fn__) 46 self.__fn__ = __fn__ 47 self.__strict__ = __strict__ 48 self.__clean_traceback__ = __clean_traceback__ 49 self.__outer_self__ = None 50 # Validate that the passed function or method is a method or function 51 self.__check_method_function__() 52 # Get input defaults for the function or method 53 self.__get_defaults__() 54 55 def __get_defaults__(self): 56 """ 57 Get the default values of the passed function or method and store them in `self.__fn_defaults__`. 58 """ 59 self.__fn_defaults__ = {} 60 if self.__fn__.__defaults__ is not None: 61 # Get the names of all provided default values for args 62 default_varnames = list(self.__fn__.__code__.co_varnames)[ 63 : self.__fn__.__code__.co_argcount 64 ][-len(self.__fn__.__defaults__) :] 65 # Update the output dictionary with the default values 66 self.__fn_defaults__.update( 67 dict(zip(default_varnames, self.__fn__.__defaults__)) 68 ) 69 if self.__fn__.__kwdefaults__ is not None: 70 # Update the output dictionary with the keyword default values 71 self.__fn_defaults__.update(self.__fn__.__kwdefaults__) 72 73 def __get_checkable_types__(self): 74 """ 75 Creates two class attributes: 76 77 - `self.__checkable_types__`: 78 - What: A dictionary of all annotations as checkable types 79 - Type: dict 80 81 - `self.__return_type__`: 82 - What: The return type of the function or method 83 - Type: dict | None 84 """ 85 if not hasattr(self, "__checkable_types__"): 86 self.__checkable_types__ = { 87 key: self.__get_checkable_type__(value) 88 for key, value in get_type_hints(self.__fn__).items() 89 } 90 self.__return_type__ = self.__checkable_types__.pop("return", None) 91 92 def __get_checkable_type__(self, annotation): 93 """ 94 Parses a type annotation and returns a nested dict structure 95 representing the checkable type(s) for validation. 96 """ 97 98 if annotation is None: 99 return {type(None): None} 100 101 # Handle `int | str` syntax (Python 3.10+) and Unions 102 if ( 103 isinstance(annotation, UnionType) 104 or getattr(annotation, "__origin__", None) == Union 105 ): 106 combined_types = {} 107 for sub_type in annotation.__args__: 108 combined_types = DeepMerge( 109 combined_types, self.__get_checkable_type__(sub_type) 110 ) 111 return combined_types 112 113 # Handle typing.Literal 114 if getattr(annotation, "__origin__", None) == Literal: 115 return {"__extra__": {"__literal__": list(annotation.__args__)}} 116 117 # Handle generic collections 118 origin = getattr(annotation, "__origin__", None) 119 args = getattr(annotation, "__args__", ()) 120 121 if origin == list: 122 if len(args) != 1: 123 self.__exception__( 124 f"List must have a single type argument, got: {args}", 125 raise_exception=True, 126 ) 127 return {list: self.__get_checkable_type__(args[0])} 128 129 if origin == dict: 130 if len(args) != 2: 131 self.__exception__( 132 f"Dict must have two type arguments, got: {args}", 133 raise_exception=True, 134 ) 135 key_type = self.__get_checkable_type__(args[0]) 136 value_type = self.__get_checkable_type__(args[1]) 137 return {dict: (key_type, value_type)} 138 139 if origin == tuple: 140 if len(args) > 2 or len(args) == 1: 141 if Ellipsis in args: 142 self.__exception__( 143 "Tuple with Ellipsis must have exactly two type arguments and the second must be Ellipsis.", 144 raise_exception=True, 145 ) 146 if len(args) == 2: 147 if args[0] is Ellipsis: 148 self.__exception__( 149 "Tuple with Ellipsis must have exactly two type arguments and the first must not be Ellipsis.", 150 raise_exception=True, 151 ) 152 if args[1] is Ellipsis: 153 return {tuple: (self.__get_checkable_type__(args[0]), True)} 154 return { 155 tuple: ( 156 tuple(self.__get_checkable_type__(arg) for arg in args), 157 False, 158 ) 159 } 160 161 if origin == set: 162 if len(args) != 1: 163 self.__exception__( 164 f"Set must have a single type argument, got: {args}", 165 raise_exception=True, 166 ) 167 return {set: self.__get_checkable_type__(args[0])} 168 169 # Handle Sized types 170 if annotation == Sized: 171 return { 172 list: None, 173 tuple: None, 174 dict: None, 175 set: None, 176 str: None, 177 bytes: None, 178 bytearray: None, 179 memoryview: None, 180 range: None, 181 } 182 183 # Handle Callable types 184 if annotation == Callable: 185 return { 186 staticmethod: None, 187 classmethod: None, 188 FunctionType: None, 189 BuiltinFunctionType: None, 190 MethodType: None, 191 BuiltinMethodType: None, 192 GeneratorType: None, 193 } 194 195 if annotation == Any: 196 return { 197 object: None, 198 } 199 200 # Handle Constraints 201 if isinstance(annotation, GenericConstraint): 202 return {"__extra__": {"__constraints__": [annotation]}} 203 204 # Handle standard types 205 if isinstance(annotation, type): 206 return {annotation: None} 207 208 # Hanldle typing.Type (for uninitialized classes) 209 if origin is type and len(args) == 1: 210 return {annotation: None} 211 212 self.__exception__( 213 f"Unsupported type hint: {annotation}", raise_exception=True 214 ) 215 216 def __exception__(self, message, raise_exception=False): 217 """ 218 Usage: 219 220 - Creates a class based exception message 221 222 Requires: 223 224 - `message`: 225 - Type: str 226 - What: The message to warn users with 227 228 Optional: 229 230 - `raise_exception`: 231 - Type: bool 232 - What: Forces an exception to be raised regardless of the `self.__strict__` setting. 233 - Default: False 234 """ 235 if self.__strict__ or raise_exception: 236 msg = f"TypeEnforced Exception ({self.__fn__.__qualname__}): {message}" 237 if self.__clean_traceback__: 238 package_path = Path(__file__).parent.resolve() 239 frame = sys._getframe() 240 relevant_tb_count = 0 241 while frame is not None: 242 frame_file = Path(frame.f_code.co_filename).resolve() 243 try: 244 frame_file.relative_to(package_path) 245 except ValueError: 246 relevant_tb_count += 1 247 frame = frame.f_back 248 original_excepthook = sys.excepthook 249 250 def excepthook(type, value, tb): 251 traceback.print_exception( 252 type, value, tb, limit=relevant_tb_count 253 ) 254 sys.excepthook = original_excepthook 255 256 sys.excepthook = excepthook 257 raise TypeError(msg) 258 else: 259 print( 260 f"TypeEnforced Warning ({self.__fn__.__qualname__}): {message}" 261 ) 262 263 def __get__(self, obj, objtype): 264 """ 265 Overwrite standard __get__ method to return __call__ instead for wrapped class methods. 266 267 Also stores the calling (__get__) `obj` to be passed as an initial argument for `__call__` such that methods can pass `self` correctly. 268 """ 269 270 @wraps(self.__fn__) 271 def __get_fn__(*args, **kwargs): 272 return self.__call__(*args, **kwargs) 273 274 self.__outer_self__ = obj 275 return __get_fn__ 276 277 def __check_method_function__(self): 278 """ 279 Validate that `self.__fn__` is a method or function 280 """ 281 if not isinstance(self.__fn__, (MethodType, FunctionType)): 282 raise Exception( 283 f"A non function/method was passed to Enforcer. See the stack trace above for more information." 284 ) 285 286 def __call__(self, *args, **kwargs): 287 """ 288 This method is used to validate the passed inputs and return the output of the wrapped function or method. 289 """ 290 # Special code to pass self as an initial argument 291 # for validation purposes in methods 292 # See: self.__get__ 293 if self.__outer_self__ is not None: 294 args = (self.__outer_self__, *args) 295 # Get a dictionary of all annotations as checkable types 296 # Note: This is only done once at first call to avoid redundant calculations 297 self.__get_checkable_types__() 298 # Create a compreshensive dictionary of assigned variables (order matters) 299 assigned_vars = { 300 **self.__fn_defaults__, 301 **dict(zip(self.__fn__.__code__.co_varnames[: len(args)], args)), 302 **kwargs, 303 } 304 # Validate all listed annotations vs the assigned_vars dictionary 305 for key, value in self.__checkable_types__.items(): 306 self.__check_type__(assigned_vars.get(key), value, key) 307 # Execute the function callable 308 return_value = self.__fn__(*args, **kwargs) 309 # If a return type was passed, validate the returned object 310 if self.__return_type__ is not None: 311 self.__check_type__(return_value, self.__return_type__, "return") 312 return return_value 313 314 def __quick_check__(self, subtype, obj): 315 if all([v == None for v in subtype.values()]): 316 # If the subtype does not contain iterables with typing, we can validate the items directly. 317 types = set(subtype.keys()) 318 values = set([type(v) for v in obj]) 319 if values.issubset(types): 320 # We can return True to bypass the full validation 321 return True 322 # Otherwise, validation did not pass and a full validation is required to raise an indexed/keyed type mismatch error 323 return False 324 325 def __check_type__(self, obj, expected, key): 326 """ 327 Raises an exception the type of a passed `obj` (parameter) is not in the list of supplied `acceptable_types` for the argument. 328 """ 329 # Special case for None 330 if obj is None and type(None) in expected: 331 return 332 extra = expected.get("__extra__", {}) 333 expected = {k: v for k, v in expected.items() if k != "__extra__"} 334 335 if isinstance(obj, type): 336 # An uninitialized class is passed, we need to check if the type is in the expected types using Type[obj] 337 obj_type = Type[obj] 338 is_present = obj_type in expected 339 else: 340 obj_type = type(obj) 341 is_present = isinstance(obj, tuple(expected.keys())) 342 343 if not is_present: 344 # Allow for literals to be used to bypass type checks if present 345 literal = extra.get("__literal__", ()) 346 if literal: 347 if obj not in literal: 348 self.__exception__( 349 f"Type mismatch for typed variable `{key}`. Expected one of the following `{list(expected.keys())}` or a literal value in `{literal}` but got type `{obj_type}` with value `{obj}` instead." 350 ) 351 # Raise an exception if the type is not in the expected types 352 else: 353 self.__exception__( 354 f"Type mismatch for typed variable `{key}`. Expected one of the following `{list(expected.keys())}` but got `{obj_type}` with value `{obj}` instead." 355 ) 356 # If the object_type is in the expected types, we can proceed with validation 357 elif obj_type in iterable_types: 358 subtype = expected.get(obj_type, None) 359 if subtype is None: 360 pass 361 # Recursive validation 362 elif obj_type == list: 363 # If the subtype does not contain iterables with typing, we can validate the items directly. 364 if not self.__quick_check__(subtype, obj): 365 for idx, item in enumerate(obj): 366 self.__check_type__(item, subtype, f"{key}[{idx}]") 367 elif obj_type == dict: 368 key_type, val_type = subtype 369 if not self.__quick_check__(key_type, obj.keys()): 370 for key in obj.keys(): 371 self.__check_type__( 372 key, key_type, f"{key}.key[{repr(key)}]" 373 ) 374 if not self.__quick_check__(val_type, obj.values()): 375 for key, value in obj.items(): 376 self.__check_type__( 377 value, val_type, f"{key}[{repr(key)}]" 378 ) 379 elif obj_type == tuple: 380 expected_args, is_ellipsis = subtype 381 if is_ellipsis: 382 if not self.__quick_check__(expected_args, obj): 383 for idx, item in enumerate(obj): 384 self.__check_type__( 385 item, expected_args, f"{key}[{idx}]" 386 ) 387 else: 388 if len(obj) != len(expected_args): 389 self.__exception__( 390 f"Tuple length mismatch for `{key}`. Expected length {len(expected_args)}, got {len(obj)}" 391 ) 392 for idx, (item, ex) in enumerate(zip(obj, expected_args)): 393 self.__check_type__(item, ex, f"{key}[{idx}]") 394 elif obj_type == set: 395 if not self.__quick_check__(subtype, obj): 396 for item in obj: 397 self.__check_type__( 398 item, subtype, f"{key}[{repr(item)}]" 399 ) 400 401 # Validate constraints if any are present 402 constraints = extra.get("__constraints__", []) 403 for constraint in constraints: 404 constraint_validation_output = constraint.__validate__(key, obj) 405 if constraint_validation_output is not True: 406 self.__exception__( 407 f"Constraint validation error for variable `{key}` with value `{obj}`. {constraint_validation_output}" 408 ) 409 410 def __repr__(self): 411 return f"<type_enforced {self.__fn__.__module__}.{self.__fn__.__qualname__} object at {hex(id(self))}>" 412 413 414@Partial 415def Enforcer(clsFnMethod, enabled=True, strict=True, clean_traceback=True): 416 """ 417 A wrapper to enforce types within a function or method given argument annotations. 418 419 Each wrapped item is converted into a special `FunctionMethodEnforcer` class object that validates the passed parameters for the function or method when it is called. If a function or method that is passed does not have any annotations, it is not converted into a `FunctionMethodEnforcer` class as no validation is possible. 420 421 If wrapping a class, all methods in the class that meet any of the following criteria will be wrapped individually: 422 423 - Methods with `__call__` 424 - Methods wrapped with `staticmethod` (if python >= 3.10) 425 - Methods wrapped with `classmethod` (if python >= 3.10) 426 427 Requires: 428 429 - `clsFnMethod`: 430 - What: The class, function or method that should have input types enforced 431 - Type: function | method | class 432 433 Optional: 434 435 - `enabled`: 436 - What: A boolean to enable or disable the enforcer 437 - Type: bool 438 - Default: True 439 - `strict`: 440 - What: A boolean to enable or disable exceptions. If True, exceptions will be raised when type checking fails. If False, exceptions will not be raised but instead a warning will be printed to the console. 441 - Type: bool 442 - Default: False 443 - Note: Type hints that are wrapped with the type enforcer and are invalid will still raise an exception. 444 - `clean_traceback`: 445 - What: A boolean to enable or disable cleaning of tracebacks when raising exceptions. 446 - If True, modifies the excepthook temporarily such that only the relevant stack (not in the type_enforced package) is shown. 447 - Type: bool 448 - Default: True 449 450 451 Example Use: 452 ``` 453 >>> import type_enforced 454 >>> @type_enforced.Enforcer 455 ... def my_fn(a: int , b: [int, str] =2, c: int =3) -> None: 456 ... pass 457 ... 458 >>> my_fn(a=1, b=2, c=3) 459 >>> my_fn(a=1, b='2', c=3) 460 >>> my_fn(a='a', b=2, c=3) 461 Traceback (most recent call last): 462 File "<stdin>", line 1, in <module> 463 File "/home/conmak/development/personal/type_enforced/type_enforced/enforcer.py", line 85, in __call__ 464 self.__check_type__(assigned_vars.get(key), value, key) 465 File "/home/conmak/development/personal/type_enforced/type_enforced/enforcer.py", line 107, in __check_type__ 466 self.__exception__( 467 File "/home/conmak/development/personal/type_enforced/type_enforced/enforcer.py", line 34, in __exception__ 468 raise Exception(f"({self.__fn__.__qualname__}): {message}") 469 Exception: (my_fn): Type mismatch for typed variable `a`. Expected one of the following `[<class 'int'>]` but got `<class 'str'>` instead. 470 ``` 471 """ 472 if not hasattr(clsFnMethod, "__type_enforced_enabled__"): 473 # Special try except clause to handle cases when the object is immutable 474 try: 475 clsFnMethod.__type_enforced_enabled__ = enabled 476 except: 477 return clsFnMethod 478 if not clsFnMethod.__type_enforced_enabled__: 479 return clsFnMethod 480 if isinstance( 481 clsFnMethod, (staticmethod, classmethod, FunctionType, MethodType) 482 ): 483 # Only apply the enforcer if type_hints are present 484 # Add try except clause to better handle forward refs. 485 try: 486 if get_type_hints(clsFnMethod) == {}: 487 return clsFnMethod 488 except: 489 pass 490 if isinstance(clsFnMethod, staticmethod): 491 return staticmethod( 492 FunctionMethodEnforcer( 493 __fn__=clsFnMethod.__func__, 494 __strict__=strict, 495 __clean_traceback__=clean_traceback, 496 ) 497 ) 498 elif isinstance(clsFnMethod, classmethod): 499 return classmethod( 500 FunctionMethodEnforcer( 501 __fn__=clsFnMethod.__func__, 502 __strict__=strict, 503 __clean_traceback__=clean_traceback, 504 ) 505 ) 506 else: 507 return FunctionMethodEnforcer( 508 __fn__=clsFnMethod, 509 __strict__=strict, 510 __clean_traceback__=clean_traceback, 511 ) 512 elif hasattr(clsFnMethod, "__dict__"): 513 for key, value in clsFnMethod.__dict__.items(): 514 # Skip the __annotate__ method if present in __dict__ as it deletes itself upon invocation 515 # Skip any previously wrapped methods if they are already a FunctionMethodEnforcer 516 if key == "__annotate__" or isinstance( 517 value, FunctionMethodEnforcer 518 ): 519 continue 520 if hasattr(value, "__call__") or isinstance( 521 value, (classmethod, staticmethod) 522 ): 523 setattr( 524 clsFnMethod, 525 key, 526 Enforcer( 527 value, 528 enabled=enabled, 529 strict=strict, 530 clean_traceback=clean_traceback, 531 ), 532 ) 533 return clsFnMethod 534 else: 535 raise Exception( 536 "Enforcer can only be used on classes, methods, or functions." 537 )
class
FunctionMethodEnforcer:
22class FunctionMethodEnforcer: 23 def __init__(self, __fn__, __strict__=False, __clean_traceback__=True): 24 """ 25 Initialize a FunctionMethodEnforcer class object as a wrapper for a passed function `__fn__`. 26 27 Requires: 28 29 - `__fn__`: 30 - What: The function to enforce 31 - Type: function | method | class 32 33 Optional: 34 35 - `__strict__`: 36 - What: A boolean to enable or disable exceptions. If True, exceptions will be raised 37 when type checking fails. If False, exceptions will not be raised but instead a warning 38 will be printed to the console. 39 - Type: bool 40 - Default: False 41 - `__clean_traceback__`: 42 - What: A boolean to enable or disable cleaning of tracebacks when raising exceptions. 43 - Type: bool 44 - Default: True 45 """ 46 update_wrapper(self, __fn__) 47 self.__fn__ = __fn__ 48 self.__strict__ = __strict__ 49 self.__clean_traceback__ = __clean_traceback__ 50 self.__outer_self__ = None 51 # Validate that the passed function or method is a method or function 52 self.__check_method_function__() 53 # Get input defaults for the function or method 54 self.__get_defaults__() 55 56 def __get_defaults__(self): 57 """ 58 Get the default values of the passed function or method and store them in `self.__fn_defaults__`. 59 """ 60 self.__fn_defaults__ = {} 61 if self.__fn__.__defaults__ is not None: 62 # Get the names of all provided default values for args 63 default_varnames = list(self.__fn__.__code__.co_varnames)[ 64 : self.__fn__.__code__.co_argcount 65 ][-len(self.__fn__.__defaults__) :] 66 # Update the output dictionary with the default values 67 self.__fn_defaults__.update( 68 dict(zip(default_varnames, self.__fn__.__defaults__)) 69 ) 70 if self.__fn__.__kwdefaults__ is not None: 71 # Update the output dictionary with the keyword default values 72 self.__fn_defaults__.update(self.__fn__.__kwdefaults__) 73 74 def __get_checkable_types__(self): 75 """ 76 Creates two class attributes: 77 78 - `self.__checkable_types__`: 79 - What: A dictionary of all annotations as checkable types 80 - Type: dict 81 82 - `self.__return_type__`: 83 - What: The return type of the function or method 84 - Type: dict | None 85 """ 86 if not hasattr(self, "__checkable_types__"): 87 self.__checkable_types__ = { 88 key: self.__get_checkable_type__(value) 89 for key, value in get_type_hints(self.__fn__).items() 90 } 91 self.__return_type__ = self.__checkable_types__.pop("return", None) 92 93 def __get_checkable_type__(self, annotation): 94 """ 95 Parses a type annotation and returns a nested dict structure 96 representing the checkable type(s) for validation. 97 """ 98 99 if annotation is None: 100 return {type(None): None} 101 102 # Handle `int | str` syntax (Python 3.10+) and Unions 103 if ( 104 isinstance(annotation, UnionType) 105 or getattr(annotation, "__origin__", None) == Union 106 ): 107 combined_types = {} 108 for sub_type in annotation.__args__: 109 combined_types = DeepMerge( 110 combined_types, self.__get_checkable_type__(sub_type) 111 ) 112 return combined_types 113 114 # Handle typing.Literal 115 if getattr(annotation, "__origin__", None) == Literal: 116 return {"__extra__": {"__literal__": list(annotation.__args__)}} 117 118 # Handle generic collections 119 origin = getattr(annotation, "__origin__", None) 120 args = getattr(annotation, "__args__", ()) 121 122 if origin == list: 123 if len(args) != 1: 124 self.__exception__( 125 f"List must have a single type argument, got: {args}", 126 raise_exception=True, 127 ) 128 return {list: self.__get_checkable_type__(args[0])} 129 130 if origin == dict: 131 if len(args) != 2: 132 self.__exception__( 133 f"Dict must have two type arguments, got: {args}", 134 raise_exception=True, 135 ) 136 key_type = self.__get_checkable_type__(args[0]) 137 value_type = self.__get_checkable_type__(args[1]) 138 return {dict: (key_type, value_type)} 139 140 if origin == tuple: 141 if len(args) > 2 or len(args) == 1: 142 if Ellipsis in args: 143 self.__exception__( 144 "Tuple with Ellipsis must have exactly two type arguments and the second must be Ellipsis.", 145 raise_exception=True, 146 ) 147 if len(args) == 2: 148 if args[0] is Ellipsis: 149 self.__exception__( 150 "Tuple with Ellipsis must have exactly two type arguments and the first must not be Ellipsis.", 151 raise_exception=True, 152 ) 153 if args[1] is Ellipsis: 154 return {tuple: (self.__get_checkable_type__(args[0]), True)} 155 return { 156 tuple: ( 157 tuple(self.__get_checkable_type__(arg) for arg in args), 158 False, 159 ) 160 } 161 162 if origin == set: 163 if len(args) != 1: 164 self.__exception__( 165 f"Set must have a single type argument, got: {args}", 166 raise_exception=True, 167 ) 168 return {set: self.__get_checkable_type__(args[0])} 169 170 # Handle Sized types 171 if annotation == Sized: 172 return { 173 list: None, 174 tuple: None, 175 dict: None, 176 set: None, 177 str: None, 178 bytes: None, 179 bytearray: None, 180 memoryview: None, 181 range: None, 182 } 183 184 # Handle Callable types 185 if annotation == Callable: 186 return { 187 staticmethod: None, 188 classmethod: None, 189 FunctionType: None, 190 BuiltinFunctionType: None, 191 MethodType: None, 192 BuiltinMethodType: None, 193 GeneratorType: None, 194 } 195 196 if annotation == Any: 197 return { 198 object: None, 199 } 200 201 # Handle Constraints 202 if isinstance(annotation, GenericConstraint): 203 return {"__extra__": {"__constraints__": [annotation]}} 204 205 # Handle standard types 206 if isinstance(annotation, type): 207 return {annotation: None} 208 209 # Hanldle typing.Type (for uninitialized classes) 210 if origin is type and len(args) == 1: 211 return {annotation: None} 212 213 self.__exception__( 214 f"Unsupported type hint: {annotation}", raise_exception=True 215 ) 216 217 def __exception__(self, message, raise_exception=False): 218 """ 219 Usage: 220 221 - Creates a class based exception message 222 223 Requires: 224 225 - `message`: 226 - Type: str 227 - What: The message to warn users with 228 229 Optional: 230 231 - `raise_exception`: 232 - Type: bool 233 - What: Forces an exception to be raised regardless of the `self.__strict__` setting. 234 - Default: False 235 """ 236 if self.__strict__ or raise_exception: 237 msg = f"TypeEnforced Exception ({self.__fn__.__qualname__}): {message}" 238 if self.__clean_traceback__: 239 package_path = Path(__file__).parent.resolve() 240 frame = sys._getframe() 241 relevant_tb_count = 0 242 while frame is not None: 243 frame_file = Path(frame.f_code.co_filename).resolve() 244 try: 245 frame_file.relative_to(package_path) 246 except ValueError: 247 relevant_tb_count += 1 248 frame = frame.f_back 249 original_excepthook = sys.excepthook 250 251 def excepthook(type, value, tb): 252 traceback.print_exception( 253 type, value, tb, limit=relevant_tb_count 254 ) 255 sys.excepthook = original_excepthook 256 257 sys.excepthook = excepthook 258 raise TypeError(msg) 259 else: 260 print( 261 f"TypeEnforced Warning ({self.__fn__.__qualname__}): {message}" 262 ) 263 264 def __get__(self, obj, objtype): 265 """ 266 Overwrite standard __get__ method to return __call__ instead for wrapped class methods. 267 268 Also stores the calling (__get__) `obj` to be passed as an initial argument for `__call__` such that methods can pass `self` correctly. 269 """ 270 271 @wraps(self.__fn__) 272 def __get_fn__(*args, **kwargs): 273 return self.__call__(*args, **kwargs) 274 275 self.__outer_self__ = obj 276 return __get_fn__ 277 278 def __check_method_function__(self): 279 """ 280 Validate that `self.__fn__` is a method or function 281 """ 282 if not isinstance(self.__fn__, (MethodType, FunctionType)): 283 raise Exception( 284 f"A non function/method was passed to Enforcer. See the stack trace above for more information." 285 ) 286 287 def __call__(self, *args, **kwargs): 288 """ 289 This method is used to validate the passed inputs and return the output of the wrapped function or method. 290 """ 291 # Special code to pass self as an initial argument 292 # for validation purposes in methods 293 # See: self.__get__ 294 if self.__outer_self__ is not None: 295 args = (self.__outer_self__, *args) 296 # Get a dictionary of all annotations as checkable types 297 # Note: This is only done once at first call to avoid redundant calculations 298 self.__get_checkable_types__() 299 # Create a compreshensive dictionary of assigned variables (order matters) 300 assigned_vars = { 301 **self.__fn_defaults__, 302 **dict(zip(self.__fn__.__code__.co_varnames[: len(args)], args)), 303 **kwargs, 304 } 305 # Validate all listed annotations vs the assigned_vars dictionary 306 for key, value in self.__checkable_types__.items(): 307 self.__check_type__(assigned_vars.get(key), value, key) 308 # Execute the function callable 309 return_value = self.__fn__(*args, **kwargs) 310 # If a return type was passed, validate the returned object 311 if self.__return_type__ is not None: 312 self.__check_type__(return_value, self.__return_type__, "return") 313 return return_value 314 315 def __quick_check__(self, subtype, obj): 316 if all([v == None for v in subtype.values()]): 317 # If the subtype does not contain iterables with typing, we can validate the items directly. 318 types = set(subtype.keys()) 319 values = set([type(v) for v in obj]) 320 if values.issubset(types): 321 # We can return True to bypass the full validation 322 return True 323 # Otherwise, validation did not pass and a full validation is required to raise an indexed/keyed type mismatch error 324 return False 325 326 def __check_type__(self, obj, expected, key): 327 """ 328 Raises an exception the type of a passed `obj` (parameter) is not in the list of supplied `acceptable_types` for the argument. 329 """ 330 # Special case for None 331 if obj is None and type(None) in expected: 332 return 333 extra = expected.get("__extra__", {}) 334 expected = {k: v for k, v in expected.items() if k != "__extra__"} 335 336 if isinstance(obj, type): 337 # An uninitialized class is passed, we need to check if the type is in the expected types using Type[obj] 338 obj_type = Type[obj] 339 is_present = obj_type in expected 340 else: 341 obj_type = type(obj) 342 is_present = isinstance(obj, tuple(expected.keys())) 343 344 if not is_present: 345 # Allow for literals to be used to bypass type checks if present 346 literal = extra.get("__literal__", ()) 347 if literal: 348 if obj not in literal: 349 self.__exception__( 350 f"Type mismatch for typed variable `{key}`. Expected one of the following `{list(expected.keys())}` or a literal value in `{literal}` but got type `{obj_type}` with value `{obj}` instead." 351 ) 352 # Raise an exception if the type is not in the expected types 353 else: 354 self.__exception__( 355 f"Type mismatch for typed variable `{key}`. Expected one of the following `{list(expected.keys())}` but got `{obj_type}` with value `{obj}` instead." 356 ) 357 # If the object_type is in the expected types, we can proceed with validation 358 elif obj_type in iterable_types: 359 subtype = expected.get(obj_type, None) 360 if subtype is None: 361 pass 362 # Recursive validation 363 elif obj_type == list: 364 # If the subtype does not contain iterables with typing, we can validate the items directly. 365 if not self.__quick_check__(subtype, obj): 366 for idx, item in enumerate(obj): 367 self.__check_type__(item, subtype, f"{key}[{idx}]") 368 elif obj_type == dict: 369 key_type, val_type = subtype 370 if not self.__quick_check__(key_type, obj.keys()): 371 for key in obj.keys(): 372 self.__check_type__( 373 key, key_type, f"{key}.key[{repr(key)}]" 374 ) 375 if not self.__quick_check__(val_type, obj.values()): 376 for key, value in obj.items(): 377 self.__check_type__( 378 value, val_type, f"{key}[{repr(key)}]" 379 ) 380 elif obj_type == tuple: 381 expected_args, is_ellipsis = subtype 382 if is_ellipsis: 383 if not self.__quick_check__(expected_args, obj): 384 for idx, item in enumerate(obj): 385 self.__check_type__( 386 item, expected_args, f"{key}[{idx}]" 387 ) 388 else: 389 if len(obj) != len(expected_args): 390 self.__exception__( 391 f"Tuple length mismatch for `{key}`. Expected length {len(expected_args)}, got {len(obj)}" 392 ) 393 for idx, (item, ex) in enumerate(zip(obj, expected_args)): 394 self.__check_type__(item, ex, f"{key}[{idx}]") 395 elif obj_type == set: 396 if not self.__quick_check__(subtype, obj): 397 for item in obj: 398 self.__check_type__( 399 item, subtype, f"{key}[{repr(item)}]" 400 ) 401 402 # Validate constraints if any are present 403 constraints = extra.get("__constraints__", []) 404 for constraint in constraints: 405 constraint_validation_output = constraint.__validate__(key, obj) 406 if constraint_validation_output is not True: 407 self.__exception__( 408 f"Constraint validation error for variable `{key}` with value `{obj}`. {constraint_validation_output}" 409 ) 410 411 def __repr__(self): 412 return f"<type_enforced {self.__fn__.__module__}.{self.__fn__.__qualname__} object at {hex(id(self))}>"
FunctionMethodEnforcer(__fn__, __strict__=False, __clean_traceback__=True)
23 def __init__(self, __fn__, __strict__=False, __clean_traceback__=True): 24 """ 25 Initialize a FunctionMethodEnforcer class object as a wrapper for a passed function `__fn__`. 26 27 Requires: 28 29 - `__fn__`: 30 - What: The function to enforce 31 - Type: function | method | class 32 33 Optional: 34 35 - `__strict__`: 36 - What: A boolean to enable or disable exceptions. If True, exceptions will be raised 37 when type checking fails. If False, exceptions will not be raised but instead a warning 38 will be printed to the console. 39 - Type: bool 40 - Default: False 41 - `__clean_traceback__`: 42 - What: A boolean to enable or disable cleaning of tracebacks when raising exceptions. 43 - Type: bool 44 - Default: True 45 """ 46 update_wrapper(self, __fn__) 47 self.__fn__ = __fn__ 48 self.__strict__ = __strict__ 49 self.__clean_traceback__ = __clean_traceback__ 50 self.__outer_self__ = None 51 # Validate that the passed function or method is a method or function 52 self.__check_method_function__() 53 # Get input defaults for the function or method 54 self.__get_defaults__()
Initialize a FunctionMethodEnforcer class object as a wrapper for a passed function __fn__.
Requires:
- `__fn__`:
- What: The function to enforce
- Type: function | method | class
Optional:
- `__strict__`:
- What: A boolean to enable or disable exceptions. If True, exceptions will be raised
when type checking fails. If False, exceptions will not be raised but instead a warning
will be printed to the console.
- Type: bool
- Default: False
- `__clean_traceback__`:
- What: A boolean to enable or disable cleaning of tracebacks when raising exceptions.
- Type: bool
- Default: True
@Partial
def
Enforcer(clsFnMethod, enabled=True, strict=True, clean_traceback=True):
415@Partial 416def Enforcer(clsFnMethod, enabled=True, strict=True, clean_traceback=True): 417 """ 418 A wrapper to enforce types within a function or method given argument annotations. 419 420 Each wrapped item is converted into a special `FunctionMethodEnforcer` class object that validates the passed parameters for the function or method when it is called. If a function or method that is passed does not have any annotations, it is not converted into a `FunctionMethodEnforcer` class as no validation is possible. 421 422 If wrapping a class, all methods in the class that meet any of the following criteria will be wrapped individually: 423 424 - Methods with `__call__` 425 - Methods wrapped with `staticmethod` (if python >= 3.10) 426 - Methods wrapped with `classmethod` (if python >= 3.10) 427 428 Requires: 429 430 - `clsFnMethod`: 431 - What: The class, function or method that should have input types enforced 432 - Type: function | method | class 433 434 Optional: 435 436 - `enabled`: 437 - What: A boolean to enable or disable the enforcer 438 - Type: bool 439 - Default: True 440 - `strict`: 441 - What: A boolean to enable or disable exceptions. If True, exceptions will be raised when type checking fails. If False, exceptions will not be raised but instead a warning will be printed to the console. 442 - Type: bool 443 - Default: False 444 - Note: Type hints that are wrapped with the type enforcer and are invalid will still raise an exception. 445 - `clean_traceback`: 446 - What: A boolean to enable or disable cleaning of tracebacks when raising exceptions. 447 - If True, modifies the excepthook temporarily such that only the relevant stack (not in the type_enforced package) is shown. 448 - Type: bool 449 - Default: True 450 451 452 Example Use: 453 ``` 454 >>> import type_enforced 455 >>> @type_enforced.Enforcer 456 ... def my_fn(a: int , b: [int, str] =2, c: int =3) -> None: 457 ... pass 458 ... 459 >>> my_fn(a=1, b=2, c=3) 460 >>> my_fn(a=1, b='2', c=3) 461 >>> my_fn(a='a', b=2, c=3) 462 Traceback (most recent call last): 463 File "<stdin>", line 1, in <module> 464 File "/home/conmak/development/personal/type_enforced/type_enforced/enforcer.py", line 85, in __call__ 465 self.__check_type__(assigned_vars.get(key), value, key) 466 File "/home/conmak/development/personal/type_enforced/type_enforced/enforcer.py", line 107, in __check_type__ 467 self.__exception__( 468 File "/home/conmak/development/personal/type_enforced/type_enforced/enforcer.py", line 34, in __exception__ 469 raise Exception(f"({self.__fn__.__qualname__}): {message}") 470 Exception: (my_fn): Type mismatch for typed variable `a`. Expected one of the following `[<class 'int'>]` but got `<class 'str'>` instead. 471 ``` 472 """ 473 if not hasattr(clsFnMethod, "__type_enforced_enabled__"): 474 # Special try except clause to handle cases when the object is immutable 475 try: 476 clsFnMethod.__type_enforced_enabled__ = enabled 477 except: 478 return clsFnMethod 479 if not clsFnMethod.__type_enforced_enabled__: 480 return clsFnMethod 481 if isinstance( 482 clsFnMethod, (staticmethod, classmethod, FunctionType, MethodType) 483 ): 484 # Only apply the enforcer if type_hints are present 485 # Add try except clause to better handle forward refs. 486 try: 487 if get_type_hints(clsFnMethod) == {}: 488 return clsFnMethod 489 except: 490 pass 491 if isinstance(clsFnMethod, staticmethod): 492 return staticmethod( 493 FunctionMethodEnforcer( 494 __fn__=clsFnMethod.__func__, 495 __strict__=strict, 496 __clean_traceback__=clean_traceback, 497 ) 498 ) 499 elif isinstance(clsFnMethod, classmethod): 500 return classmethod( 501 FunctionMethodEnforcer( 502 __fn__=clsFnMethod.__func__, 503 __strict__=strict, 504 __clean_traceback__=clean_traceback, 505 ) 506 ) 507 else: 508 return FunctionMethodEnforcer( 509 __fn__=clsFnMethod, 510 __strict__=strict, 511 __clean_traceback__=clean_traceback, 512 ) 513 elif hasattr(clsFnMethod, "__dict__"): 514 for key, value in clsFnMethod.__dict__.items(): 515 # Skip the __annotate__ method if present in __dict__ as it deletes itself upon invocation 516 # Skip any previously wrapped methods if they are already a FunctionMethodEnforcer 517 if key == "__annotate__" or isinstance( 518 value, FunctionMethodEnforcer 519 ): 520 continue 521 if hasattr(value, "__call__") or isinstance( 522 value, (classmethod, staticmethod) 523 ): 524 setattr( 525 clsFnMethod, 526 key, 527 Enforcer( 528 value, 529 enabled=enabled, 530 strict=strict, 531 clean_traceback=clean_traceback, 532 ), 533 ) 534 return clsFnMethod 535 else: 536 raise Exception( 537 "Enforcer can only be used on classes, methods, or functions." 538 )
A wrapper to enforce types within a function or method given argument annotations.
Each wrapped item is converted into a special FunctionMethodEnforcer class object that validates the passed parameters for the function or method when it is called. If a function or method that is passed does not have any annotations, it is not converted into a FunctionMethodEnforcer class as no validation is possible.
If wrapping a class, all methods in the class that meet any of the following criteria will be wrapped individually:
- Methods with
__call__ - Methods wrapped with
staticmethod(if python >= 3.10) - Methods wrapped with
classmethod(if python >= 3.10)
Requires:
clsFnMethod:- What: The class, function or method that should have input types enforced
- Type: function | method | class
Optional:
enabled:- What: A boolean to enable or disable the enforcer
- Type: bool
- Default: True
strict:- What: A boolean to enable or disable exceptions. If True, exceptions will be raised when type checking fails. If False, exceptions will not be raised but instead a warning will be printed to the console.
- Type: bool
- Default: False
- Note: Type hints that are wrapped with the type enforcer and are invalid will still raise an exception.
clean_traceback:- What: A boolean to enable or disable cleaning of tracebacks when raising exceptions.
- If True, modifies the excepthook temporarily such that only the relevant stack (not in the type_enforced package) is shown.
- Type: bool
- Default: True
Example Use:
>>> import type_enforced
>>> @type_enforced.Enforcer
... def my_fn(a: int , b: [int, str] =2, c: int =3) -> None:
... pass
...
>>> my_fn(a=1, b=2, c=3)
>>> my_fn(a=1, b='2', c=3)
>>> my_fn(a='a', b=2, c=3)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/home/conmak/development/personal/type_enforced/type_enforced/enforcer.py", line 85, in __call__
self.__check_type__(assigned_vars.get(key), value, key)
File "/home/conmak/development/personal/type_enforced/type_enforced/enforcer.py", line 107, in __check_type__
self.__exception__(
File "/home/conmak/development/personal/type_enforced/type_enforced/enforcer.py", line 34, in __exception__
raise Exception(f"({self.__fn__.__qualname__}): {message}")
Exception: (my_fn): Type mismatch for typed variable `a`. Expected one of the following `[<class 'int'>]` but got `<class 'str'>` instead.