Coverage for src/appl/core/globals.py: 86%

56 statements  

« prev     ^ index     » next       coverage.py v7.6.7, created at 2024-11-22 15:39 -0800

1import contextvars 

2import threading 

3from argparse import Namespace 

4from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor 

5from enum import Enum 

6from typing import Any, Union 

7 

8# Singleton stats object 

9global_vars = Namespace() 

10global_vars.lock = threading.Lock() 

11 

12# tracing 

13global_vars.trace_engine = None 

14global_vars.gen_cnt = 0 

15global_vars.current_func = contextvars.ContextVar("current_func", default=None) 

16 

17# thread-level vars 

18thread_local = threading.local() 

19 

20# streaming 

21global_vars.live = None 

22global_vars.live_lock = threading.Lock() 

23 

24# executors (to be replaced by appl.init()) 

25global_vars.llm_thread_executor = ThreadPoolExecutor( 

26 max_workers=10, thread_name_prefix="llm" 

27) 

28global_vars.general_thread_executor = ThreadPoolExecutor( 

29 max_workers=20, thread_name_prefix="general" 

30) 

31global_vars.general_process_executor = ProcessPoolExecutor(max_workers=10) 

32 

33 

34def get_thread_local(name: str, default: Any = None) -> Any: 

35 """Get the value of a thread-local variable.""" 

36 return getattr(thread_local, name, default) 

37 

38 

39def set_thread_local(name: str, value: Any) -> None: 

40 """Set the value of a thread-local variable.""" 

41 setattr(thread_local, name, value) 

42 

43 

44def inc_thread_local(name: str, delta: Union[int, float] = 1) -> Any: 

45 """Increment a thread-local variable by a delta and return the new value.""" 

46 value = get_thread_local(name, 0) 

47 value += delta 

48 setattr(thread_local, name, value) 

49 return value 

50 

51 

52def get_global_var(name: str, default: Any = None) -> Any: 

53 """Get the value of a global variable.""" 

54 with global_vars.lock: 

55 return getattr(global_vars, name, default) 

56 

57 

58def set_global_var(name: str, value: Any) -> None: 

59 """Set the value of a global variable.""" 

60 with global_vars.lock: 

61 setattr(global_vars, name, value) 

62 

63 

64def inc_global_var(name: str, delta: Union[int, float] = 1) -> Any: 

65 """Increment a global variable by a delta and return the new value.""" 

66 with global_vars.lock: 

67 value = getattr(global_vars, name, 0) 

68 value += delta 

69 setattr(global_vars, name, value) 

70 return value 

71 

72 

73class ExecutorType(str, Enum): 

74 """The type of the executor.""" 

75 

76 LLM_THREAD_POOL = "llm_thread_pool" 

77 GENERAL_THREAD_POOL = "general_thread_pool" 

78 GENERAL_PROCESS_POOL = "general_process_pool" 

79 NEW_THREAD = "new_thread" 

80 NEW_PROCESS = "new_process" 

81 

82 

83def get_executor( 

84 executor_type: ExecutorType, 

85) -> Union[ThreadPoolExecutor, ProcessPoolExecutor]: 

86 """Get the executor of a given type.""" 

87 if executor_type == ExecutorType.LLM_THREAD_POOL: 

88 return global_vars.llm_thread_executor 

89 elif executor_type == ExecutorType.GENERAL_THREAD_POOL: 

90 return global_vars.general_thread_executor 

91 elif executor_type == ExecutorType.GENERAL_PROCESS_POOL: 

92 return global_vars.general_process_executor 

93 elif executor_type == ExecutorType.NEW_THREAD: 

94 return ThreadPoolExecutor(max_workers=1) 

95 elif executor_type == ExecutorType.NEW_PROCESS: 

96 return ProcessPoolExecutor(max_workers=1) 

97 else: 

98 raise ValueError(f"Invalid executor type: {executor_type}")