type_enforced.enforcer
1from types import ( 2 FunctionType, 3 MethodType, 4 GeneratorType, 5 BuiltinFunctionType, 6 BuiltinMethodType, 7 UnionType, 8) 9from typing import Type, Union, Sized, Literal, Callable, get_type_hints, Any 10from functools import update_wrapper, wraps 11from type_enforced.utils import ( 12 Partial, 13 GenericConstraint, 14 DeepMerge, 15 iterable_types, 16) 17 18 19class FunctionMethodEnforcer: 20 def __init__(self, __fn__, __strict__=False): 21 """ 22 Initialize a FunctionMethodEnforcer class object as a wrapper for a passed function `__fn__`. 23 24 Requires: 25 26 - `__fn__`: 27 - What: The function to enforce 28 - Type: function | method | class 29 """ 30 update_wrapper(self, __fn__) 31 self.__fn__ = __fn__ 32 self.__strict__ = __strict__ 33 self.__outer_self__ = None 34 # Validate that the passed function or method is a method or function 35 self.__check_method_function__() 36 # Get input defaults for the function or method 37 self.__get_defaults__() 38 39 def __get_defaults__(self): 40 """ 41 Get the default values of the passed function or method and store them in `self.__fn_defaults__`. 42 """ 43 self.__fn_defaults__ = {} 44 if self.__fn__.__defaults__ is not None: 45 # Get the names of all provided default values for args 46 default_varnames = list(self.__fn__.__code__.co_varnames)[ 47 : self.__fn__.__code__.co_argcount 48 ][-len(self.__fn__.__defaults__) :] 49 # Update the output dictionary with the default values 50 self.__fn_defaults__.update( 51 dict(zip(default_varnames, self.__fn__.__defaults__)) 52 ) 53 if self.__fn__.__kwdefaults__ is not None: 54 # Update the output dictionary with the keyword default values 55 self.__fn_defaults__.update(self.__fn__.__kwdefaults__) 56 57 def __get_checkable_types__(self): 58 """ 59 Creates two class attributes: 60 61 - `self.__checkable_types__`: 62 - What: A dictionary of all annotations as checkable types 63 - Type: dict 64 65 - `self.__return_type__`: 66 - What: The return type of the function or method 67 - Type: dict | None 68 """ 69 if not hasattr(self, "__checkable_types__"): 70 self.__checkable_types__ = { 71 key: self.__get_checkable_type__(value) 72 for key, value in get_type_hints(self.__fn__).items() 73 } 74 self.__return_type__ = self.__checkable_types__.pop("return", None) 75 76 def __get_checkable_type__(self, annotation): 77 """ 78 Parses a type annotation and returns a nested dict structure 79 representing the checkable type(s) for validation. 80 """ 81 82 if annotation is None: 83 return {type(None): None} 84 85 # Handle `int | str` syntax (Python 3.10+) and Unions 86 if ( 87 isinstance(annotation, UnionType) 88 or getattr(annotation, "__origin__", None) == Union 89 ): 90 combined_types = {} 91 for sub_type in annotation.__args__: 92 combined_types = DeepMerge( 93 combined_types, self.__get_checkable_type__(sub_type) 94 ) 95 return combined_types 96 97 # Handle typing.Literal 98 if getattr(annotation, "__origin__", None) == Literal: 99 return {"__extra__": {"__literal__": list(annotation.__args__)}} 100 101 # Handle generic collections 102 origin = getattr(annotation, "__origin__", None) 103 args = getattr(annotation, "__args__", ()) 104 105 if origin == list: 106 if len(args) != 1: 107 self.__exception__( 108 f"List must have a single type argument, got: {args}", 109 raise_exception=True, 110 ) 111 return {list: self.__get_checkable_type__(args[0])} 112 113 if origin == dict: 114 if len(args) != 2: 115 self.__exception__( 116 f"Dict must have two type arguments, got: {args}", 117 raise_exception=True, 118 ) 119 key_type = self.__get_checkable_type__(args[0]) 120 value_type = self.__get_checkable_type__(args[1]) 121 return {dict: (key_type, value_type)} 122 123 if origin == tuple: 124 if len(args) > 2 or len(args) == 1: 125 if Ellipsis in args: 126 self.__exception__( 127 "Tuple with Ellipsis must have exactly two type arguments and the second must be Ellipsis.", 128 raise_exception=True, 129 ) 130 if len(args) == 2: 131 if args[0] is Ellipsis: 132 self.__exception__( 133 "Tuple with Ellipsis must have exactly two type arguments and the first must not be Ellipsis.", 134 raise_exception=True, 135 ) 136 if args[1] is Ellipsis: 137 return {tuple: (self.__get_checkable_type__(args[0]), True)} 138 return { 139 tuple: ( 140 tuple(self.__get_checkable_type__(arg) for arg in args), 141 False, 142 ) 143 } 144 145 if origin == set: 146 if len(args) != 1: 147 self.__exception__( 148 f"Set must have a single type argument, got: {args}", 149 raise_exception=True, 150 ) 151 return {set: self.__get_checkable_type__(args[0])} 152 153 # Handle Sized types 154 if annotation == Sized: 155 return { 156 list: None, 157 tuple: None, 158 dict: None, 159 set: None, 160 str: None, 161 bytes: None, 162 bytearray: None, 163 memoryview: None, 164 range: None, 165 } 166 167 # Handle Callable types 168 if annotation == Callable: 169 return { 170 staticmethod: None, 171 classmethod: None, 172 FunctionType: None, 173 BuiltinFunctionType: None, 174 MethodType: None, 175 BuiltinMethodType: None, 176 GeneratorType: None, 177 } 178 179 if annotation == Any: 180 return { 181 object: None, 182 } 183 184 # Handle Constraints 185 if isinstance(annotation, GenericConstraint): 186 return {"__extra__": {"__constraints__": [annotation]}} 187 188 # Handle standard types 189 if isinstance(annotation, type): 190 return {annotation: None} 191 192 # Hanldle typing.Type (for uninitialized classes) 193 if origin is type and len(args) == 1: 194 return {annotation: None} 195 196 self.__exception__( 197 f"Unsupported type hint: {annotation}", raise_exception=True 198 ) 199 200 def __exception__(self, message, raise_exception=False): 201 """ 202 Usage: 203 204 - Creates a class based exception message 205 206 Requires: 207 208 - `message`: 209 - Type: str 210 - What: The message to warn users with 211 212 Optional: 213 214 - `raise_exception`: 215 - Type: bool 216 - What: Forces an exception to be raised regardless of the `self.__strict__` setting. 217 - Default: False 218 """ 219 if self.__strict__ or raise_exception: 220 raise TypeError( 221 f"TypeEnforced Exception ({self.__fn__.__qualname__}): {message}" 222 ) 223 else: 224 print( 225 f"TypeEnforced Warning ({self.__fn__.__qualname__}): {message}" 226 ) 227 228 def __get__(self, obj, objtype): 229 """ 230 Overwrite standard __get__ method to return __call__ instead for wrapped class methods. 231 232 Also stores the calling (__get__) `obj` to be passed as an initial argument for `__call__` such that methods can pass `self` correctly. 233 """ 234 235 @wraps(self.__fn__) 236 def __get_fn__(*args, **kwargs): 237 return self.__call__(*args, **kwargs) 238 239 self.__outer_self__ = obj 240 return __get_fn__ 241 242 def __check_method_function__(self): 243 """ 244 Validate that `self.__fn__` is a method or function 245 """ 246 if not isinstance(self.__fn__, (MethodType, FunctionType)): 247 raise Exception( 248 f"A non function/method was passed to Enforcer. See the stack trace above for more information." 249 ) 250 251 def __call__(self, *args, **kwargs): 252 """ 253 This method is used to validate the passed inputs and return the output of the wrapped function or method. 254 """ 255 # Special code to pass self as an initial argument 256 # for validation purposes in methods 257 # See: self.__get__ 258 if self.__outer_self__ is not None: 259 args = (self.__outer_self__, *args) 260 # Get a dictionary of all annotations as checkable types 261 # Note: This is only done once at first call to avoid redundant calculations 262 self.__get_checkable_types__() 263 # Create a compreshensive dictionary of assigned variables (order matters) 264 assigned_vars = { 265 **self.__fn_defaults__, 266 **dict(zip(self.__fn__.__code__.co_varnames[: len(args)], args)), 267 **kwargs, 268 } 269 # Validate all listed annotations vs the assigned_vars dictionary 270 for key, value in self.__checkable_types__.items(): 271 self.__check_type__(assigned_vars.get(key), value, key) 272 # Execute the function callable 273 return_value = self.__fn__(*args, **kwargs) 274 # If a return type was passed, validate the returned object 275 if self.__return_type__ is not None: 276 self.__check_type__(return_value, self.__return_type__, "return") 277 return return_value 278 279 def __quick_check__(self, subtype, obj): 280 if all([v == None for v in subtype.values()]): 281 # If the subtype does not contain iterables with typing, we can validate the items directly. 282 types = set(subtype.keys()) 283 values = set([type(v) for v in obj]) 284 if values.issubset(types): 285 # We can return True to bypass the full validation 286 return True 287 # Otherwise, validation did not pass and a full validation is required to raise an indexed/keyed type mismatch error 288 return False 289 290 def __check_type__(self, obj, expected, key): 291 """ 292 Raises an exception the type of a passed `obj` (parameter) is not in the list of supplied `acceptable_types` for the argument. 293 """ 294 # Special case for None 295 if obj is None and type(None) in expected: 296 return 297 extra = expected.get("__extra__", {}) 298 expected = {k: v for k, v in expected.items() if k != "__extra__"} 299 300 if isinstance(obj, type): 301 # An uninitialized class is passed, we need to check if the type is in the expected types using Type[obj] 302 obj_type = Type[obj] 303 is_present = obj_type in expected 304 else: 305 obj_type = type(obj) 306 is_present = isinstance(obj, tuple(expected.keys())) 307 308 if not is_present: 309 # Allow for literals to be used to bypass type checks if present 310 literal = extra.get("__literal__", ()) 311 if literal: 312 if obj not in literal: 313 self.__exception__( 314 f"Type mismatch for typed variable `{key}`. Expected one of the following `{list(expected.keys())}` or a literal value in `{literal}` but got type `{obj_type}` with value `{obj}` instead." 315 ) 316 # Raise an exception if the type is not in the expected types 317 else: 318 self.__exception__( 319 f"Type mismatch for typed variable `{key}`. Expected one of the following `{list(expected.keys())}` but got `{obj_type}` with value `{obj}` instead." 320 ) 321 # If the object_type is in the expected types, we can proceed with validation 322 elif obj_type in iterable_types: 323 subtype = expected.get(obj_type, None) 324 if subtype is None: 325 pass 326 # Recursive validation 327 elif obj_type == list: 328 # If the subtype does not contain iterables with typing, we can validate the items directly. 329 if not self.__quick_check__(subtype, obj): 330 for idx, item in enumerate(obj): 331 self.__check_type__(item, subtype, f"{key}[{idx}]") 332 elif obj_type == dict: 333 key_type, val_type = subtype 334 if not self.__quick_check__(key_type, obj.keys()): 335 for key in obj.keys(): 336 self.__check_type__( 337 key, key_type, f"{key}.key[{repr(key)}]" 338 ) 339 if not self.__quick_check__(val_type, obj.values()): 340 for key, value in obj.items(): 341 self.__check_type__( 342 value, val_type, f"{key}[{repr(key)}]" 343 ) 344 elif obj_type == tuple: 345 expected_args, is_ellipsis = subtype 346 if is_ellipsis: 347 if not self.__quick_check__(expected_args, obj): 348 for idx, item in enumerate(obj): 349 self.__check_type__( 350 item, expected_args, f"{key}[{idx}]" 351 ) 352 else: 353 if len(obj) != len(expected_args): 354 self.__exception__( 355 f"Tuple length mismatch for `{key}`. Expected length {len(expected_args)}, got {len(obj)}" 356 ) 357 for idx, (item, ex) in enumerate(zip(obj, expected_args)): 358 self.__check_type__(item, ex, f"{key}[{idx}]") 359 elif obj_type == set: 360 if not self.__quick_check__(subtype, obj): 361 for item in obj: 362 self.__check_type__( 363 item, subtype, f"{key}[{repr(item)}]" 364 ) 365 366 # Validate constraints if any are present 367 constraints = extra.get("__constraints__", []) 368 for constraint in constraints: 369 constraint_validation_output = constraint.__validate__(key, obj) 370 if constraint_validation_output is not True: 371 self.__exception__( 372 f"Constraint validation error for variable `{key}` with value `{obj}`. {constraint_validation_output}" 373 ) 374 375 def __repr__(self): 376 return f"<type_enforced {self.__fn__.__module__}.{self.__fn__.__qualname__} object at {hex(id(self))}>" 377 378 379@Partial 380def Enforcer(clsFnMethod, enabled=True, strict=True): 381 """ 382 A wrapper to enforce types within a function or method given argument annotations. 383 384 Each wrapped item is converted into a special `FunctionMethodEnforcer` class object that validates the passed parameters for the function or method when it is called. If a function or method that is passed does not have any annotations, it is not converted into a `FunctionMethodEnforcer` class as no validation is possible. 385 386 If wrapping a class, all methods in the class that meet any of the following criteria will be wrapped individually: 387 388 - Methods with `__call__` 389 - Methods wrapped with `staticmethod` (if python >= 3.10) 390 - Methods wrapped with `classmethod` (if python >= 3.10) 391 392 Requires: 393 394 - `clsFnMethod`: 395 - What: The class, function or method that should have input types enforced 396 - Type: function | method | class 397 398 Optional: 399 400 - `enabled`: 401 - What: A boolean to enable or disable the enforcer 402 - Type: bool 403 - Default: True 404 - `strict`: 405 - What: A boolean to enable or disable exceptions. If True, exceptions will be raised when type checking fails. If False, exceptions will not be raised but instead a warning will be printed to the console. 406 - Type: bool 407 - Default: False 408 - Note: Type hints that are wrapped with the type enforcer and are invalid will still raise an exception. 409 410 411 Example Use: 412 ``` 413 >>> import type_enforced 414 >>> @type_enforced.Enforcer 415 ... def my_fn(a: int , b: [int, str] =2, c: int =3) -> None: 416 ... pass 417 ... 418 >>> my_fn(a=1, b=2, c=3) 419 >>> my_fn(a=1, b='2', c=3) 420 >>> my_fn(a='a', b=2, c=3) 421 Traceback (most recent call last): 422 File "<stdin>", line 1, in <module> 423 File "/home/conmak/development/personal/type_enforced/type_enforced/enforcer.py", line 85, in __call__ 424 self.__check_type__(assigned_vars.get(key), value, key) 425 File "/home/conmak/development/personal/type_enforced/type_enforced/enforcer.py", line 107, in __check_type__ 426 self.__exception__( 427 File "/home/conmak/development/personal/type_enforced/type_enforced/enforcer.py", line 34, in __exception__ 428 raise Exception(f"({self.__fn__.__qualname__}): {message}") 429 Exception: (my_fn): Type mismatch for typed variable `a`. Expected one of the following `[<class 'int'>]` but got `<class 'str'>` instead. 430 ``` 431 """ 432 if not hasattr(clsFnMethod, "__type_enforced_enabled__"): 433 # Special try except clause to handle cases when the object is immutable 434 try: 435 clsFnMethod.__type_enforced_enabled__ = enabled 436 except: 437 return clsFnMethod 438 if not clsFnMethod.__type_enforced_enabled__: 439 return clsFnMethod 440 if isinstance( 441 clsFnMethod, (staticmethod, classmethod, FunctionType, MethodType) 442 ): 443 # Only apply the enforcer if type_hints are present 444 if get_type_hints(clsFnMethod) == {}: 445 return clsFnMethod 446 elif isinstance(clsFnMethod, staticmethod): 447 return staticmethod( 448 FunctionMethodEnforcer( 449 __fn__=clsFnMethod.__func__, __strict__=strict 450 ) 451 ) 452 elif isinstance(clsFnMethod, classmethod): 453 return classmethod( 454 FunctionMethodEnforcer( 455 __fn__=clsFnMethod.__func__, __strict__=strict 456 ) 457 ) 458 else: 459 return FunctionMethodEnforcer(__fn__=clsFnMethod, __strict__=strict) 460 elif hasattr(clsFnMethod, "__dict__"): 461 for key, value in clsFnMethod.__dict__.items(): 462 # Skip the __annotate__ method if present in __dict__ as it deletes itself upon invocation 463 # Skip any previously wrapped methods if they are already a FunctionMethodEnforcer 464 if key == "__annotate__" or isinstance( 465 value, FunctionMethodEnforcer 466 ): 467 continue 468 if hasattr(value, "__call__") or isinstance( 469 value, (classmethod, staticmethod) 470 ): 471 setattr( 472 clsFnMethod, 473 key, 474 Enforcer(value, enabled=enabled, strict=strict), 475 ) 476 return clsFnMethod 477 else: 478 raise Exception( 479 "Enforcer can only be used on classes, methods, or functions." 480 )
class
FunctionMethodEnforcer:
20class FunctionMethodEnforcer: 21 def __init__(self, __fn__, __strict__=False): 22 """ 23 Initialize a FunctionMethodEnforcer class object as a wrapper for a passed function `__fn__`. 24 25 Requires: 26 27 - `__fn__`: 28 - What: The function to enforce 29 - Type: function | method | class 30 """ 31 update_wrapper(self, __fn__) 32 self.__fn__ = __fn__ 33 self.__strict__ = __strict__ 34 self.__outer_self__ = None 35 # Validate that the passed function or method is a method or function 36 self.__check_method_function__() 37 # Get input defaults for the function or method 38 self.__get_defaults__() 39 40 def __get_defaults__(self): 41 """ 42 Get the default values of the passed function or method and store them in `self.__fn_defaults__`. 43 """ 44 self.__fn_defaults__ = {} 45 if self.__fn__.__defaults__ is not None: 46 # Get the names of all provided default values for args 47 default_varnames = list(self.__fn__.__code__.co_varnames)[ 48 : self.__fn__.__code__.co_argcount 49 ][-len(self.__fn__.__defaults__) :] 50 # Update the output dictionary with the default values 51 self.__fn_defaults__.update( 52 dict(zip(default_varnames, self.__fn__.__defaults__)) 53 ) 54 if self.__fn__.__kwdefaults__ is not None: 55 # Update the output dictionary with the keyword default values 56 self.__fn_defaults__.update(self.__fn__.__kwdefaults__) 57 58 def __get_checkable_types__(self): 59 """ 60 Creates two class attributes: 61 62 - `self.__checkable_types__`: 63 - What: A dictionary of all annotations as checkable types 64 - Type: dict 65 66 - `self.__return_type__`: 67 - What: The return type of the function or method 68 - Type: dict | None 69 """ 70 if not hasattr(self, "__checkable_types__"): 71 self.__checkable_types__ = { 72 key: self.__get_checkable_type__(value) 73 for key, value in get_type_hints(self.__fn__).items() 74 } 75 self.__return_type__ = self.__checkable_types__.pop("return", None) 76 77 def __get_checkable_type__(self, annotation): 78 """ 79 Parses a type annotation and returns a nested dict structure 80 representing the checkable type(s) for validation. 81 """ 82 83 if annotation is None: 84 return {type(None): None} 85 86 # Handle `int | str` syntax (Python 3.10+) and Unions 87 if ( 88 isinstance(annotation, UnionType) 89 or getattr(annotation, "__origin__", None) == Union 90 ): 91 combined_types = {} 92 for sub_type in annotation.__args__: 93 combined_types = DeepMerge( 94 combined_types, self.__get_checkable_type__(sub_type) 95 ) 96 return combined_types 97 98 # Handle typing.Literal 99 if getattr(annotation, "__origin__", None) == Literal: 100 return {"__extra__": {"__literal__": list(annotation.__args__)}} 101 102 # Handle generic collections 103 origin = getattr(annotation, "__origin__", None) 104 args = getattr(annotation, "__args__", ()) 105 106 if origin == list: 107 if len(args) != 1: 108 self.__exception__( 109 f"List must have a single type argument, got: {args}", 110 raise_exception=True, 111 ) 112 return {list: self.__get_checkable_type__(args[0])} 113 114 if origin == dict: 115 if len(args) != 2: 116 self.__exception__( 117 f"Dict must have two type arguments, got: {args}", 118 raise_exception=True, 119 ) 120 key_type = self.__get_checkable_type__(args[0]) 121 value_type = self.__get_checkable_type__(args[1]) 122 return {dict: (key_type, value_type)} 123 124 if origin == tuple: 125 if len(args) > 2 or len(args) == 1: 126 if Ellipsis in args: 127 self.__exception__( 128 "Tuple with Ellipsis must have exactly two type arguments and the second must be Ellipsis.", 129 raise_exception=True, 130 ) 131 if len(args) == 2: 132 if args[0] is Ellipsis: 133 self.__exception__( 134 "Tuple with Ellipsis must have exactly two type arguments and the first must not be Ellipsis.", 135 raise_exception=True, 136 ) 137 if args[1] is Ellipsis: 138 return {tuple: (self.__get_checkable_type__(args[0]), True)} 139 return { 140 tuple: ( 141 tuple(self.__get_checkable_type__(arg) for arg in args), 142 False, 143 ) 144 } 145 146 if origin == set: 147 if len(args) != 1: 148 self.__exception__( 149 f"Set must have a single type argument, got: {args}", 150 raise_exception=True, 151 ) 152 return {set: self.__get_checkable_type__(args[0])} 153 154 # Handle Sized types 155 if annotation == Sized: 156 return { 157 list: None, 158 tuple: None, 159 dict: None, 160 set: None, 161 str: None, 162 bytes: None, 163 bytearray: None, 164 memoryview: None, 165 range: None, 166 } 167 168 # Handle Callable types 169 if annotation == Callable: 170 return { 171 staticmethod: None, 172 classmethod: None, 173 FunctionType: None, 174 BuiltinFunctionType: None, 175 MethodType: None, 176 BuiltinMethodType: None, 177 GeneratorType: None, 178 } 179 180 if annotation == Any: 181 return { 182 object: None, 183 } 184 185 # Handle Constraints 186 if isinstance(annotation, GenericConstraint): 187 return {"__extra__": {"__constraints__": [annotation]}} 188 189 # Handle standard types 190 if isinstance(annotation, type): 191 return {annotation: None} 192 193 # Hanldle typing.Type (for uninitialized classes) 194 if origin is type and len(args) == 1: 195 return {annotation: None} 196 197 self.__exception__( 198 f"Unsupported type hint: {annotation}", raise_exception=True 199 ) 200 201 def __exception__(self, message, raise_exception=False): 202 """ 203 Usage: 204 205 - Creates a class based exception message 206 207 Requires: 208 209 - `message`: 210 - Type: str 211 - What: The message to warn users with 212 213 Optional: 214 215 - `raise_exception`: 216 - Type: bool 217 - What: Forces an exception to be raised regardless of the `self.__strict__` setting. 218 - Default: False 219 """ 220 if self.__strict__ or raise_exception: 221 raise TypeError( 222 f"TypeEnforced Exception ({self.__fn__.__qualname__}): {message}" 223 ) 224 else: 225 print( 226 f"TypeEnforced Warning ({self.__fn__.__qualname__}): {message}" 227 ) 228 229 def __get__(self, obj, objtype): 230 """ 231 Overwrite standard __get__ method to return __call__ instead for wrapped class methods. 232 233 Also stores the calling (__get__) `obj` to be passed as an initial argument for `__call__` such that methods can pass `self` correctly. 234 """ 235 236 @wraps(self.__fn__) 237 def __get_fn__(*args, **kwargs): 238 return self.__call__(*args, **kwargs) 239 240 self.__outer_self__ = obj 241 return __get_fn__ 242 243 def __check_method_function__(self): 244 """ 245 Validate that `self.__fn__` is a method or function 246 """ 247 if not isinstance(self.__fn__, (MethodType, FunctionType)): 248 raise Exception( 249 f"A non function/method was passed to Enforcer. See the stack trace above for more information." 250 ) 251 252 def __call__(self, *args, **kwargs): 253 """ 254 This method is used to validate the passed inputs and return the output of the wrapped function or method. 255 """ 256 # Special code to pass self as an initial argument 257 # for validation purposes in methods 258 # See: self.__get__ 259 if self.__outer_self__ is not None: 260 args = (self.__outer_self__, *args) 261 # Get a dictionary of all annotations as checkable types 262 # Note: This is only done once at first call to avoid redundant calculations 263 self.__get_checkable_types__() 264 # Create a compreshensive dictionary of assigned variables (order matters) 265 assigned_vars = { 266 **self.__fn_defaults__, 267 **dict(zip(self.__fn__.__code__.co_varnames[: len(args)], args)), 268 **kwargs, 269 } 270 # Validate all listed annotations vs the assigned_vars dictionary 271 for key, value in self.__checkable_types__.items(): 272 self.__check_type__(assigned_vars.get(key), value, key) 273 # Execute the function callable 274 return_value = self.__fn__(*args, **kwargs) 275 # If a return type was passed, validate the returned object 276 if self.__return_type__ is not None: 277 self.__check_type__(return_value, self.__return_type__, "return") 278 return return_value 279 280 def __quick_check__(self, subtype, obj): 281 if all([v == None for v in subtype.values()]): 282 # If the subtype does not contain iterables with typing, we can validate the items directly. 283 types = set(subtype.keys()) 284 values = set([type(v) for v in obj]) 285 if values.issubset(types): 286 # We can return True to bypass the full validation 287 return True 288 # Otherwise, validation did not pass and a full validation is required to raise an indexed/keyed type mismatch error 289 return False 290 291 def __check_type__(self, obj, expected, key): 292 """ 293 Raises an exception the type of a passed `obj` (parameter) is not in the list of supplied `acceptable_types` for the argument. 294 """ 295 # Special case for None 296 if obj is None and type(None) in expected: 297 return 298 extra = expected.get("__extra__", {}) 299 expected = {k: v for k, v in expected.items() if k != "__extra__"} 300 301 if isinstance(obj, type): 302 # An uninitialized class is passed, we need to check if the type is in the expected types using Type[obj] 303 obj_type = Type[obj] 304 is_present = obj_type in expected 305 else: 306 obj_type = type(obj) 307 is_present = isinstance(obj, tuple(expected.keys())) 308 309 if not is_present: 310 # Allow for literals to be used to bypass type checks if present 311 literal = extra.get("__literal__", ()) 312 if literal: 313 if obj not in literal: 314 self.__exception__( 315 f"Type mismatch for typed variable `{key}`. Expected one of the following `{list(expected.keys())}` or a literal value in `{literal}` but got type `{obj_type}` with value `{obj}` instead." 316 ) 317 # Raise an exception if the type is not in the expected types 318 else: 319 self.__exception__( 320 f"Type mismatch for typed variable `{key}`. Expected one of the following `{list(expected.keys())}` but got `{obj_type}` with value `{obj}` instead." 321 ) 322 # If the object_type is in the expected types, we can proceed with validation 323 elif obj_type in iterable_types: 324 subtype = expected.get(obj_type, None) 325 if subtype is None: 326 pass 327 # Recursive validation 328 elif obj_type == list: 329 # If the subtype does not contain iterables with typing, we can validate the items directly. 330 if not self.__quick_check__(subtype, obj): 331 for idx, item in enumerate(obj): 332 self.__check_type__(item, subtype, f"{key}[{idx}]") 333 elif obj_type == dict: 334 key_type, val_type = subtype 335 if not self.__quick_check__(key_type, obj.keys()): 336 for key in obj.keys(): 337 self.__check_type__( 338 key, key_type, f"{key}.key[{repr(key)}]" 339 ) 340 if not self.__quick_check__(val_type, obj.values()): 341 for key, value in obj.items(): 342 self.__check_type__( 343 value, val_type, f"{key}[{repr(key)}]" 344 ) 345 elif obj_type == tuple: 346 expected_args, is_ellipsis = subtype 347 if is_ellipsis: 348 if not self.__quick_check__(expected_args, obj): 349 for idx, item in enumerate(obj): 350 self.__check_type__( 351 item, expected_args, f"{key}[{idx}]" 352 ) 353 else: 354 if len(obj) != len(expected_args): 355 self.__exception__( 356 f"Tuple length mismatch for `{key}`. Expected length {len(expected_args)}, got {len(obj)}" 357 ) 358 for idx, (item, ex) in enumerate(zip(obj, expected_args)): 359 self.__check_type__(item, ex, f"{key}[{idx}]") 360 elif obj_type == set: 361 if not self.__quick_check__(subtype, obj): 362 for item in obj: 363 self.__check_type__( 364 item, subtype, f"{key}[{repr(item)}]" 365 ) 366 367 # Validate constraints if any are present 368 constraints = extra.get("__constraints__", []) 369 for constraint in constraints: 370 constraint_validation_output = constraint.__validate__(key, obj) 371 if constraint_validation_output is not True: 372 self.__exception__( 373 f"Constraint validation error for variable `{key}` with value `{obj}`. {constraint_validation_output}" 374 ) 375 376 def __repr__(self): 377 return f"<type_enforced {self.__fn__.__module__}.{self.__fn__.__qualname__} object at {hex(id(self))}>"
FunctionMethodEnforcer(__fn__, __strict__=False)
21 def __init__(self, __fn__, __strict__=False): 22 """ 23 Initialize a FunctionMethodEnforcer class object as a wrapper for a passed function `__fn__`. 24 25 Requires: 26 27 - `__fn__`: 28 - What: The function to enforce 29 - Type: function | method | class 30 """ 31 update_wrapper(self, __fn__) 32 self.__fn__ = __fn__ 33 self.__strict__ = __strict__ 34 self.__outer_self__ = None 35 # Validate that the passed function or method is a method or function 36 self.__check_method_function__() 37 # Get input defaults for the function or method 38 self.__get_defaults__()
Initialize a FunctionMethodEnforcer class object as a wrapper for a passed function __fn__
.
Requires:
- `__fn__`:
- What: The function to enforce
- Type: function | method | class
@Partial
def
Enforcer(clsFnMethod, enabled=True, strict=True):
380@Partial 381def Enforcer(clsFnMethod, enabled=True, strict=True): 382 """ 383 A wrapper to enforce types within a function or method given argument annotations. 384 385 Each wrapped item is converted into a special `FunctionMethodEnforcer` class object that validates the passed parameters for the function or method when it is called. If a function or method that is passed does not have any annotations, it is not converted into a `FunctionMethodEnforcer` class as no validation is possible. 386 387 If wrapping a class, all methods in the class that meet any of the following criteria will be wrapped individually: 388 389 - Methods with `__call__` 390 - Methods wrapped with `staticmethod` (if python >= 3.10) 391 - Methods wrapped with `classmethod` (if python >= 3.10) 392 393 Requires: 394 395 - `clsFnMethod`: 396 - What: The class, function or method that should have input types enforced 397 - Type: function | method | class 398 399 Optional: 400 401 - `enabled`: 402 - What: A boolean to enable or disable the enforcer 403 - Type: bool 404 - Default: True 405 - `strict`: 406 - What: A boolean to enable or disable exceptions. If True, exceptions will be raised when type checking fails. If False, exceptions will not be raised but instead a warning will be printed to the console. 407 - Type: bool 408 - Default: False 409 - Note: Type hints that are wrapped with the type enforcer and are invalid will still raise an exception. 410 411 412 Example Use: 413 ``` 414 >>> import type_enforced 415 >>> @type_enforced.Enforcer 416 ... def my_fn(a: int , b: [int, str] =2, c: int =3) -> None: 417 ... pass 418 ... 419 >>> my_fn(a=1, b=2, c=3) 420 >>> my_fn(a=1, b='2', c=3) 421 >>> my_fn(a='a', b=2, c=3) 422 Traceback (most recent call last): 423 File "<stdin>", line 1, in <module> 424 File "/home/conmak/development/personal/type_enforced/type_enforced/enforcer.py", line 85, in __call__ 425 self.__check_type__(assigned_vars.get(key), value, key) 426 File "/home/conmak/development/personal/type_enforced/type_enforced/enforcer.py", line 107, in __check_type__ 427 self.__exception__( 428 File "/home/conmak/development/personal/type_enforced/type_enforced/enforcer.py", line 34, in __exception__ 429 raise Exception(f"({self.__fn__.__qualname__}): {message}") 430 Exception: (my_fn): Type mismatch for typed variable `a`. Expected one of the following `[<class 'int'>]` but got `<class 'str'>` instead. 431 ``` 432 """ 433 if not hasattr(clsFnMethod, "__type_enforced_enabled__"): 434 # Special try except clause to handle cases when the object is immutable 435 try: 436 clsFnMethod.__type_enforced_enabled__ = enabled 437 except: 438 return clsFnMethod 439 if not clsFnMethod.__type_enforced_enabled__: 440 return clsFnMethod 441 if isinstance( 442 clsFnMethod, (staticmethod, classmethod, FunctionType, MethodType) 443 ): 444 # Only apply the enforcer if type_hints are present 445 if get_type_hints(clsFnMethod) == {}: 446 return clsFnMethod 447 elif isinstance(clsFnMethod, staticmethod): 448 return staticmethod( 449 FunctionMethodEnforcer( 450 __fn__=clsFnMethod.__func__, __strict__=strict 451 ) 452 ) 453 elif isinstance(clsFnMethod, classmethod): 454 return classmethod( 455 FunctionMethodEnforcer( 456 __fn__=clsFnMethod.__func__, __strict__=strict 457 ) 458 ) 459 else: 460 return FunctionMethodEnforcer(__fn__=clsFnMethod, __strict__=strict) 461 elif hasattr(clsFnMethod, "__dict__"): 462 for key, value in clsFnMethod.__dict__.items(): 463 # Skip the __annotate__ method if present in __dict__ as it deletes itself upon invocation 464 # Skip any previously wrapped methods if they are already a FunctionMethodEnforcer 465 if key == "__annotate__" or isinstance( 466 value, FunctionMethodEnforcer 467 ): 468 continue 469 if hasattr(value, "__call__") or isinstance( 470 value, (classmethod, staticmethod) 471 ): 472 setattr( 473 clsFnMethod, 474 key, 475 Enforcer(value, enabled=enabled, strict=strict), 476 ) 477 return clsFnMethod 478 else: 479 raise Exception( 480 "Enforcer can only be used on classes, methods, or functions." 481 )
A wrapper to enforce types within a function or method given argument annotations.
Each wrapped item is converted into a special FunctionMethodEnforcer
class object that validates the passed parameters for the function or method when it is called. If a function or method that is passed does not have any annotations, it is not converted into a FunctionMethodEnforcer
class as no validation is possible.
If wrapping a class, all methods in the class that meet any of the following criteria will be wrapped individually:
- Methods with
__call__
- Methods wrapped with
staticmethod
(if python >= 3.10) - Methods wrapped with
classmethod
(if python >= 3.10)
Requires:
clsFnMethod
:- What: The class, function or method that should have input types enforced
- Type: function | method | class
Optional:
enabled
:- What: A boolean to enable or disable the enforcer
- Type: bool
- Default: True
strict
:- What: A boolean to enable or disable exceptions. If True, exceptions will be raised when type checking fails. If False, exceptions will not be raised but instead a warning will be printed to the console.
- Type: bool
- Default: False
- Note: Type hints that are wrapped with the type enforcer and are invalid will still raise an exception.
Example Use:
>>> import type_enforced
>>> @type_enforced.Enforcer
... def my_fn(a: int , b: [int, str] =2, c: int =3) -> None:
... pass
...
>>> my_fn(a=1, b=2, c=3)
>>> my_fn(a=1, b='2', c=3)
>>> my_fn(a='a', b=2, c=3)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/home/conmak/development/personal/type_enforced/type_enforced/enforcer.py", line 85, in __call__
self.__check_type__(assigned_vars.get(key), value, key)
File "/home/conmak/development/personal/type_enforced/type_enforced/enforcer.py", line 107, in __check_type__
self.__exception__(
File "/home/conmak/development/personal/type_enforced/type_enforced/enforcer.py", line 34, in __exception__
raise Exception(f"({self.__fn__.__qualname__}): {message}")
Exception: (my_fn): Type mismatch for typed variable `a`. Expected one of the following `[<class 'int'>]` but got `<class 'str'>` instead.