Coverage for src/appl/ 73%

145 statements  

« prev     ^ index     » next v7.6.7, created at 2024-11-22 15:39 -0800

1"""appl - A Prompt Programming Language.""" 


3from __future__ import annotations 


5import datetime 

6import inspect 

7import os 

8import sys 

9import threading 

10from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor 

11from contextlib import contextmanager 


13import pendulum 

14import toml 

15import yaml 

16from dotenv import load_dotenv 

17from loguru import logger 


19logger.remove() # Remove default handler 

20logger.add(sys.stderr, level="INFO") # set to INFO 


22from typing import Any, Callable, Dict, Optional 


24from .compositor import ApplStr as Str 

25from .compositor import iter 

26from .core import ( 

27 BracketedDefinition, 

28 CallFuture, 

29 CompletionResponse, 

30 Definition, 

31 Generation, 

32 Image, 

33 Indexing, 

34 Promptable, 

35 PromptContext, 

36 PromptPrinter, 

37 PromptRecords, 

38 StringFuture, 

39 Tool, 


41from .core import appl_compile as compile 

42from .core import appl_execute as execute 

43from .core import appl_format as format 

44from .core import appl_with_ctx as with_ctx 

45from .core.config import Configs, configs, load_config 

46from .core.generation import get_gen_name_prefix, set_gen_name_prefix 

47from .core.globals import global_vars 

48from import dump_file, load_file 

49from .core.message import ( 

50 AIMessage, 

51 Conversation, 

52 SystemMessage, 

53 ToolMessage, 

54 UserMessage, 


56from .core.patch import patch_threading 

57from .core.promptable import define, define_bracketed, promptify 

58from .core.trace import traceable 

59from .core.utils import need_ctx, partial, wraps 

60from .func import ( 

61 as_func, 

62 as_tool, 

63 as_tool_choice, 

64 call, 

65 convo, 

66 empty_line, 

67 gen, 

68 grow, 

69 openai_tool_schema, 

70 ppl, 

71 records, 

72 reset_context, 

73 str_future, 


75from .role_changer import AIRole, SystemRole, ToolRole, UserRole 

76from .servers import server_manager 

77from .tracing import TraceEngine 

78from .utils import ( 

79 LoguruFormatter, 

80 find_dotenv, 

81 find_files, 

82 get_folder, 

83 get_meta_file, 

84 timeit, 


86from .version import __version__ 



89def _get_loguru_format(): 

90 return LoguruFormatter( 

91 max_length=configs.getattrs("settings.logging.max_length"), 

92 suffix_length=configs.getattrs("settings.logging.suffix_length"), 

93 ).loguru_format 



96logger.remove() # Remove default handler 

97# update default handler for the loguru logger 

98logger.add(sys.stderr, level="INFO", format=_get_loguru_format()) # default 

99global_vars.initialized = False 



102def init( 

103 resume_cache: Optional[str] = None, 

104 update_config_hook: Optional[Callable] = None, 

105) -> None: 

106 """Initialize APPL with dotenv and config files. 


108 Args: 

109 resume_cache: Path to the trace file used as resume cache. Defaults to None. 

110 update_config_hook: A hook to update the configs. Defaults to None. 


112 Examples: 

113 ```python 

114 import appl 


116 # Load environment variables from `.env` and configs from `appl.yaml`. 

117 # Initialize logging and tracing systems if enabled. 

118 appl.init() 

119 ``` 

120 """ 

121 with global_vars.lock: 

122 # only initialize once 

123 if global_vars.initialized: 

124 logger.warning("APPL has already been initialized, ignore") 

125 return 

126 global_vars.initialized = True 


128 now = pendulum.instance( 

129 # Get the previous frame in the stack, i.e., the one calling this function 

130 frame = inspect.currentframe() 

131 if frame and frame.f_back: 

132 caller_path = frame.f_back.f_code.co_filename # Get file_path of the caller 

133 caller_funcname = frame.f_back.f_code.co_name # Get function name of the caller 

134 caller_basename = os.path.basename(caller_path).split(".")[0] 

135 caller_folder = os.path.dirname(caller_path) # Get folder of the caller 

136 caller_folder = get_folder(caller_folder) 

137 dotenvs = find_files(caller_folder, [".env"]) 

138 appl_config_files = find_files( 

139 caller_folder, ["appl.yaml", "appl.yml", "appl.json", "appl.toml"] 

140 ) 

141 # load dotenvs and appl configs from outer to inner with override 

142 for dotenv in dotenvs[::-1]: 

143 load_dotenv(dotenv, override=True) 

144"Loaded dotenv from {}".format(dotenv)) 

145 for config_file in appl_config_files[::-1]: 

146 override_configs = load_config(config_file) 

147"Loaded configs from {}".format(config_file)) 

148 configs.update(override_configs) 

149 if configs.getattrs("settings.logging.display.configs_update"): 

150"update configs:\n{yaml.dump(override_configs.to_dict())}") 

151 else: 

152 caller_basename, caller_funcname = "appl", "<module>" 

153 dotenvs, appl_config_files = [], [] 

154 logger.error( 

155 "Cannot find the caller of appl.init(), fail to load .env and appl configs" 

156 ) 


158 if update_config_hook: 

159 update_config_hook(configs) 


161 # ============================================================ 

162 # Logging 

163 # ============================================================ 

164 log_format = configs.getattrs("settings.logging.format") 

165 log_level = configs.getattrs("settings.logging.log_level") 

166 log_file = configs.getattrs("settings.logging.log_file") 

167 # set logger level for loguru 

168 logger.remove() # Remove default handler 

169 logger.add(sys.stderr, level=log_level, format=_get_loguru_format()) 

170 if log_file.get("enabled", False): 

171 if (log_file_format := log_file.get("path_format", None)) is not None: 

172 log_file_path = ( 

173 log_file_format.format( 

174 basename=caller_basename, funcname=caller_funcname, time=now 

175 ) 

176 + ".log" 

177 ) 

178 log_file.path = log_file_path # set the top level log file path 

179 file_log_level = log_file.get("log_level", None) or log_level 

180"Logging to file: {log_file_path} with level {file_log_level}") 

181 # no need to overwrite the default format when writing to file 

182 logger.add(log_file_path, level=file_log_level, format=log_format) 


184 configs["info"] = Configs( 

185 { 

186 "start_time": now.format("YYYY-MM-DD HH:mm:ss"), 

187 "dotenvs": dotenvs, 

188 "appl_configs": appl_config_files, 

189 } 

190 ) 

191 if configs.getattrs("settings.logging.display.configs"): 

192"Using configs:\n{yaml.dump(configs.to_dict())}") 


194 # ============================================================ 

195 # Concurrency 

196 # ============================================================ 

197 concurrency = configs.getattrs("settings.concurrency") 

198 llm_max_workers = concurrency.get("llm_max_workers", 10) 

199 thread_max_workers = concurrency.get("thread_max_workers", 20) 

200 process_max_workers = concurrency.get("process_max_workers", 10) 

201 global_vars.llm_thread_executor = ThreadPoolExecutor( 

202 max_workers=llm_max_workers, thread_name_prefix="llm" 

203 ) 

204 global_vars.thread_executor = ThreadPoolExecutor( 

205 max_workers=thread_max_workers, thread_name_prefix="general" 

206 ) 

207 global_vars.process_executor = ProcessPoolExecutor(max_workers=process_max_workers) 


209 # ============================================================ 

210 # Tracing 

211 # ============================================================ 

212 tracing = configs.getattrs("settings.tracing") 

213 strict_match = tracing.get("strict_match", True) 

214 if tracing.get("enabled", False): 

215 if tracing.get("patch_threading", True): 

216 patch_threading() 

217 if (trace_file_format := tracing.get("path_format", None)) is not None: 

218 prefix = trace_file_format.format( 

219 basename=caller_basename, funcname=caller_funcname, time=now 

220 ) 

221 trace_file_path = f"{prefix}.pkl" 

222 meta_file = f"{prefix}_meta.json" 

223 tracing.trace_file = trace_file_path 

224"Tracing file: {trace_file_path}") 

225 dump_file(configs.to_dict(), meta_file) 

226 global_vars.trace_engine = TraceEngine( 

227 trace_file_path, mode="write", strict=strict_match 

228 ) 

229 else: 

230 logger.warning("Tracing is enabled but no trace file is specified") 


232 resume_cache = resume_cache or os.environ.get("APPL_RESUME_TRACE", None) 

233 if resume_cache: 

234 global_vars.resume_cache = resume_cache 

235"Using resume cache: {resume_cache}") 

236 global_vars.resume_cache = TraceEngine( 

237 resume_cache, mode="read", strict=strict_match 

238 ) 




242def init_within_thread( 

243 log_file_prefix: Optional[str] = None, gen_name_prefix: Optional[str] = None 

244) -> Any: 

245 """Initialize APPL to work with multi-threading, including logging and tracing. 


247 Args: 

248 log_file_prefix: The prefix for the log file. Defaults to use the path of the main log file. 

249 gen_name_prefix: The prefix for the generation name. Defaults to use the thread name. 


251 Examples: 

252 ```python 

253 def function_run_in_thread(): 

254 with appl.init_within_thread(): 

255 # do something within the thread 

256 ``` 

257 """ 

258 handler_id = None 


260 try: 

261 thread_name = threading.current_thread().name 

262 log_format = configs.getattrs("settings.logging.format") 

263 log_file = configs.getattrs("settings.logging.log_file") 


265 def filter_thread_record(record: Dict) -> bool: 

266 assert hasattr(record["thread"], "name") 

267 # Use prefix match to filter the log records in different threads 

268 name = record["thread"].name 

269 return name == thread_name or name.startswith(thread_name + "_") 


271 if log_file.get("enabled", False): 

272 if log_file_prefix is None: 

273 if "path" not in log_file: 

274 raise ValueError( 

275 "main log file is not set, did you forget to call appl.init()?" 

276 ) 

277 thread_log_path = os.path.join( 

278 log_file.path[: -len(".log")] + "_logs", f"{thread_name}.log" 

279 ) 

280 else: 

281 thread_log_path = f"{log_file_prefix}_{thread_name}.log" 


283 log_level = log_file.get("log_level", None) 

284 log_level = log_level or configs.getattrs("settings.logging.log_level") 

285 # The logger append to the file by default, not overwrite. 

286 handler_id = logger.add( 

287 thread_log_path, 

288 level=log_level, 

289 format=log_format, 

290 filter=filter_thread_record, # type: ignore 

291 ) 

292 if gen_name_prefix: 

293 set_gen_name_prefix(gen_name_prefix) 

294 # ? shall we reset the prefix after exiting the context? 


296 f"Thread {thread_name}, set generation name prefix as: {gen_name_prefix}" 

297 ) 


299 if handler_id is None: 

300 logger.warning("logging is not enabled") 

301 yield thread_log_path 

302 except Exception as e: 

303 logger.error(f"Error in thread: {e}") 

304 raise e 

305 finally: 

306 if handler_id: 

307 logger.remove(handler_id)