Coverage for src/appl/core/response.py: 71%
266 statements
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-22 15:39 -0800
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-22 15:39 -0800
1import json
2import os
3import shutil
4import sys
5import time
6from typing import Any, Callable, List, Literal, Optional, Union
8from litellm import CustomStreamWrapper, completion_cost, stream_chunk_builder
9from litellm.exceptions import NotFoundError
10from litellm.types.utils import Delta, Function, ModelResponse
11from loguru import logger
12from openai import Stream
13from openai.types import CompletionUsage
14from openai.types.chat import (
15 ChatCompletion,
16 ChatCompletionChunk,
17 ChatCompletionMessageToolCall,
18)
19from openai.types.chat.chat_completion import Choice
20from openai.types.chat.chat_completion_chunk import (
21 ChoiceDelta,
22 ChoiceDeltaToolCallFunction,
23)
24from pydantic import BaseModel, Field, model_validator
25from rich.live import Live
26from rich.panel import Panel
27from rich.syntax import Syntax
28from termcolor import colored
29from termcolor._types import Color
31from .config import configs
32from .tool import ToolCall
33from .types import ResponseType
34from .utils import get_live, make_panel, split_last, stop_live, strip_for_continue
37class CompletionResponse(BaseModel):
38 """A class wrapping the response from the LLM model.
40 For a streaming response, it tracks the chunks of the response and
41 builds the complete response when the streaming is finished.
42 """
44 raw_response: Any = Field(None, description="The raw response from the model")
45 """The raw response from the model."""
46 cost: Optional[float] = Field(None, description="The cost of the completion")
47 """The cost of the completion."""
48 usage: Optional[CompletionUsage] = Field(
49 None, description="The usage of the completion"
50 )
51 """The usage of the completion."""
52 finish_reason: Optional[str] = Field(
53 None, description="The reason why the completion is finished for the top-choice"
54 )
55 """The reason why the completion is finished for the top-choice."""
56 num_raw_completions: int = Field(1, description="The number of raw completions")
57 """The number of raw completions."""
58 chunks: List[Union[ModelResponse, ChatCompletionChunk]] = Field(
59 [], description="The chunks of the response when streaming"
60 )
61 """The chunks of the response when streaming."""
62 is_stream: bool = Field(False, description="Whether the response is a stream")
63 """Whether the response is a stream."""
64 is_finished: bool = Field(
65 False, description="Whether the response stream is finished"
66 )
67 """Whether the response stream is finished."""
68 post_finish_callbacks: List[Callable] = Field(
69 [], description="The post finish callbacks"
70 )
71 """The post finish callbacks."""
72 response_model: Any = Field(
73 None, description="The BaseModel's subclass specifying the response format."
74 )
75 """The BaseModel's subclass specifying the response format."""
76 response_obj: Any = Field(
77 None, description="The response object of response model, could be a stream"
78 )
79 """The response object of response model, could be a stream."""
80 message: Optional[str] = Field(
81 None, description="The top-choice message from the completion"
82 )
83 """The top-choice message from the completion."""
84 tool_calls: List[ToolCall] = Field([], description="The tool calls")
85 """The tool calls."""
87 @model_validator(mode="after")
88 def _post_init(self) -> "CompletionResponse":
89 self._complete_response = None
91 if isinstance(self.raw_response, (CustomStreamWrapper, Stream)):
92 # ? supports for Async Steam?
93 self.is_stream = True
94 else:
95 self._finish(self.raw_response) # type: ignore
96 return self
98 def set_response_obj(self, response_obj: Any) -> None:
99 """Set the response object."""
100 self.response_obj = response_obj
102 @property
103 def complete_response(self) -> Union[ModelResponse, ChatCompletion]:
104 """The complete response from the model. This will block until the response is finished."""
105 if self.is_finished:
106 return self._complete_response # type: ignore
107 self.streaming() # ? when we should set display to False?
108 assert self.is_finished, "Response should be finished after streaming"
109 return self._complete_response # type: ignore
111 @property
112 def results(self) -> Any:
113 """The results of the response.
115 Returns:
116 message (str):
117 The message if the response is a text completion.
118 tool_calls (List[ToolCall]):
119 The tool calls if the response is a list of tool calls.
120 response_obj (Any):
121 The object if the response is a response object.
122 """
123 if self.is_stream and not self.is_finished:
124 self.streaming() # display the stream and finish the response
125 results: Any = self.message
126 if self.response_obj is not None:
127 results = self.response_obj
128 elif len(self.tool_calls):
129 results = self.tool_calls
130 return results
132 @property
133 def type(self) -> ResponseType:
134 """The type of the response."""
135 if not self.is_finished:
136 return ResponseType.UNFINISHED
137 if self.response_model is not None and self.response_obj is not None:
138 return ResponseType.OBJECT
139 if len(self.tool_calls):
140 return ResponseType.TOOL_CALL
141 return ResponseType.TEXT
143 def update(
144 self, other: "CompletionResponse", split_marker: str = "\n"
145 ) -> "CompletionResponse":
146 """Update the response with the information contained in the other response."""
147 if not self.is_finished:
148 raise ValueError("Cannot update unfinished response")
149 if self.type != other.type:
150 raise ValueError(
151 f"Cannot update response with type {self.type} "
152 f"with another response of type {other.type}"
153 )
154 if self.type != ResponseType.TEXT:
155 raise NotImplementedError("Not supported for non-text response")
156 if self.message is None or other.message is None:
157 raise ValueError("Not supported for empty message when updating")
159 stripped_message = strip_for_continue(self.message)
160 _, last_part = split_last(stripped_message, split_marker)
161 message = other.message
162 if last_part in message:
163 # truncate the overlapping part, patch the messages together
164 self.message = (
165 stripped_message + message[message.index(last_part) + len(last_part) :]
166 )
167 else:
168 self.message += message # extend the message
169 logger.warning(
170 f"Last part {last_part} not found in the message. "
171 "Appending the message directly."
172 )
174 def as_list(obj: Any) -> List[Any]:
175 if isinstance(obj, list):
176 return obj
177 return [obj]
179 for k in ["finish_reason", "response_model", "response_obj", "tool_calls"]:
180 if getattr(self, k) is None:
181 setattr(self, k, getattr(other, k))
182 self.raw_response = as_list(self.raw_response) + as_list(other.raw_response)
183 self.chunks += other.chunks
184 self.num_raw_completions += other.num_raw_completions
185 if other.cost is not None:
186 self.cost = (self.cost or 0) + other.cost
187 if other.usage is not None:
189 def merge_usage(usage1: BaseModel, usage2: BaseModel) -> None:
190 """Merge the usage from two responses recursively."""
191 for k, v in usage2.model_dump().items():
192 if isinstance(v, int) or isinstance(v, float):
193 if hasattr(usage1, k):
194 setattr(usage1, k, getattr(usage1, k) + v)
195 elif isinstance(v, BaseModel):
196 merge_usage(getattr(usage1, k), v)
198 merge_usage(self.usage, other.usage) # type: ignore
200 return self
202 def streaming(
203 self,
204 display: Optional[str] = None,
205 title: str = "APPL Streaming",
206 display_prefix_content: str = "",
207 live: Optional[Live] = None,
208 ) -> "CompletionResponse":
209 """Stream the response object and finish the response."""
210 if not self.is_stream:
211 raise ValueError("Cannot iterate over non-streaming response")
212 if self.is_finished:
213 return self
215 if self.response_obj is not None:
216 target = self.response_obj
217 else:
218 target = self.format_stream()
220 display = display or configs.getattrs(
221 "settings.logging.display.display_mode", "live"
222 )
223 if display == "live":
224 start_time = time.time()
226 def panel(
227 content: str, iter_index: Optional[int] = None, truncate: bool = False
228 ) -> Panel:
229 style = "magenta"
230 display_title = title
231 if iter_index is not None:
232 time_elapsed = time.time() - start_time
233 avg_iters_per_sec = (iter_index + 1) / time_elapsed
234 stream_info = (
235 f"[{time_elapsed:.3f}s ({avg_iters_per_sec:.2f} it/s)]"
236 )
237 display_title += f" - {stream_info}"
238 return make_panel(
239 content, title=display_title, style=style, truncate=truncate
240 )
242 if live is None:
243 live = get_live()
244 need_stop = True
245 else:
246 need_stop = False
247 content = display_prefix_content
248 for i, chunk in enumerate(iter(target)):
249 if isinstance(chunk, BaseModel):
250 content = json.dumps(chunk.model_dump(), indent=2)
251 else:
252 content += str(chunk)
253 live.update(panel(content, i, truncate=True))
254 # live.refresh() # might be too frequent
255 # display untruncated content at the end
256 live.update(panel(content, i))
257 live.refresh()
258 if need_stop:
259 stop_live()
260 elif display == "print":
261 last_content = ""
263 def eprint(content: str, color: Optional[Color] = None) -> None:
264 print(colored(content, color) if color else content, end="")
265 sys.stdout.flush()
267 eprint("\n===== START OF APPL STREAMING =====\n", color="magenta")
268 self.register_post_finish_callback(
269 lambda _: eprint(
270 "\n===== END OF APPL STREAMING =====\n", color="magenta"
271 ),
272 order="first",
273 )
274 eprint(display_prefix_content, color="grey")
275 for chunk in iter(target):
276 if isinstance(chunk, BaseModel):
277 content = json.dumps(chunk.model_dump(), indent=2)
278 if last_content in content:
279 eprint(content[content.index(last_content) :])
280 else:
281 eprint(content)
282 last_content = content
283 else:
284 eprint(str(chunk))
286 elif display == "none":
287 for chunk in iter(target):
288 pass
289 else:
290 raise ValueError(
291 f"Unknown display argument: {display}, only 'live', 'print' and 'none' are supported"
292 )
293 if self.response_obj is not None:
294 self.set_response_obj(chunk)
295 return self
297 def register_post_finish_callback(
298 self,
299 callback: Callable,
300 order: Literal["first", "last"] = "last",
301 ) -> None:
302 """Register a post finish callback.
304 The callback will be called after the response is finished.
305 """
306 if self.is_finished:
307 callback(self)
308 else:
309 if order not in ["first", "last"]:
310 raise ValueError(
311 f"Unknown order argument: {order}, only 'first' and 'last' are supported"
312 )
313 if order == "last":
314 self.post_finish_callbacks.append(callback)
315 else:
316 self.post_finish_callbacks.insert(0, callback)
318 def format_stream(self):
319 """Format the stream response as a text generator."""
320 suffix = ""
321 for chunk in iter(self):
322 # chunk: Union[ModelResponse, ChatCompletionChunk]
323 delta: Union[Delta, ChoiceDelta] = chunk.choices[0].delta # type: ignore
325 if delta is not None:
326 if delta.content is not None:
327 yield delta.content
328 elif getattr(delta, "tool_calls", None):
329 f: Union[Function, ChoiceDeltaToolCallFunction] = delta.tool_calls[
330 0
331 ].function # type: ignore
332 if f.name is not None:
333 if suffix:
334 yield f"{suffix}, "
335 yield f"{f.name}("
336 suffix = ")"
337 if f.arguments is not None:
338 yield f.arguments
339 yield suffix
341 def _finish(self, response: Any) -> None:
342 if self.is_finished:
343 logger.warning("Response already finished. Ignoring finish call.")
344 return
345 self.is_finished = True
346 self._complete_response = response
347 self.usage = getattr(response, "usage", None)
348 self.cost = 0.0
349 try:
350 self.cost = completion_cost(response)
351 except Exception:
352 pass
353 # parse the message and tool calls
354 if isinstance(response, (ModelResponse, ChatCompletion)):
355 message = response.choices[0].message # type: ignore
356 self.finish_reason = response.choices[0].finish_reason
357 if tool_calls := getattr(message, "tool_calls", None):
358 for call in tool_calls:
359 self.tool_calls.append(ToolCall.from_openai_tool_call(call))
360 elif message.content is not None:
361 self.message = message.content
362 else:
363 raise ValueError(f"Invalid response: {response}")
364 elif response is None:
365 logger.warning("Response is None, only used for testing")
366 else:
367 raise ValueError(f"Unknown response type: {type(response)}")
369 # post finish hook
370 for callback in self.post_finish_callbacks:
371 try:
372 callback(self)
373 except Exception as e:
374 logger.error(
375 f"Error when calling post finish callback {callback.__name__}: {e}"
376 )
377 raise e
379 def _finish_stream(self) -> None:
380 try:
381 response = stream_chunk_builder(self.chunks)
382 except Exception:
383 logger.error("Error when building the response from the stream")
384 raise
385 self._finish(response)
387 def __str__(self):
388 if self.is_stream and not self.is_finished:
389 return repr(self)
390 return str(self.results)
392 def __next__(self):
393 if not self.is_stream:
394 raise ValueError("Cannot iterate over non-streaming response")
395 try:
396 chunk = next(self.raw_response)
397 self.chunks.append(chunk)
398 return chunk
399 except StopIteration:
400 self._finish_stream()
401 raise StopIteration
403 def __iter__(self):
404 if not self.is_stream:
405 raise ValueError("Cannot iterate over non-streaming response")
406 return self
408 def __getattr__(self, name: str) -> Any:
409 if not self.is_finished:
410 logger.warning(
411 f"Cannot get {name} attribute before the response is finished. "
412 "Returning None."
413 )
414 return None
415 return getattr(self.complete_response, name)