Coverage for src/appl/core/response.py: 71%

266 statements  

« prev     ^ index     » next       coverage.py v7.6.7, created at 2024-11-22 15:39 -0800

1import json 

2import os 

3import shutil 

4import sys 

5import time 

6from typing import Any, Callable, List, Literal, Optional, Union 

7 

8from litellm import CustomStreamWrapper, completion_cost, stream_chunk_builder 

9from litellm.exceptions import NotFoundError 

10from litellm.types.utils import Delta, Function, ModelResponse 

11from loguru import logger 

12from openai import Stream 

13from openai.types import CompletionUsage 

14from openai.types.chat import ( 

15 ChatCompletion, 

16 ChatCompletionChunk, 

17 ChatCompletionMessageToolCall, 

18) 

19from openai.types.chat.chat_completion import Choice 

20from openai.types.chat.chat_completion_chunk import ( 

21 ChoiceDelta, 

22 ChoiceDeltaToolCallFunction, 

23) 

24from pydantic import BaseModel, Field, model_validator 

25from rich.live import Live 

26from rich.panel import Panel 

27from rich.syntax import Syntax 

28from termcolor import colored 

29from termcolor._types import Color 

30 

31from .config import configs 

32from .tool import ToolCall 

33from .types import ResponseType 

34from .utils import get_live, make_panel, split_last, stop_live, strip_for_continue 

35 

36 

37class CompletionResponse(BaseModel): 

38 """A class wrapping the response from the LLM model. 

39 

40 For a streaming response, it tracks the chunks of the response and 

41 builds the complete response when the streaming is finished. 

42 """ 

43 

44 raw_response: Any = Field(None, description="The raw response from the model") 

45 """The raw response from the model.""" 

46 cost: Optional[float] = Field(None, description="The cost of the completion") 

47 """The cost of the completion.""" 

48 usage: Optional[CompletionUsage] = Field( 

49 None, description="The usage of the completion" 

50 ) 

51 """The usage of the completion.""" 

52 finish_reason: Optional[str] = Field( 

53 None, description="The reason why the completion is finished for the top-choice" 

54 ) 

55 """The reason why the completion is finished for the top-choice.""" 

56 num_raw_completions: int = Field(1, description="The number of raw completions") 

57 """The number of raw completions.""" 

58 chunks: List[Union[ModelResponse, ChatCompletionChunk]] = Field( 

59 [], description="The chunks of the response when streaming" 

60 ) 

61 """The chunks of the response when streaming.""" 

62 is_stream: bool = Field(False, description="Whether the response is a stream") 

63 """Whether the response is a stream.""" 

64 is_finished: bool = Field( 

65 False, description="Whether the response stream is finished" 

66 ) 

67 """Whether the response stream is finished.""" 

68 post_finish_callbacks: List[Callable] = Field( 

69 [], description="The post finish callbacks" 

70 ) 

71 """The post finish callbacks.""" 

72 response_model: Any = Field( 

73 None, description="The BaseModel's subclass specifying the response format." 

74 ) 

75 """The BaseModel's subclass specifying the response format.""" 

76 response_obj: Any = Field( 

77 None, description="The response object of response model, could be a stream" 

78 ) 

79 """The response object of response model, could be a stream.""" 

80 message: Optional[str] = Field( 

81 None, description="The top-choice message from the completion" 

82 ) 

83 """The top-choice message from the completion.""" 

84 tool_calls: List[ToolCall] = Field([], description="The tool calls") 

85 """The tool calls.""" 

86 

87 @model_validator(mode="after") 

88 def _post_init(self) -> "CompletionResponse": 

89 self._complete_response = None 

90 

91 if isinstance(self.raw_response, (CustomStreamWrapper, Stream)): 

92 # ? supports for Async Steam? 

93 self.is_stream = True 

94 else: 

95 self._finish(self.raw_response) # type: ignore 

96 return self 

97 

98 def set_response_obj(self, response_obj: Any) -> None: 

99 """Set the response object.""" 

100 self.response_obj = response_obj 

101 

102 @property 

103 def complete_response(self) -> Union[ModelResponse, ChatCompletion]: 

104 """The complete response from the model. This will block until the response is finished.""" 

105 if self.is_finished: 

106 return self._complete_response # type: ignore 

107 self.streaming() # ? when we should set display to False? 

108 assert self.is_finished, "Response should be finished after streaming" 

109 return self._complete_response # type: ignore 

110 

111 @property 

112 def results(self) -> Any: 

113 """The results of the response. 

114 

115 Returns: 

116 message (str): 

117 The message if the response is a text completion. 

118 tool_calls (List[ToolCall]): 

119 The tool calls if the response is a list of tool calls. 

120 response_obj (Any): 

121 The object if the response is a response object. 

122 """ 

123 if self.is_stream and not self.is_finished: 

124 self.streaming() # display the stream and finish the response 

125 results: Any = self.message 

126 if self.response_obj is not None: 

127 results = self.response_obj 

128 elif len(self.tool_calls): 

129 results = self.tool_calls 

130 return results 

131 

132 @property 

133 def type(self) -> ResponseType: 

134 """The type of the response.""" 

135 if not self.is_finished: 

136 return ResponseType.UNFINISHED 

137 if self.response_model is not None and self.response_obj is not None: 

138 return ResponseType.OBJECT 

139 if len(self.tool_calls): 

140 return ResponseType.TOOL_CALL 

141 return ResponseType.TEXT 

142 

143 def update( 

144 self, other: "CompletionResponse", split_marker: str = "\n" 

145 ) -> "CompletionResponse": 

146 """Update the response with the information contained in the other response.""" 

147 if not self.is_finished: 

148 raise ValueError("Cannot update unfinished response") 

149 if self.type != other.type: 

150 raise ValueError( 

151 f"Cannot update response with type {self.type} " 

152 f"with another response of type {other.type}" 

153 ) 

154 if self.type != ResponseType.TEXT: 

155 raise NotImplementedError("Not supported for non-text response") 

156 if self.message is None or other.message is None: 

157 raise ValueError("Not supported for empty message when updating") 

158 

159 stripped_message = strip_for_continue(self.message) 

160 _, last_part = split_last(stripped_message, split_marker) 

161 message = other.message 

162 if last_part in message: 

163 # truncate the overlapping part, patch the messages together 

164 self.message = ( 

165 stripped_message + message[message.index(last_part) + len(last_part) :] 

166 ) 

167 else: 

168 self.message += message # extend the message 

169 logger.warning( 

170 f"Last part {last_part} not found in the message. " 

171 "Appending the message directly." 

172 ) 

173 

174 def as_list(obj: Any) -> List[Any]: 

175 if isinstance(obj, list): 

176 return obj 

177 return [obj] 

178 

179 for k in ["finish_reason", "response_model", "response_obj", "tool_calls"]: 

180 if getattr(self, k) is None: 

181 setattr(self, k, getattr(other, k)) 

182 self.raw_response = as_list(self.raw_response) + as_list(other.raw_response) 

183 self.chunks += other.chunks 

184 self.num_raw_completions += other.num_raw_completions 

185 if other.cost is not None: 

186 self.cost = (self.cost or 0) + other.cost 

187 if other.usage is not None: 

188 

189 def merge_usage(usage1: BaseModel, usage2: BaseModel) -> None: 

190 """Merge the usage from two responses recursively.""" 

191 for k, v in usage2.model_dump().items(): 

192 if isinstance(v, int) or isinstance(v, float): 

193 if hasattr(usage1, k): 

194 setattr(usage1, k, getattr(usage1, k) + v) 

195 elif isinstance(v, BaseModel): 

196 merge_usage(getattr(usage1, k), v) 

197 

198 merge_usage(self.usage, other.usage) # type: ignore 

199 

200 return self 

201 

202 def streaming( 

203 self, 

204 display: Optional[str] = None, 

205 title: str = "APPL Streaming", 

206 display_prefix_content: str = "", 

207 live: Optional[Live] = None, 

208 ) -> "CompletionResponse": 

209 """Stream the response object and finish the response.""" 

210 if not self.is_stream: 

211 raise ValueError("Cannot iterate over non-streaming response") 

212 if self.is_finished: 

213 return self 

214 

215 if self.response_obj is not None: 

216 target = self.response_obj 

217 else: 

218 target = self.format_stream() 

219 

220 display = display or configs.getattrs( 

221 "settings.logging.display.display_mode", "live" 

222 ) 

223 if display == "live": 

224 start_time = time.time() 

225 

226 def panel( 

227 content: str, iter_index: Optional[int] = None, truncate: bool = False 

228 ) -> Panel: 

229 style = "magenta" 

230 display_title = title 

231 if iter_index is not None: 

232 time_elapsed = time.time() - start_time 

233 avg_iters_per_sec = (iter_index + 1) / time_elapsed 

234 stream_info = ( 

235 f"[{time_elapsed:.3f}s ({avg_iters_per_sec:.2f} it/s)]" 

236 ) 

237 display_title += f" - {stream_info}" 

238 return make_panel( 

239 content, title=display_title, style=style, truncate=truncate 

240 ) 

241 

242 if live is None: 

243 live = get_live() 

244 need_stop = True 

245 else: 

246 need_stop = False 

247 content = display_prefix_content 

248 for i, chunk in enumerate(iter(target)): 

249 if isinstance(chunk, BaseModel): 

250 content = json.dumps(chunk.model_dump(), indent=2) 

251 else: 

252 content += str(chunk) 

253 live.update(panel(content, i, truncate=True)) 

254 # live.refresh() # might be too frequent 

255 # display untruncated content at the end 

256 live.update(panel(content, i)) 

257 live.refresh() 

258 if need_stop: 

259 stop_live() 

260 elif display == "print": 

261 last_content = "" 

262 

263 def eprint(content: str, color: Optional[Color] = None) -> None: 

264 print(colored(content, color) if color else content, end="") 

265 sys.stdout.flush() 

266 

267 eprint("\n===== START OF APPL STREAMING =====\n", color="magenta") 

268 self.register_post_finish_callback( 

269 lambda _: eprint( 

270 "\n===== END OF APPL STREAMING =====\n", color="magenta" 

271 ), 

272 order="first", 

273 ) 

274 eprint(display_prefix_content, color="grey") 

275 for chunk in iter(target): 

276 if isinstance(chunk, BaseModel): 

277 content = json.dumps(chunk.model_dump(), indent=2) 

278 if last_content in content: 

279 eprint(content[content.index(last_content) :]) 

280 else: 

281 eprint(content) 

282 last_content = content 

283 else: 

284 eprint(str(chunk)) 

285 

286 elif display == "none": 

287 for chunk in iter(target): 

288 pass 

289 else: 

290 raise ValueError( 

291 f"Unknown display argument: {display}, only 'live', 'print' and 'none' are supported" 

292 ) 

293 if self.response_obj is not None: 

294 self.set_response_obj(chunk) 

295 return self 

296 

297 def register_post_finish_callback( 

298 self, 

299 callback: Callable, 

300 order: Literal["first", "last"] = "last", 

301 ) -> None: 

302 """Register a post finish callback. 

303 

304 The callback will be called after the response is finished. 

305 """ 

306 if self.is_finished: 

307 callback(self) 

308 else: 

309 if order not in ["first", "last"]: 

310 raise ValueError( 

311 f"Unknown order argument: {order}, only 'first' and 'last' are supported" 

312 ) 

313 if order == "last": 

314 self.post_finish_callbacks.append(callback) 

315 else: 

316 self.post_finish_callbacks.insert(0, callback) 

317 

318 def format_stream(self): 

319 """Format the stream response as a text generator.""" 

320 suffix = "" 

321 for chunk in iter(self): 

322 # chunk: Union[ModelResponse, ChatCompletionChunk] 

323 delta: Union[Delta, ChoiceDelta] = chunk.choices[0].delta # type: ignore 

324 

325 if delta is not None: 

326 if delta.content is not None: 

327 yield delta.content 

328 elif getattr(delta, "tool_calls", None): 

329 f: Union[Function, ChoiceDeltaToolCallFunction] = delta.tool_calls[ 

330 0 

331 ].function # type: ignore 

332 if f.name is not None: 

333 if suffix: 

334 yield f"{suffix}, " 

335 yield f"{f.name}(" 

336 suffix = ")" 

337 if f.arguments is not None: 

338 yield f.arguments 

339 yield suffix 

340 

341 def _finish(self, response: Any) -> None: 

342 if self.is_finished: 

343 logger.warning("Response already finished. Ignoring finish call.") 

344 return 

345 self.is_finished = True 

346 self._complete_response = response 

347 self.usage = getattr(response, "usage", None) 

348 self.cost = 0.0 

349 try: 

350 self.cost = completion_cost(response) 

351 except Exception: 

352 pass 

353 # parse the message and tool calls 

354 if isinstance(response, (ModelResponse, ChatCompletion)): 

355 message = response.choices[0].message # type: ignore 

356 self.finish_reason = response.choices[0].finish_reason 

357 if tool_calls := getattr(message, "tool_calls", None): 

358 for call in tool_calls: 

359 self.tool_calls.append(ToolCall.from_openai_tool_call(call)) 

360 elif message.content is not None: 

361 self.message = message.content 

362 else: 

363 raise ValueError(f"Invalid response: {response}") 

364 elif response is None: 

365 logger.warning("Response is None, only used for testing") 

366 else: 

367 raise ValueError(f"Unknown response type: {type(response)}") 

368 

369 # post finish hook 

370 for callback in self.post_finish_callbacks: 

371 try: 

372 callback(self) 

373 except Exception as e: 

374 logger.error( 

375 f"Error when calling post finish callback {callback.__name__}: {e}" 

376 ) 

377 raise e 

378 

379 def _finish_stream(self) -> None: 

380 try: 

381 response = stream_chunk_builder(self.chunks) 

382 except Exception: 

383 logger.error("Error when building the response from the stream") 

384 raise 

385 self._finish(response) 

386 

387 def __str__(self): 

388 if self.is_stream and not self.is_finished: 

389 return repr(self) 

390 return str(self.results) 

391 

392 def __next__(self): 

393 if not self.is_stream: 

394 raise ValueError("Cannot iterate over non-streaming response") 

395 try: 

396 chunk = next(self.raw_response) 

397 self.chunks.append(chunk) 

398 return chunk 

399 except StopIteration: 

400 self._finish_stream() 

401 raise StopIteration 

402 

403 def __iter__(self): 

404 if not self.is_stream: 

405 raise ValueError("Cannot iterate over non-streaming response") 

406 return self 

407 

408 def __getattr__(self, name: str) -> Any: 

409 if not self.is_finished: 

410 logger.warning( 

411 f"Cannot get {name} attribute before the response is finished. " 

412 "Returning None." 

413 ) 

414 return None 

415 return getattr(self.complete_response, name)