Coverage for src/appl/core/generation.py: 84%

237 statements  

« prev     ^ index     » next       coverage.py v7.6.7, created at 2024-11-22 15:39 -0800

1from __future__ import annotations 

2 

3import json 

4import threading 

5from typing import ( 

6 Any, 

7 Callable, 

8 Generic, 

9 List, 

10 Optional, 

11 Sequence, 

12 Tuple, 

13 TypeVar, 

14 Union, 

15) 

16 

17from loguru import logger 

18from pydantic import BaseModel 

19from rich.live import Live 

20 

21from .config import configs 

22from .context import PromptContext 

23from .globals import ( 

24 get_thread_local, 

25 inc_global_var, 

26 inc_thread_local, 

27 set_thread_local, 

28) 

29from .message import AIMessage, BaseMessage, ToolMessage, UserMessage 

30from .promptable import Promptable 

31from .response import CompletionResponse 

32from .server import BaseServer, GenArgs 

33from .tool import BaseTool, ToolCall 

34from .trace import GenerationInitEvent, GenerationResponseEvent, add_to_trace 

35from .types import ( 

36 CallFuture, 

37 ExecutorType, 

38 MessageRole, 

39 MessageRoleType, 

40 ResponseType, 

41 String, 

42 StringFuture, 

43) 

44from .utils import get_live, split_last, stop_live, strip_for_continue 

45 

46M = TypeVar("M") 

47APPL_GEN_NAME_PREFIX_KEY = "_appl_gen_name_prefix" 

48LAST_LINE_MARKER = "<last_line>" 

49LAST_PART_MARKER = "<last_part>" 

50 

51 

52def set_gen_name_prefix(prefix: str) -> None: 

53 """Set the prefix for generation names in the current thread.""" 

54 set_thread_local(APPL_GEN_NAME_PREFIX_KEY, prefix) 

55 

56 

57def get_gen_name_prefix() -> Optional[str]: 

58 """Get the prefix for generation names in the current thread.""" 

59 gen_name_prefix = get_thread_local(APPL_GEN_NAME_PREFIX_KEY, None) 

60 if gen_name_prefix is None: 

61 thread_name = threading.current_thread().name 

62 if thread_name != "MainThread": 

63 gen_name_prefix = thread_name 

64 return gen_name_prefix 

65 

66 

67class Generation(Generic[M]): 

68 """Represents a generation call to the model.""" 

69 

70 def __init__( 

71 self, 

72 server: BaseServer, 

73 args: GenArgs, 

74 *, 

75 max_relay_rounds: int = 0, 

76 mock_response: Optional[Union[CompletionResponse, str]] = None, 

77 llm_executor_type: ExecutorType = ExecutorType.LLM_THREAD_POOL, 

78 lazy_eval: bool = False, 

79 _ctx: Optional[PromptContext] = None, 

80 **kwargs: Any, 

81 # kwargs used for extra args for the create method 

82 ) -> None: 

83 """Initialize the Generation object. 

84 

85 Args: 

86 server: An LLM server where the generation request will be sent. 

87 args: The arguments of the generation call. 

88 max_relay_rounds: the maximum number of relay rounds to continue the unfinished text generation. 

89 mock_response: A mock response for the generation call. 

90 llm_executor_type: The type of the executor to run the LLM call. 

91 lazy_eval: If True, the generation call will be evaluated lazily. 

92 _ctx: The prompt context filled automatically by the APPL function. 

93 **kwargs: Extra arguments for the generation call. 

94 """ 

95 # name needs to be unique and ordered, so it has to be generated in the main thread 

96 gen_name_prefix = get_gen_name_prefix() 

97 # take the value before increment 

98 self._cnt = inc_thread_local(f"{gen_name_prefix}_gen_cnt") - 1 

99 if gen_name_prefix is None: 

100 self._id = f"@gen_{self._cnt}" 

101 else: 

102 self._id = f"@{gen_name_prefix}_gen_{self._cnt}" 

103 

104 self._server = server 

105 self._model_name = server.model_name 

106 self._args = args 

107 self._max_relay_rounds = max_relay_rounds 

108 self._mock_response = mock_response 

109 self._llm_executor_type = llm_executor_type 

110 self._lazy_eval = lazy_eval 

111 self._ctx = _ctx 

112 self._extra_args = kwargs 

113 self._num_raw_completions = 0 

114 self._cached_response: Optional[CompletionResponse] = None 

115 

116 add_to_trace(GenerationInitEvent(name=self.id)) 

117 log_llm_call_args = configs.getattrs("settings.logging.display.llm_call_args") 

118 if log_llm_call_args: 

119 logger.info( 

120 f"Call generation [{self.id}] with args: {args} and kwargs: {kwargs}" 

121 ) 

122 

123 if isinstance(mock_response, CompletionResponse): 

124 

125 def get_response() -> CompletionResponse: 

126 return mock_response 

127 

128 self._call = self._wrap_response(get_response) 

129 else: 

130 if mock_response: 

131 # use litellm's mock response 

132 kwargs.update({"mock_response": mock_response}) 

133 self._call = self._wrap_response(self._call_llm()) 

134 

135 # tools 

136 self._tools: Sequence[BaseTool] = args.tools 

137 self._name2tools = {tool.name: tool for tool in self._tools} 

138 

139 def _call_llm(self) -> CallFuture[CompletionResponse]: 

140 """Call the LLM server asynchronously to get the completion response.""" 

141 self._num_raw_completions += 1 

142 return CallFuture( 

143 self._server.create, 

144 executor_type=self._llm_executor_type, 

145 lazy_eval=self._lazy_eval, 

146 args=self._args, 

147 gen_id=f"{self.id}_raw_{self._num_raw_completions - 1}", 

148 **self._extra_args, 

149 ) 

150 

151 def _continue_llm( 

152 self, results: CompletionResponse, live: Optional[Live] = None 

153 ) -> CompletionResponse: 

154 assert results.message is not None, "Not support continue for empty message" 

155 

156 cutoff_content = strip_for_continue(results.message) 

157 continue_prompt = configs.getattrs("prompts.continue_generation") 

158 continue_prompt_alt = configs.getattrs("prompts.continue_generation_alt") 

159 

160 # Choose a proper split marker for the continuation 

161 for split_marker in ["\n", " ", ","]: 

162 content, last_part = split_last(cutoff_content, split_marker) 

163 if content is not None: # found split_marker in the content 

164 prompt = ( 

165 continue_prompt if split_marker == "\n" else continue_prompt_alt 

166 ) 

167 marker = LAST_LINE_MARKER if split_marker == "\n" else LAST_PART_MARKER 

168 break 

169 marked_cutoff_content = f"{content}{split_marker}{marker}{last_part}{marker}" 

170 prompt = prompt.format(last_marker=marker) 

171 

172 messages = self._args.messages 

173 messages.append(AIMessage(content=marked_cutoff_content)) 

174 messages.append(UserMessage(content=prompt)) 

175 # print(messages, "\n") # DEBUG 

176 

177 # call the LLM again and wait for the result 

178 response = self._call_llm().result() 

179 if response.type == ResponseType.UNFINISHED: 

180 response.streaming( 

181 title=f"Continue generation [{self.id}]", 

182 display_prefix_content=marked_cutoff_content + "\n", 

183 live=live, 

184 ) 

185 

186 # pop the last two messages 

187 for _ in range(2): 

188 messages.pop() 

189 

190 results.update(response, split_marker) 

191 return response 

192 

193 def _wrap_response( 

194 self, get_response: Callable[[], CompletionResponse] 

195 ) -> Callable[[], CompletionResponse]: 

196 """Wrap the LLM calls to address incomplete completion.""" 

197 

198 def inner() -> CompletionResponse: 

199 log_llm_usage = configs.getattrs("settings.logging.display.llm_usage") 

200 log_llm_response = configs.getattrs("settings.logging.display.llm_response") 

201 log_llm_cost = configs.getattrs("settings.logging.display.llm_cost") 

202 

203 results = response = get_response() 

204 

205 if self._max_relay_rounds > 0: 

206 live = None 

207 display_mode = configs.getattrs( 

208 "settings.logging.display.display_mode", "live" 

209 ) 

210 need_live = self._args.stream and display_mode == "live" 

211 if response.type == ResponseType.UNFINISHED: 

212 if need_live: 

213 live = get_live() 

214 response.streaming(title=f"Generation [{self.id}]", live=live) 

215 

216 for i in range(self._max_relay_rounds): 

217 if response.finish_reason in ["length"]: 

218 generated_chars = len(results.message or "") 

219 logger.info( 

220 f"[Round {i + 1}/{self._max_relay_rounds}, " 

221 f"generated {generated_chars} chars] " 

222 f"Generation [{self.id}] was cut off due to max_tokens, " 

223 "automatically continue the generation." 

224 ) 

225 if need_live and live is None: 

226 live = get_live() 

227 response = self._continue_llm(results, live=live) 

228 else: 

229 break 

230 

231 if live is not None: 

232 stop_live() 

233 

234 def handle_results(results: CompletionResponse) -> None: 

235 if log_llm_response: 

236 logger.info(f"Generation [{self.id}] results: {results}") 

237 if results.usage and log_llm_usage: 

238 logger.info(f"Generation [{self.id}] token usage: {results.usage}") 

239 

240 num_requests = inc_global_var(f"{self._model_name}_num_requests") 

241 if log_llm_cost: 

242 currency = getattr(self._server, "_cost_currency", "USD") 

243 if self._mock_response is not None: 

244 logger.info( 

245 "Mock response, estimated cost for real request: " 

246 f"{results.cost:.4f} {currency}" 

247 ) 

248 elif results.cost is None: 

249 logger.warning( 

250 f"No cost information for generation [{self.id}]" 

251 ) 

252 else: 

253 total_cost = inc_global_var( 

254 f"{self._model_name}_api_cost", results.cost 

255 ) 

256 logger.info( 

257 f"API cost for this request: {results.cost:.4f}, " 

258 f"in total: {total_cost:.4f} {currency}. " 

259 f"Total number of requests: {num_requests}." 

260 ) 

261 create_args = self._server._get_create_args( 

262 self._args, **self._extra_args 

263 ) 

264 dump_args = create_args.copy() 

265 for k, v in dump_args.items(): 

266 if k in ["response_format", "response_model"]: 

267 if isinstance(v, type) and issubclass(v, BaseModel): 

268 dump_args[k] = json.dumps(v.model_json_schema(), indent=4) 

269 

270 add_to_trace( 

271 GenerationResponseEvent( 

272 name=self.id, args=dump_args, ret=str(results) 

273 ) 

274 ) 

275 

276 results.register_post_finish_callback(handle_results) 

277 

278 return results 

279 

280 return inner 

281 

282 @property 

283 def id(self) -> str: 

284 """The unique ID of the generation.""" 

285 return self._id 

286 

287 @property 

288 def response(self) -> CompletionResponse: 

289 """The response of the generation call.""" 

290 # NOTE: the result of the call should be cached 

291 if self._cached_response is None: 

292 self._cached_response = self._call() 

293 return self._cached_response 

294 

295 @property 

296 def response_type(self) -> ResponseType: 

297 """The type of the response.""" 

298 return self.response.type 

299 

300 @property 

301 def is_message(self) -> bool: 

302 """Whether the response is a text message.""" 

303 return self.response_type == ResponseType.TEXT 

304 

305 @property 

306 def is_tool_call(self) -> bool: 

307 """Whether the response is a tool call.""" 

308 return self.response_type == ResponseType.TOOL_CALL 

309 

310 @property 

311 def is_obj(self) -> bool: 

312 """Whether the response is an object.""" 

313 return self.response_type == ResponseType.OBJECT 

314 

315 @property 

316 def message(self) -> Optional[str]: 

317 """The message of the response.""" 

318 return self.response.message 

319 

320 @property 

321 def tool_calls(self) -> List[ToolCall]: 

322 """The tool calls of the response.""" 

323 return self.response.tool_calls 

324 

325 @property 

326 def response_obj(self) -> M: 

327 """The object of the response.""" 

328 return self.response.response_obj 

329 

330 @property 

331 def results(self) -> Union[M, str, List[ToolCall]]: 

332 """The results of the response.""" 

333 return self.response.results 

334 

335 @property 

336 def str_future(self) -> StringFuture: 

337 """The StringFuture representation of the response.""" 

338 return StringFuture(self) 

339 

340 @property 

341 def text_stream(self): 

342 """Get the response of the generation as a text stream.""" 

343 return self.response.format_stream() 

344 

345 def _call_tool( 

346 self, 

347 name: str, 

348 args: str, 

349 parallel: bool = False, 

350 executor_type: ExecutorType = ExecutorType.GENERAL_THREAD_POOL, 

351 ) -> Any: 

352 try: 

353 kwargs = json.loads(args) 

354 except json.JSONDecodeError as e: 

355 raise ValueError(f"Error parsing args: {args}") from e 

356 args_str = ", ".join([f"{k}={v}" for k, v in kwargs.items()]) 

357 if configs.getattrs("settings.logging.display.tool_calls"): 

358 logger.info(f"Running tool call: {name}({args_str})") 

359 

360 if name not in self._name2tools: 

361 raise ValueError(f"Error: Tool {name} not found") 

362 tool = self._name2tools[name] 

363 try: 

364 if parallel: 

365 res = CallFuture(tool, executor_type=executor_type, **kwargs) 

366 else: 

367 res = tool(**kwargs) 

368 except Exception as e: 

369 raise RuntimeError(f"Error running tool call: {name}({args_str})") from e 

370 

371 return res 

372 

373 def run_tool_calls( 

374 self, 

375 filter_fn: Optional[Callable[[List[ToolCall]], List[ToolCall]]] = None, 

376 parallel: bool = False, 

377 executor_type: ExecutorType = ExecutorType.GENERAL_THREAD_POOL, 

378 log_results: Optional[bool] = None, 

379 ) -> List[ToolMessage]: 

380 """Run all tool calls in the generation and return the results. 

381 

382 Args: 

383 filter_fn: 

384 A function that takes a list of ToolCall objects and returns 

385 a filtered list of ToolCall objects. This function can be 

386 used to filter the tool calls that will be run. 

387 parallel: If True, run the tool calls in parallel. Default to False. 

388 executor_type: 

389 The type of the executor to run the tool calls, can be 

390 "general_thread_pool", "general_process_pool", "new_thread" or 

391 "new_process". 

392 log_results: 

393 If True, log the results of the tool calls. Note This will wait for 

394 the results to be ready. Default to use the setting in configs. 

395 

396 Returns: 

397 A list of ToolMessage objects. 

398 """ 

399 if not self.is_tool_call: 

400 raise ValueError("Error: The Generation is not a tool call") 

401 if log_results is None: 

402 log_results = configs.getattrs("settings.logging.display.tool_results") 

403 tool_calls = self.tool_calls 

404 if filter_fn: 

405 tool_calls = filter_fn(tool_calls) 

406 messages = [] 

407 for tc in tool_calls: 

408 role = MessageRole(MessageRoleType.TOOL, tc.name) 

409 try: 

410 tool_result = self._call_tool( 

411 tc.name, tc.args, parallel=parallel, executor_type=executor_type 

412 ) 

413 msg = ToolMessage( 

414 tool_result, role=role, tool_call_id=tc.id, has_error=False 

415 ) 

416 except Exception as e: 

417 logger.error(f"Error running tool call: {tc.name}({tc.args})") 

418 logger.error(e) 

419 msg = ToolMessage(str(e), role=role, tool_call_id=tc.id, has_error=True) 

420 messages.append(msg) 

421 if log_results: # this will wait for the results to be ready 

422 for msg in messages: 

423 logger.info(f"Tool call result: {msg}") 

424 return messages 

425 

426 def as_prompt(self) -> Union[AIMessage, StringFuture]: 

427 """Get the response of the generation as a promptable object.""" 

428 if self._args.tools: 

429 if self.is_tool_call: 

430 return AIMessage(tool_calls=self.tool_calls) 

431 # return a future object of value: str(self._call()), without blocking 

432 return StringFuture(CallFuture(self._call)) 

433 

434 def __add__(self, other: Union[String, "Generation"]) -> StringFuture: 

435 # Assume generation is a string 

436 if isinstance(other, Generation): 

437 return self.str_future + other.str_future 

438 elif isinstance(other, (str, StringFuture)): 

439 return self.str_future + other 

440 raise TypeError( 

441 f"unsupported operand type(s) for +: 'Generation' and '{type(other)}'" 

442 ) 

443 

444 def __radd__(self, other: String) -> StringFuture: 

445 # Assume generation is a string 

446 if isinstance(other, (str, StringFuture)): 

447 return other + self.str_future 

448 raise TypeError( 

449 f"unsupported operand type(s) for +: '{type(other)}' and 'Generation'" 

450 ) 

451 

452 def __getattr__(self, name: str) -> Any: 

453 assert name != "response", "Internal Error within self.response" 

454 return getattr(self.response, name) 

455 

456 def __str__(self) -> str: 

457 return str(self.response.results) 

458 

459 def __repr__(self) -> str: 

460 return f"Generation(id={self.id})" 

461 

462 def __call__(self): 

463 """Get the response of the generation call.""" 

464 return self.response