Coverage for src/appl/core/patch.py: 89%
19 statements
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-22 15:39 -0800
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-22 15:39 -0800
1import contextvars
2import threading
4APPL_PATCHED_NAME = "__APPL_PATCHED__"
7def _get_new_target(target):
8 # copy the current context and propagate it to the function in the new thread
9 ctx = contextvars.copy_context()
11 def new_target(*args, **kwargs):
12 return ctx.run(target, *args, **kwargs)
14 return new_target
17def patch_threading() -> None:
18 """Patch threading.Thread to automatically wrap the target with context."""
19 if not hasattr(threading.Thread, APPL_PATCHED_NAME):
20 # print("patching threading.Thread")
21 class ThreadWithContext(threading.Thread):
22 def __init__(self, *args, **kwargs):
23 if "target" in kwargs and kwargs["target"] is not None:
24 kwargs["target"] = _get_new_target(kwargs["target"])
25 elif len(args) > 0 and args[0] is not None:
26 args = (_get_new_target(args[0]),) + args[1:]
27 super().__init__(*args, **kwargs)
29 setattr(ThreadWithContext, APPL_PATCHED_NAME, True)
30 threading.Thread = ThreadWithContext # type: ignore