Coverage for src/appl/__init__.py: 73%
145 statements
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-22 15:39 -0800
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-22 15:39 -0800
1"""appl - A Prompt Programming Language."""
3from __future__ import annotations
5import datetime
6import inspect
7import os
8import sys
9import threading
10from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
11from contextlib import contextmanager
13import pendulum
14import toml
15import yaml
16from dotenv import load_dotenv
17from loguru import logger
19logger.remove() # Remove default handler
20logger.add(sys.stderr, level="INFO") # set to INFO
22from typing import Any, Callable, Dict, Optional
24from .compositor import ApplStr as Str
25from .compositor import iter
26from .core import (
27 BracketedDefinition,
28 CallFuture,
29 CompletionResponse,
30 Definition,
31 Generation,
32 Image,
33 Indexing,
34 Promptable,
35 PromptContext,
36 PromptPrinter,
37 PromptRecords,
38 StringFuture,
39 Tool,
40)
41from .core import appl_compile as compile
42from .core import appl_execute as execute
43from .core import appl_format as format
44from .core import appl_with_ctx as with_ctx
45from .core.config import Configs, configs, load_config
46from .core.generation import get_gen_name_prefix, set_gen_name_prefix
47from .core.globals import global_vars
48from .core.io import dump_file, load_file
49from .core.message import (
50 AIMessage,
51 Conversation,
52 SystemMessage,
53 ToolMessage,
54 UserMessage,
55)
56from .core.patch import patch_threading
57from .core.promptable import define, define_bracketed, promptify
58from .core.trace import traceable
59from .core.utils import need_ctx, partial, wraps
60from .func import (
61 as_func,
62 as_tool,
63 as_tool_choice,
64 call,
65 convo,
66 empty_line,
67 gen,
68 grow,
69 openai_tool_schema,
70 ppl,
71 records,
72 reset_context,
73 str_future,
74)
75from .role_changer import AIRole, SystemRole, ToolRole, UserRole
76from .servers import server_manager
77from .tracing import TraceEngine
78from .utils import (
79 LoguruFormatter,
80 find_dotenv,
81 find_files,
82 get_folder,
83 get_meta_file,
84 timeit,
85)
86from .version import __version__
89def _get_loguru_format():
90 return LoguruFormatter(
91 max_length=configs.getattrs("settings.logging.max_length"),
92 suffix_length=configs.getattrs("settings.logging.suffix_length"),
93 ).loguru_format
96logger.remove() # Remove default handler
97# update default handler for the loguru logger
98logger.add(sys.stderr, level="INFO", format=_get_loguru_format()) # default
99global_vars.initialized = False
102def init(
103 resume_cache: Optional[str] = None,
104 update_config_hook: Optional[Callable] = None,
105) -> None:
106 """Initialize APPL with dotenv and config files.
108 Args:
109 resume_cache: Path to the trace file used as resume cache. Defaults to None.
110 update_config_hook: A hook to update the configs. Defaults to None.
112 Examples:
113 ```python
114 import appl
116 # Load environment variables from `.env` and configs from `appl.yaml`.
117 # Initialize logging and tracing systems if enabled.
118 appl.init()
119 ```
120 """
121 with global_vars.lock:
122 # only initialize once
123 if global_vars.initialized:
124 logger.warning("APPL has already been initialized, ignore")
125 return
126 global_vars.initialized = True
128 now = pendulum.instance(datetime.datetime.now())
129 # Get the previous frame in the stack, i.e., the one calling this function
130 frame = inspect.currentframe()
131 if frame and frame.f_back:
132 caller_path = frame.f_back.f_code.co_filename # Get file_path of the caller
133 caller_funcname = frame.f_back.f_code.co_name # Get function name of the caller
134 caller_basename = os.path.basename(caller_path).split(".")[0]
135 caller_folder = os.path.dirname(caller_path) # Get folder of the caller
136 caller_folder = get_folder(caller_folder)
137 dotenvs = find_files(caller_folder, [".env"])
138 appl_config_files = find_files(
139 caller_folder, ["appl.yaml", "appl.yml", "appl.json", "appl.toml"]
140 )
141 # load dotenvs and appl configs from outer to inner with override
142 for dotenv in dotenvs[::-1]:
143 load_dotenv(dotenv, override=True)
144 logger.info("Loaded dotenv from {}".format(dotenv))
145 for config_file in appl_config_files[::-1]:
146 override_configs = load_config(config_file)
147 logger.info("Loaded configs from {}".format(config_file))
148 configs.update(override_configs)
149 if configs.getattrs("settings.logging.display.configs_update"):
150 logger.info(f"update configs:\n{yaml.dump(override_configs.to_dict())}")
151 else:
152 caller_basename, caller_funcname = "appl", "<module>"
153 dotenvs, appl_config_files = [], []
154 logger.error(
155 "Cannot find the caller of appl.init(), fail to load .env and appl configs"
156 )
158 if update_config_hook:
159 update_config_hook(configs)
161 # ============================================================
162 # Logging
163 # ============================================================
164 log_format = configs.getattrs("settings.logging.format")
165 log_level = configs.getattrs("settings.logging.log_level")
166 log_file = configs.getattrs("settings.logging.log_file")
167 # set logger level for loguru
168 logger.remove() # Remove default handler
169 logger.add(sys.stderr, level=log_level, format=_get_loguru_format())
170 if log_file.get("enabled", False):
171 if (log_file_format := log_file.get("path_format", None)) is not None:
172 log_file_path = (
173 log_file_format.format(
174 basename=caller_basename, funcname=caller_funcname, time=now
175 )
176 + ".log"
177 )
178 log_file.path = log_file_path # set the top level log file path
179 file_log_level = log_file.get("log_level", None) or log_level
180 logger.info(f"Logging to file: {log_file_path} with level {file_log_level}")
181 # no need to overwrite the default format when writing to file
182 logger.add(log_file_path, level=file_log_level, format=log_format)
184 configs["info"] = Configs(
185 {
186 "start_time": now.format("YYYY-MM-DD HH:mm:ss"),
187 "dotenvs": dotenvs,
188 "appl_configs": appl_config_files,
189 }
190 )
191 if configs.getattrs("settings.logging.display.configs"):
192 logger.info(f"Using configs:\n{yaml.dump(configs.to_dict())}")
194 # ============================================================
195 # Concurrency
196 # ============================================================
197 concurrency = configs.getattrs("settings.concurrency")
198 llm_max_workers = concurrency.get("llm_max_workers", 10)
199 thread_max_workers = concurrency.get("thread_max_workers", 20)
200 process_max_workers = concurrency.get("process_max_workers", 10)
201 global_vars.llm_thread_executor = ThreadPoolExecutor(
202 max_workers=llm_max_workers, thread_name_prefix="llm"
203 )
204 global_vars.thread_executor = ThreadPoolExecutor(
205 max_workers=thread_max_workers, thread_name_prefix="general"
206 )
207 global_vars.process_executor = ProcessPoolExecutor(max_workers=process_max_workers)
209 # ============================================================
210 # Tracing
211 # ============================================================
212 tracing = configs.getattrs("settings.tracing")
213 strict_match = tracing.get("strict_match", True)
214 if tracing.get("enabled", False):
215 if tracing.get("patch_threading", True):
216 patch_threading()
217 if (trace_file_format := tracing.get("path_format", None)) is not None:
218 prefix = trace_file_format.format(
219 basename=caller_basename, funcname=caller_funcname, time=now
220 )
221 trace_file_path = f"{prefix}.pkl"
222 meta_file = f"{prefix}_meta.json"
223 tracing.trace_file = trace_file_path
224 logger.info(f"Tracing file: {trace_file_path}")
225 dump_file(configs.to_dict(), meta_file)
226 global_vars.trace_engine = TraceEngine(
227 trace_file_path, mode="write", strict=strict_match
228 )
229 else:
230 logger.warning("Tracing is enabled but no trace file is specified")
232 resume_cache = resume_cache or os.environ.get("APPL_RESUME_TRACE", None)
233 if resume_cache:
234 global_vars.resume_cache = resume_cache
235 logger.info(f"Using resume cache: {resume_cache}")
236 global_vars.resume_cache = TraceEngine(
237 resume_cache, mode="read", strict=strict_match
238 )
241@contextmanager
242def init_within_thread(
243 log_file_prefix: Optional[str] = None, gen_name_prefix: Optional[str] = None
244) -> Any:
245 """Initialize APPL to work with multi-threading, including logging and tracing.
247 Args:
248 log_file_prefix: The prefix for the log file. Defaults to use the path of the main log file.
249 gen_name_prefix: The prefix for the generation name. Defaults to use the thread name.
251 Examples:
252 ```python
253 def function_run_in_thread():
254 with appl.init_within_thread():
255 # do something within the thread
256 ```
257 """
258 handler_id = None
260 try:
261 thread_name = threading.current_thread().name
262 log_format = configs.getattrs("settings.logging.format")
263 log_file = configs.getattrs("settings.logging.log_file")
265 def filter_thread_record(record: Dict) -> bool:
266 assert hasattr(record["thread"], "name")
267 # Use prefix match to filter the log records in different threads
268 name = record["thread"].name
269 return name == thread_name or name.startswith(thread_name + "_")
271 if log_file.get("enabled", False):
272 if log_file_prefix is None:
273 if "path" not in log_file:
274 raise ValueError(
275 "main log file is not set, did you forget to call appl.init()?"
276 )
277 thread_log_path = os.path.join(
278 log_file.path[: -len(".log")] + "_logs", f"{thread_name}.log"
279 )
280 else:
281 thread_log_path = f"{log_file_prefix}_{thread_name}.log"
283 log_level = log_file.get("log_level", None)
284 log_level = log_level or configs.getattrs("settings.logging.log_level")
285 # The logger append to the file by default, not overwrite.
286 handler_id = logger.add(
287 thread_log_path,
288 level=log_level,
289 format=log_format,
290 filter=filter_thread_record, # type: ignore
291 )
292 if gen_name_prefix:
293 set_gen_name_prefix(gen_name_prefix)
294 # ? shall we reset the prefix after exiting the context?
295 logger.info(
296 f"Thread {thread_name}, set generation name prefix as: {gen_name_prefix}"
297 )
299 if handler_id is None:
300 logger.warning("logging is not enabled")
301 yield thread_log_path
302 except Exception as e:
303 logger.error(f"Error in thread: {e}")
304 raise e
305 finally:
306 if handler_id:
307 logger.remove(handler_id)