Coverage for src/appl/core/globals.py: 86%
56 statements
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-22 15:39 -0800
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-22 15:39 -0800
1import contextvars
2import threading
3from argparse import Namespace
4from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
5from enum import Enum
6from typing import Any, Union
8# Singleton stats object
9global_vars = Namespace()
10global_vars.lock = threading.Lock()
12# tracing
13global_vars.trace_engine = None
14global_vars.gen_cnt = 0
15global_vars.current_func = contextvars.ContextVar("current_func", default=None)
17# thread-level vars
18thread_local = threading.local()
20# streaming
21global_vars.live = None
22global_vars.live_lock = threading.Lock()
24# executors (to be replaced by appl.init())
25global_vars.llm_thread_executor = ThreadPoolExecutor(
26 max_workers=10, thread_name_prefix="llm"
27)
28global_vars.general_thread_executor = ThreadPoolExecutor(
29 max_workers=20, thread_name_prefix="general"
30)
31global_vars.general_process_executor = ProcessPoolExecutor(max_workers=10)
34def get_thread_local(name: str, default: Any = None) -> Any:
35 """Get the value of a thread-local variable."""
36 return getattr(thread_local, name, default)
39def set_thread_local(name: str, value: Any) -> None:
40 """Set the value of a thread-local variable."""
41 setattr(thread_local, name, value)
44def inc_thread_local(name: str, delta: Union[int, float] = 1) -> Any:
45 """Increment a thread-local variable by a delta and return the new value."""
46 value = get_thread_local(name, 0)
47 value += delta
48 setattr(thread_local, name, value)
49 return value
52def get_global_var(name: str, default: Any = None) -> Any:
53 """Get the value of a global variable."""
54 with global_vars.lock:
55 return getattr(global_vars, name, default)
58def set_global_var(name: str, value: Any) -> None:
59 """Set the value of a global variable."""
60 with global_vars.lock:
61 setattr(global_vars, name, value)
64def inc_global_var(name: str, delta: Union[int, float] = 1) -> Any:
65 """Increment a global variable by a delta and return the new value."""
66 with global_vars.lock:
67 value = getattr(global_vars, name, 0)
68 value += delta
69 setattr(global_vars, name, value)
70 return value
73class ExecutorType(str, Enum):
74 """The type of the executor."""
76 LLM_THREAD_POOL = "llm_thread_pool"
77 GENERAL_THREAD_POOL = "general_thread_pool"
78 GENERAL_PROCESS_POOL = "general_process_pool"
79 NEW_THREAD = "new_thread"
80 NEW_PROCESS = "new_process"
83def get_executor(
84 executor_type: ExecutorType,
85) -> Union[ThreadPoolExecutor, ProcessPoolExecutor]:
86 """Get the executor of a given type."""
87 if executor_type == ExecutorType.LLM_THREAD_POOL:
88 return global_vars.llm_thread_executor
89 elif executor_type == ExecutorType.GENERAL_THREAD_POOL:
90 return global_vars.general_thread_executor
91 elif executor_type == ExecutorType.GENERAL_PROCESS_POOL:
92 return global_vars.general_process_executor
93 elif executor_type == ExecutorType.NEW_THREAD:
94 return ThreadPoolExecutor(max_workers=1)
95 elif executor_type == ExecutorType.NEW_PROCESS:
96 return ProcessPoolExecutor(max_workers=1)
97 else:
98 raise ValueError(f"Invalid executor type: {executor_type}")