Coverage for src/appl/core/patch.py: 89%

19 statements  

« prev     ^ index     » next       coverage.py v7.6.7, created at 2024-11-22 15:39 -0800

1import contextvars 

2import threading 

3 

4APPL_PATCHED_NAME = "__APPL_PATCHED__" 

5 

6 

7def _get_new_target(target): 

8 # copy the current context and propagate it to the function in the new thread 

9 ctx = contextvars.copy_context() 

10 

11 def new_target(*args, **kwargs): 

12 return ctx.run(target, *args, **kwargs) 

13 

14 return new_target 

15 

16 

17def patch_threading() -> None: 

18 """Patch threading.Thread to automatically wrap the target with context.""" 

19 if not hasattr(threading.Thread, APPL_PATCHED_NAME): 

20 # print("patching threading.Thread") 

21 class ThreadWithContext(threading.Thread): 

22 def __init__(self, *args, **kwargs): 

23 if "target" in kwargs and kwargs["target"] is not None: 

24 kwargs["target"] = _get_new_target(kwargs["target"]) 

25 elif len(args) > 0 and args[0] is not None: 

26 args = (_get_new_target(args[0]),) + args[1:] 

27 super().__init__(*args, **kwargs) 

28 

29 setattr(ThreadWithContext, APPL_PATCHED_NAME, True) 

30 threading.Thread = ThreadWithContext # type: ignore