Coverage for src/appl/__init__.py: 73%

145 statements  

« prev     ^ index     » next       coverage.py v7.6.7, created at 2024-11-22 15:39 -0800

1"""appl - A Prompt Programming Language.""" 

2 

3from __future__ import annotations 

4 

5import datetime 

6import inspect 

7import os 

8import sys 

9import threading 

10from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor 

11from contextlib import contextmanager 

12 

13import pendulum 

14import toml 

15import yaml 

16from dotenv import load_dotenv 

17from loguru import logger 

18 

19logger.remove() # Remove default handler 

20logger.add(sys.stderr, level="INFO") # set to INFO 

21 

22from typing import Any, Callable, Dict, Optional 

23 

24from .compositor import ApplStr as Str 

25from .compositor import iter 

26from .core import ( 

27 BracketedDefinition, 

28 CallFuture, 

29 CompletionResponse, 

30 Definition, 

31 Generation, 

32 Image, 

33 Indexing, 

34 Promptable, 

35 PromptContext, 

36 PromptPrinter, 

37 PromptRecords, 

38 StringFuture, 

39 Tool, 

40) 

41from .core import appl_compile as compile 

42from .core import appl_execute as execute 

43from .core import appl_format as format 

44from .core import appl_with_ctx as with_ctx 

45from .core.config import Configs, configs, load_config 

46from .core.generation import get_gen_name_prefix, set_gen_name_prefix 

47from .core.globals import global_vars 

48from .core.io import dump_file, load_file 

49from .core.message import ( 

50 AIMessage, 

51 Conversation, 

52 SystemMessage, 

53 ToolMessage, 

54 UserMessage, 

55) 

56from .core.patch import patch_threading 

57from .core.promptable import define, define_bracketed, promptify 

58from .core.trace import traceable 

59from .core.utils import need_ctx, partial, wraps 

60from .func import ( 

61 as_func, 

62 as_tool, 

63 as_tool_choice, 

64 call, 

65 convo, 

66 empty_line, 

67 gen, 

68 grow, 

69 openai_tool_schema, 

70 ppl, 

71 records, 

72 reset_context, 

73 str_future, 

74) 

75from .role_changer import AIRole, SystemRole, ToolRole, UserRole 

76from .servers import server_manager 

77from .tracing import TraceEngine 

78from .utils import ( 

79 LoguruFormatter, 

80 find_dotenv, 

81 find_files, 

82 get_folder, 

83 get_meta_file, 

84 timeit, 

85) 

86from .version import __version__ 

87 

88 

89def _get_loguru_format(): 

90 return LoguruFormatter( 

91 max_length=configs.getattrs("settings.logging.max_length"), 

92 suffix_length=configs.getattrs("settings.logging.suffix_length"), 

93 ).loguru_format 

94 

95 

96logger.remove() # Remove default handler 

97# update default handler for the loguru logger 

98logger.add(sys.stderr, level="INFO", format=_get_loguru_format()) # default 

99global_vars.initialized = False 

100 

101 

102def init( 

103 resume_cache: Optional[str] = None, 

104 update_config_hook: Optional[Callable] = None, 

105) -> None: 

106 """Initialize APPL with dotenv and config files. 

107 

108 Args: 

109 resume_cache: Path to the trace file used as resume cache. Defaults to None. 

110 update_config_hook: A hook to update the configs. Defaults to None. 

111 

112 Examples: 

113 ```python 

114 import appl 

115 

116 # Load environment variables from `.env` and configs from `appl.yaml`. 

117 # Initialize logging and tracing systems if enabled. 

118 appl.init() 

119 ``` 

120 """ 

121 with global_vars.lock: 

122 # only initialize once 

123 if global_vars.initialized: 

124 logger.warning("APPL has already been initialized, ignore") 

125 return 

126 global_vars.initialized = True 

127 

128 now = pendulum.instance(datetime.datetime.now()) 

129 # Get the previous frame in the stack, i.e., the one calling this function 

130 frame = inspect.currentframe() 

131 if frame and frame.f_back: 

132 caller_path = frame.f_back.f_code.co_filename # Get file_path of the caller 

133 caller_funcname = frame.f_back.f_code.co_name # Get function name of the caller 

134 caller_basename = os.path.basename(caller_path).split(".")[0] 

135 caller_folder = os.path.dirname(caller_path) # Get folder of the caller 

136 caller_folder = get_folder(caller_folder) 

137 dotenvs = find_files(caller_folder, [".env"]) 

138 appl_config_files = find_files( 

139 caller_folder, ["appl.yaml", "appl.yml", "appl.json", "appl.toml"] 

140 ) 

141 # load dotenvs and appl configs from outer to inner with override 

142 for dotenv in dotenvs[::-1]: 

143 load_dotenv(dotenv, override=True) 

144 logger.info("Loaded dotenv from {}".format(dotenv)) 

145 for config_file in appl_config_files[::-1]: 

146 override_configs = load_config(config_file) 

147 logger.info("Loaded configs from {}".format(config_file)) 

148 configs.update(override_configs) 

149 if configs.getattrs("settings.logging.display.configs_update"): 

150 logger.info(f"update configs:\n{yaml.dump(override_configs.to_dict())}") 

151 else: 

152 caller_basename, caller_funcname = "appl", "<module>" 

153 dotenvs, appl_config_files = [], [] 

154 logger.error( 

155 "Cannot find the caller of appl.init(), fail to load .env and appl configs" 

156 ) 

157 

158 if update_config_hook: 

159 update_config_hook(configs) 

160 

161 # ============================================================ 

162 # Logging 

163 # ============================================================ 

164 log_format = configs.getattrs("settings.logging.format") 

165 log_level = configs.getattrs("settings.logging.log_level") 

166 log_file = configs.getattrs("settings.logging.log_file") 

167 # set logger level for loguru 

168 logger.remove() # Remove default handler 

169 logger.add(sys.stderr, level=log_level, format=_get_loguru_format()) 

170 if log_file.get("enabled", False): 

171 if (log_file_format := log_file.get("path_format", None)) is not None: 

172 log_file_path = ( 

173 log_file_format.format( 

174 basename=caller_basename, funcname=caller_funcname, time=now 

175 ) 

176 + ".log" 

177 ) 

178 log_file.path = log_file_path # set the top level log file path 

179 file_log_level = log_file.get("log_level", None) or log_level 

180 logger.info(f"Logging to file: {log_file_path} with level {file_log_level}") 

181 # no need to overwrite the default format when writing to file 

182 logger.add(log_file_path, level=file_log_level, format=log_format) 

183 

184 configs["info"] = Configs( 

185 { 

186 "start_time": now.format("YYYY-MM-DD HH:mm:ss"), 

187 "dotenvs": dotenvs, 

188 "appl_configs": appl_config_files, 

189 } 

190 ) 

191 if configs.getattrs("settings.logging.display.configs"): 

192 logger.info(f"Using configs:\n{yaml.dump(configs.to_dict())}") 

193 

194 # ============================================================ 

195 # Concurrency 

196 # ============================================================ 

197 concurrency = configs.getattrs("settings.concurrency") 

198 llm_max_workers = concurrency.get("llm_max_workers", 10) 

199 thread_max_workers = concurrency.get("thread_max_workers", 20) 

200 process_max_workers = concurrency.get("process_max_workers", 10) 

201 global_vars.llm_thread_executor = ThreadPoolExecutor( 

202 max_workers=llm_max_workers, thread_name_prefix="llm" 

203 ) 

204 global_vars.thread_executor = ThreadPoolExecutor( 

205 max_workers=thread_max_workers, thread_name_prefix="general" 

206 ) 

207 global_vars.process_executor = ProcessPoolExecutor(max_workers=process_max_workers) 

208 

209 # ============================================================ 

210 # Tracing 

211 # ============================================================ 

212 tracing = configs.getattrs("settings.tracing") 

213 strict_match = tracing.get("strict_match", True) 

214 if tracing.get("enabled", False): 

215 if tracing.get("patch_threading", True): 

216 patch_threading() 

217 if (trace_file_format := tracing.get("path_format", None)) is not None: 

218 prefix = trace_file_format.format( 

219 basename=caller_basename, funcname=caller_funcname, time=now 

220 ) 

221 trace_file_path = f"{prefix}.pkl" 

222 meta_file = f"{prefix}_meta.json" 

223 tracing.trace_file = trace_file_path 

224 logger.info(f"Tracing file: {trace_file_path}") 

225 dump_file(configs.to_dict(), meta_file) 

226 global_vars.trace_engine = TraceEngine( 

227 trace_file_path, mode="write", strict=strict_match 

228 ) 

229 else: 

230 logger.warning("Tracing is enabled but no trace file is specified") 

231 

232 resume_cache = resume_cache or os.environ.get("APPL_RESUME_TRACE", None) 

233 if resume_cache: 

234 global_vars.resume_cache = resume_cache 

235 logger.info(f"Using resume cache: {resume_cache}") 

236 global_vars.resume_cache = TraceEngine( 

237 resume_cache, mode="read", strict=strict_match 

238 ) 

239 

240 

241@contextmanager 

242def init_within_thread( 

243 log_file_prefix: Optional[str] = None, gen_name_prefix: Optional[str] = None 

244) -> Any: 

245 """Initialize APPL to work with multi-threading, including logging and tracing. 

246 

247 Args: 

248 log_file_prefix: The prefix for the log file. Defaults to use the path of the main log file. 

249 gen_name_prefix: The prefix for the generation name. Defaults to use the thread name. 

250 

251 Examples: 

252 ```python 

253 def function_run_in_thread(): 

254 with appl.init_within_thread(): 

255 # do something within the thread 

256 ``` 

257 """ 

258 handler_id = None 

259 

260 try: 

261 thread_name = threading.current_thread().name 

262 log_format = configs.getattrs("settings.logging.format") 

263 log_file = configs.getattrs("settings.logging.log_file") 

264 

265 def filter_thread_record(record: Dict) -> bool: 

266 assert hasattr(record["thread"], "name") 

267 # Use prefix match to filter the log records in different threads 

268 name = record["thread"].name 

269 return name == thread_name or name.startswith(thread_name + "_") 

270 

271 if log_file.get("enabled", False): 

272 if log_file_prefix is None: 

273 if "path" not in log_file: 

274 raise ValueError( 

275 "main log file is not set, did you forget to call appl.init()?" 

276 ) 

277 thread_log_path = os.path.join( 

278 log_file.path[: -len(".log")] + "_logs", f"{thread_name}.log" 

279 ) 

280 else: 

281 thread_log_path = f"{log_file_prefix}_{thread_name}.log" 

282 

283 log_level = log_file.get("log_level", None) 

284 log_level = log_level or configs.getattrs("settings.logging.log_level") 

285 # The logger append to the file by default, not overwrite. 

286 handler_id = logger.add( 

287 thread_log_path, 

288 level=log_level, 

289 format=log_format, 

290 filter=filter_thread_record, # type: ignore 

291 ) 

292 if gen_name_prefix: 

293 set_gen_name_prefix(gen_name_prefix) 

294 # ? shall we reset the prefix after exiting the context? 

295 logger.info( 

296 f"Thread {thread_name}, set generation name prefix as: {gen_name_prefix}" 

297 ) 

298 

299 if handler_id is None: 

300 logger.warning("logging is not enabled") 

301 yield thread_log_path 

302 except Exception as e: 

303 logger.error(f"Error in thread: {e}") 

304 raise e 

305 finally: 

306 if handler_id: 

307 logger.remove(handler_id)