Coverage for src/appl/core/generation.py: 84%
237 statements
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-22 15:39 -0800
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-22 15:39 -0800
1from __future__ import annotations
3import json
4import threading
5from typing import (
6 Any,
7 Callable,
8 Generic,
9 List,
10 Optional,
11 Sequence,
12 Tuple,
13 TypeVar,
14 Union,
15)
17from loguru import logger
18from pydantic import BaseModel
19from rich.live import Live
21from .config import configs
22from .context import PromptContext
23from .globals import (
24 get_thread_local,
25 inc_global_var,
26 inc_thread_local,
27 set_thread_local,
28)
29from .message import AIMessage, BaseMessage, ToolMessage, UserMessage
30from .promptable import Promptable
31from .response import CompletionResponse
32from .server import BaseServer, GenArgs
33from .tool import BaseTool, ToolCall
34from .trace import GenerationInitEvent, GenerationResponseEvent, add_to_trace
35from .types import (
36 CallFuture,
37 ExecutorType,
38 MessageRole,
39 MessageRoleType,
40 ResponseType,
41 String,
42 StringFuture,
43)
44from .utils import get_live, split_last, stop_live, strip_for_continue
46M = TypeVar("M")
47APPL_GEN_NAME_PREFIX_KEY = "_appl_gen_name_prefix"
48LAST_LINE_MARKER = "<last_line>"
49LAST_PART_MARKER = "<last_part>"
52def set_gen_name_prefix(prefix: str) -> None:
53 """Set the prefix for generation names in the current thread."""
54 set_thread_local(APPL_GEN_NAME_PREFIX_KEY, prefix)
57def get_gen_name_prefix() -> Optional[str]:
58 """Get the prefix for generation names in the current thread."""
59 gen_name_prefix = get_thread_local(APPL_GEN_NAME_PREFIX_KEY, None)
60 if gen_name_prefix is None:
61 thread_name = threading.current_thread().name
62 if thread_name != "MainThread":
63 gen_name_prefix = thread_name
64 return gen_name_prefix
67class Generation(Generic[M]):
68 """Represents a generation call to the model."""
70 def __init__(
71 self,
72 server: BaseServer,
73 args: GenArgs,
74 *,
75 max_relay_rounds: int = 0,
76 mock_response: Optional[Union[CompletionResponse, str]] = None,
77 llm_executor_type: ExecutorType = ExecutorType.LLM_THREAD_POOL,
78 lazy_eval: bool = False,
79 _ctx: Optional[PromptContext] = None,
80 **kwargs: Any,
81 # kwargs used for extra args for the create method
82 ) -> None:
83 """Initialize the Generation object.
85 Args:
86 server: An LLM server where the generation request will be sent.
87 args: The arguments of the generation call.
88 max_relay_rounds: the maximum number of relay rounds to continue the unfinished text generation.
89 mock_response: A mock response for the generation call.
90 llm_executor_type: The type of the executor to run the LLM call.
91 lazy_eval: If True, the generation call will be evaluated lazily.
92 _ctx: The prompt context filled automatically by the APPL function.
93 **kwargs: Extra arguments for the generation call.
94 """
95 # name needs to be unique and ordered, so it has to be generated in the main thread
96 gen_name_prefix = get_gen_name_prefix()
97 # take the value before increment
98 self._cnt = inc_thread_local(f"{gen_name_prefix}_gen_cnt") - 1
99 if gen_name_prefix is None:
100 self._id = f"@gen_{self._cnt}"
101 else:
102 self._id = f"@{gen_name_prefix}_gen_{self._cnt}"
104 self._server = server
105 self._model_name = server.model_name
106 self._args = args
107 self._max_relay_rounds = max_relay_rounds
108 self._mock_response = mock_response
109 self._llm_executor_type = llm_executor_type
110 self._lazy_eval = lazy_eval
111 self._ctx = _ctx
112 self._extra_args = kwargs
113 self._num_raw_completions = 0
114 self._cached_response: Optional[CompletionResponse] = None
116 add_to_trace(GenerationInitEvent(name=self.id))
117 log_llm_call_args = configs.getattrs("settings.logging.display.llm_call_args")
118 if log_llm_call_args:
119 logger.info(
120 f"Call generation [{self.id}] with args: {args} and kwargs: {kwargs}"
121 )
123 if isinstance(mock_response, CompletionResponse):
125 def get_response() -> CompletionResponse:
126 return mock_response
128 self._call = self._wrap_response(get_response)
129 else:
130 if mock_response:
131 # use litellm's mock response
132 kwargs.update({"mock_response": mock_response})
133 self._call = self._wrap_response(self._call_llm())
135 # tools
136 self._tools: Sequence[BaseTool] = args.tools
137 self._name2tools = {tool.name: tool for tool in self._tools}
139 def _call_llm(self) -> CallFuture[CompletionResponse]:
140 """Call the LLM server asynchronously to get the completion response."""
141 self._num_raw_completions += 1
142 return CallFuture(
143 self._server.create,
144 executor_type=self._llm_executor_type,
145 lazy_eval=self._lazy_eval,
146 args=self._args,
147 gen_id=f"{self.id}_raw_{self._num_raw_completions - 1}",
148 **self._extra_args,
149 )
151 def _continue_llm(
152 self, results: CompletionResponse, live: Optional[Live] = None
153 ) -> CompletionResponse:
154 assert results.message is not None, "Not support continue for empty message"
156 cutoff_content = strip_for_continue(results.message)
157 continue_prompt = configs.getattrs("prompts.continue_generation")
158 continue_prompt_alt = configs.getattrs("prompts.continue_generation_alt")
160 # Choose a proper split marker for the continuation
161 for split_marker in ["\n", " ", ","]:
162 content, last_part = split_last(cutoff_content, split_marker)
163 if content is not None: # found split_marker in the content
164 prompt = (
165 continue_prompt if split_marker == "\n" else continue_prompt_alt
166 )
167 marker = LAST_LINE_MARKER if split_marker == "\n" else LAST_PART_MARKER
168 break
169 marked_cutoff_content = f"{content}{split_marker}{marker}{last_part}{marker}"
170 prompt = prompt.format(last_marker=marker)
172 messages = self._args.messages
173 messages.append(AIMessage(content=marked_cutoff_content))
174 messages.append(UserMessage(content=prompt))
175 # print(messages, "\n") # DEBUG
177 # call the LLM again and wait for the result
178 response = self._call_llm().result()
179 if response.type == ResponseType.UNFINISHED:
180 response.streaming(
181 title=f"Continue generation [{self.id}]",
182 display_prefix_content=marked_cutoff_content + "\n",
183 live=live,
184 )
186 # pop the last two messages
187 for _ in range(2):
188 messages.pop()
190 results.update(response, split_marker)
191 return response
193 def _wrap_response(
194 self, get_response: Callable[[], CompletionResponse]
195 ) -> Callable[[], CompletionResponse]:
196 """Wrap the LLM calls to address incomplete completion."""
198 def inner() -> CompletionResponse:
199 log_llm_usage = configs.getattrs("settings.logging.display.llm_usage")
200 log_llm_response = configs.getattrs("settings.logging.display.llm_response")
201 log_llm_cost = configs.getattrs("settings.logging.display.llm_cost")
203 results = response = get_response()
205 if self._max_relay_rounds > 0:
206 live = None
207 display_mode = configs.getattrs(
208 "settings.logging.display.display_mode", "live"
209 )
210 need_live = self._args.stream and display_mode == "live"
211 if response.type == ResponseType.UNFINISHED:
212 if need_live:
213 live = get_live()
214 response.streaming(title=f"Generation [{self.id}]", live=live)
216 for i in range(self._max_relay_rounds):
217 if response.finish_reason in ["length"]:
218 generated_chars = len(results.message or "")
219 logger.info(
220 f"[Round {i + 1}/{self._max_relay_rounds}, "
221 f"generated {generated_chars} chars] "
222 f"Generation [{self.id}] was cut off due to max_tokens, "
223 "automatically continue the generation."
224 )
225 if need_live and live is None:
226 live = get_live()
227 response = self._continue_llm(results, live=live)
228 else:
229 break
231 if live is not None:
232 stop_live()
234 def handle_results(results: CompletionResponse) -> None:
235 if log_llm_response:
236 logger.info(f"Generation [{self.id}] results: {results}")
237 if results.usage and log_llm_usage:
238 logger.info(f"Generation [{self.id}] token usage: {results.usage}")
240 num_requests = inc_global_var(f"{self._model_name}_num_requests")
241 if log_llm_cost:
242 currency = getattr(self._server, "_cost_currency", "USD")
243 if self._mock_response is not None:
244 logger.info(
245 "Mock response, estimated cost for real request: "
246 f"{results.cost:.4f} {currency}"
247 )
248 elif results.cost is None:
249 logger.warning(
250 f"No cost information for generation [{self.id}]"
251 )
252 else:
253 total_cost = inc_global_var(
254 f"{self._model_name}_api_cost", results.cost
255 )
256 logger.info(
257 f"API cost for this request: {results.cost:.4f}, "
258 f"in total: {total_cost:.4f} {currency}. "
259 f"Total number of requests: {num_requests}."
260 )
261 create_args = self._server._get_create_args(
262 self._args, **self._extra_args
263 )
264 dump_args = create_args.copy()
265 for k, v in dump_args.items():
266 if k in ["response_format", "response_model"]:
267 if isinstance(v, type) and issubclass(v, BaseModel):
268 dump_args[k] = json.dumps(v.model_json_schema(), indent=4)
270 add_to_trace(
271 GenerationResponseEvent(
272 name=self.id, args=dump_args, ret=str(results)
273 )
274 )
276 results.register_post_finish_callback(handle_results)
278 return results
280 return inner
282 @property
283 def id(self) -> str:
284 """The unique ID of the generation."""
285 return self._id
287 @property
288 def response(self) -> CompletionResponse:
289 """The response of the generation call."""
290 # NOTE: the result of the call should be cached
291 if self._cached_response is None:
292 self._cached_response = self._call()
293 return self._cached_response
295 @property
296 def response_type(self) -> ResponseType:
297 """The type of the response."""
298 return self.response.type
300 @property
301 def is_message(self) -> bool:
302 """Whether the response is a text message."""
303 return self.response_type == ResponseType.TEXT
305 @property
306 def is_tool_call(self) -> bool:
307 """Whether the response is a tool call."""
308 return self.response_type == ResponseType.TOOL_CALL
310 @property
311 def is_obj(self) -> bool:
312 """Whether the response is an object."""
313 return self.response_type == ResponseType.OBJECT
315 @property
316 def message(self) -> Optional[str]:
317 """The message of the response."""
318 return self.response.message
320 @property
321 def tool_calls(self) -> List[ToolCall]:
322 """The tool calls of the response."""
323 return self.response.tool_calls
325 @property
326 def response_obj(self) -> M:
327 """The object of the response."""
328 return self.response.response_obj
330 @property
331 def results(self) -> Union[M, str, List[ToolCall]]:
332 """The results of the response."""
333 return self.response.results
335 @property
336 def str_future(self) -> StringFuture:
337 """The StringFuture representation of the response."""
338 return StringFuture(self)
340 @property
341 def text_stream(self):
342 """Get the response of the generation as a text stream."""
343 return self.response.format_stream()
345 def _call_tool(
346 self,
347 name: str,
348 args: str,
349 parallel: bool = False,
350 executor_type: ExecutorType = ExecutorType.GENERAL_THREAD_POOL,
351 ) -> Any:
352 try:
353 kwargs = json.loads(args)
354 except json.JSONDecodeError as e:
355 raise ValueError(f"Error parsing args: {args}") from e
356 args_str = ", ".join([f"{k}={v}" for k, v in kwargs.items()])
357 if configs.getattrs("settings.logging.display.tool_calls"):
358 logger.info(f"Running tool call: {name}({args_str})")
360 if name not in self._name2tools:
361 raise ValueError(f"Error: Tool {name} not found")
362 tool = self._name2tools[name]
363 try:
364 if parallel:
365 res = CallFuture(tool, executor_type=executor_type, **kwargs)
366 else:
367 res = tool(**kwargs)
368 except Exception as e:
369 raise RuntimeError(f"Error running tool call: {name}({args_str})") from e
371 return res
373 def run_tool_calls(
374 self,
375 filter_fn: Optional[Callable[[List[ToolCall]], List[ToolCall]]] = None,
376 parallel: bool = False,
377 executor_type: ExecutorType = ExecutorType.GENERAL_THREAD_POOL,
378 log_results: Optional[bool] = None,
379 ) -> List[ToolMessage]:
380 """Run all tool calls in the generation and return the results.
382 Args:
383 filter_fn:
384 A function that takes a list of ToolCall objects and returns
385 a filtered list of ToolCall objects. This function can be
386 used to filter the tool calls that will be run.
387 parallel: If True, run the tool calls in parallel. Default to False.
388 executor_type:
389 The type of the executor to run the tool calls, can be
390 "general_thread_pool", "general_process_pool", "new_thread" or
391 "new_process".
392 log_results:
393 If True, log the results of the tool calls. Note This will wait for
394 the results to be ready. Default to use the setting in configs.
396 Returns:
397 A list of ToolMessage objects.
398 """
399 if not self.is_tool_call:
400 raise ValueError("Error: The Generation is not a tool call")
401 if log_results is None:
402 log_results = configs.getattrs("settings.logging.display.tool_results")
403 tool_calls = self.tool_calls
404 if filter_fn:
405 tool_calls = filter_fn(tool_calls)
406 messages = []
407 for tc in tool_calls:
408 role = MessageRole(MessageRoleType.TOOL, tc.name)
409 try:
410 tool_result = self._call_tool(
411 tc.name, tc.args, parallel=parallel, executor_type=executor_type
412 )
413 msg = ToolMessage(
414 tool_result, role=role, tool_call_id=tc.id, has_error=False
415 )
416 except Exception as e:
417 logger.error(f"Error running tool call: {tc.name}({tc.args})")
418 logger.error(e)
419 msg = ToolMessage(str(e), role=role, tool_call_id=tc.id, has_error=True)
420 messages.append(msg)
421 if log_results: # this will wait for the results to be ready
422 for msg in messages:
423 logger.info(f"Tool call result: {msg}")
424 return messages
426 def as_prompt(self) -> Union[AIMessage, StringFuture]:
427 """Get the response of the generation as a promptable object."""
428 if self._args.tools:
429 if self.is_tool_call:
430 return AIMessage(tool_calls=self.tool_calls)
431 # return a future object of value: str(self._call()), without blocking
432 return StringFuture(CallFuture(self._call))
434 def __add__(self, other: Union[String, "Generation"]) -> StringFuture:
435 # Assume generation is a string
436 if isinstance(other, Generation):
437 return self.str_future + other.str_future
438 elif isinstance(other, (str, StringFuture)):
439 return self.str_future + other
440 raise TypeError(
441 f"unsupported operand type(s) for +: 'Generation' and '{type(other)}'"
442 )
444 def __radd__(self, other: String) -> StringFuture:
445 # Assume generation is a string
446 if isinstance(other, (str, StringFuture)):
447 return other + self.str_future
448 raise TypeError(
449 f"unsupported operand type(s) for +: '{type(other)}' and 'Generation'"
450 )
452 def __getattr__(self, name: str) -> Any:
453 assert name != "response", "Internal Error within self.response"
454 return getattr(self.response, name)
456 def __str__(self) -> str:
457 return str(self.response.results)
459 def __repr__(self) -> str:
460 return f"Generation(id={self.id})"
462 def __call__(self):
463 """Get the response of the generation call."""
464 return self.response