Coverage for src/appl/func.py: 73%

174 statements  

« prev     ^ index     » next       coverage.py v7.6.7, created at 2024-11-22 15:39 -0800

1import copy 

2import inspect 

3from typing import ( 

4 Any, 

5 Callable, 

6 Dict, 

7 Literal, 

8 Optional, 

9 Sequence, 

10 Type, 

11 TypeVar, 

12 Union, 

13 get_origin, 

14 overload, 

15) 

16 

17from loguru import logger 

18from pydantic import BaseModel 

19 

20from .core import ( 

21 BaseTool, 

22 Compositor, 

23 Conversation, 

24 GenArgs, 

25 Generation, 

26 PrinterPop, 

27 PrinterPush, 

28 PromptContext, 

29 PromptFunc, 

30 PromptRecords, 

31 Tool, 

32 need_ctx, 

33 partial, 

34 wraps, 

35) 

36from .core.printer import Indexing 

37from .core.response import CompletionResponse 

38from .core.runtime import appl_execute 

39from .core.trace import traceable 

40from .servers import server_manager 

41from .types import ( 

42 CallFuture, 

43 ExecutorType, 

44 MaybeOneOrMany, 

45 OneOrMany, 

46 ParamSpec, 

47 StringFuture, 

48) 

49from .utils import _langsmith_traceable 

50 

51# https://docs.python.org/3/library/typing.html#typing.ParamSpec 

52# https://docs.python.org/3/library/typing.html#typing.Concatenate 

53# https://peps.python.org/pep-0612/ 

54# Callable[P, T] is used for static type inference (Pylance) 

55P = ParamSpec("P") 

56T = TypeVar("T") 

57F = TypeVar("F", bound=Callable) # function 

58M = TypeVar("M") # model 

59R = TypeVar("R") # return value 

60 

61 

62def auto_prime_gen(gen_func): 

63 """Decorate a generator to automatically prime the generator.""" 

64 

65 def wrapper(*args, **kwargs): 

66 gen = gen_func(*args, **kwargs) 

67 next(gen) # prime the generator 

68 return gen 

69 

70 return wrapper 

71 

72 

73@overload 

74def ppl(ctx: F) -> F: ... 

75 

76 

77@overload 

78def ppl( 

79 ctx: str = "new", 

80 comp: Optional[Compositor] = None, 

81 *, 

82 default_return: Optional[Literal["prompt"]] = None, 

83 include_docstring: bool = False, 

84 auto_prime: bool = False, 

85 num_extra_wrappers: int = 0, 

86 new_ctx_func: Callable = PromptContext, 

87) -> Callable[[F], F]: ... 

88 

89 

90def ppl( 

91 ctx: Union[str, F] = "new", 

92 comp: Optional[Compositor] = None, 

93 *, 

94 default_return: Optional[Literal["prompt"]] = None, 

95 include_docstring: bool = False, 

96 auto_prime: bool = False, 

97 num_extra_wrappers: int = 0, 

98 new_ctx_func: Callable = PromptContext, 

99) -> Union[Callable[[F], F], F]: 

100 """Decorate a function to mark it as an APPL function. 

101 

102 The function contains a prompt context, which could be same as or 

103 copied from its caller function, or created from scratch, or resumed 

104 from the last run. 

105 

106 Args: 

107 ctx (str): 

108 the method to deal with the child context, available methods includes: 

109 

110 - (default) "new" or "new_ctx": create a brand new context. 

111 - "copy" or "copy_ctx": 

112 copy from the parent's context, the change will not 

113 affect the parent's context. 

114 - "same" or "same_ctx": 

115 use the same context as the parent's, the change will 

116 affect the parent's context. 

117 - "resume" or "resume_ctx": 

118 resume its own context from the last run. 

119 For the first run, it will use the parent's context. 

120 

121 comp (Compositor, optional): 

122 the default compositor to be used. Defaults to None. 

123 default_return (str, optional): 

124 The default return value, "prompt" means return the prompt within 

125 the function. Defaults to None. 

126 include_docstring (bool, optional): 

127 set to True to include the triple-quoted docstring in the prompt. 

128 Defaults to False. 

129 auto_prime (bool, optional): 

130 set to True to automatically prime the generator. Defaults to False. 

131 num_extra_wrappers (int, optional): 

132 the number of extra wrappers to go back to the caller frame. 

133 new_ctx_func (Callable, optional): 

134 the function to create a new context. Defaults to PromptContext. 

135 """ 

136 # The same doc string as PromptFunc (excluding the func argument) 

137 

138 ctx_method: str = "new" 

139 

140 def decorator(func: F) -> F: 

141 """Decorate a function as prompt function.""" 

142 _is_class_method = False 

143 if "." in (qualname := func.__qualname__): 

144 # NOTE: this is a workaround for class methods, may not cover all cases 

145 qualnames = qualname.split(".") 

146 if qualnames[-2] != "<locals>": 

147 _is_class_method = True 

148 

149 # ? should disable such usage? 

150 # if not _is_class_method and "<locals>" in qualname and ctx_method == "resume": 

151 # raise ValueError("Cannot use 'resume' with local functions.") 

152 prompt_func = PromptFunc( 

153 func, ctx_method, comp, default_return, include_docstring, new_ctx_func 

154 ) 

155 

156 @need_ctx 

157 @traceable() 

158 @_langsmith_traceable(name=func.__qualname__, metadata={"appl": "func"}) # type: ignore 

159 @wraps(func) 

160 def wrapper( 

161 *args: Any, 

162 _globals: Optional[Dict] = None, 

163 _locals: Optional[Dict] = None, 

164 **kwargs: Any, 

165 ) -> Any: 

166 # closure variables 

167 freevars = prompt_func.compiled_func.freevars 

168 if _locals is None: 

169 # * Workaround for closure variables 

170 # Default: use the locals from the caller 

171 frame = inspect.currentframe() 

172 num_wrappers = (4 if auto_prime else 3) + num_extra_wrappers 

173 for _ in range(num_wrappers): 

174 if frame is None: 

175 raise RuntimeError("No caller frame found") 

176 # back to @_langsmith_traceable, @traceable, and the caller frame 

177 frame = frame.f_back 

178 if frame is None: 

179 raise RuntimeError("No caller frame found") 

180 _locals = frame.f_locals 

181 

182 if len(freevars): 

183 vars = {var: _locals.get(var, "NotFound") for var in freevars} 

184 logger.debug( 

185 f"For freevars of function {func.__name__}, " 

186 f"automatically using locals from the caller: {vars}" 

187 ) 

188 for var in freevars: 

189 if var not in _locals: 

190 logger.warning( 

191 f"could not find variable {var} automatically from" 

192 f"the caller frame for function {func.__name__}. " 

193 "If you have wrapper around the function, you may need" 

194 "to set the `num_extra_wrappers` in @ppl function." 

195 ) 

196 results = prompt_func( 

197 *args, 

198 _globals=_globals, 

199 _locals=_locals, 

200 _is_class_method=_is_class_method, 

201 **kwargs, 

202 ) 

203 

204 return results 

205 

206 if auto_prime: 

207 wrapper = auto_prime_gen(wrapper) 

208 setattr(wrapper, "_prompt_func", prompt_func) 

209 return wrapper # type: ignore 

210 

211 if isinstance(ctx, str): 

212 ctx_method = ctx 

213 # used as a decorator with arguments (e.g., @ppl(ctx="copy")) 

214 # returns a decorator that takes a function as input 

215 return decorator 

216 else: 

217 # used as a single decorator (i.e., @ppl) 

218 return decorator(func=ctx) # returns a wrapper 

219 

220 

221def reset_context(func: Callable) -> None: 

222 """Reset the context for APPL functions with the 'resume' context method.""" 

223 if prompt_func := getattr(func, "_prompt_func", None): 

224 if reset_func := getattr(prompt_func, "_reset_context_func", None): 

225 reset_func() 

226 logger.info(f"Context reset for function {func.__name__}") 

227 else: 

228 logger.warning(f"Nothing to reset for function {func.__name__}") 

229 else: 

230 logger.warning(f"Not an APPL function: {func.__name__}, cannot reset context.") 

231 

232 

233def as_func( 

234 func: Callable[P, T], 

235 _globals: Optional[Dict] = None, 

236 _locals: Optional[Dict] = None, 

237) -> Callable[P, T]: 

238 """Fill the globals and locals for a ppl function. 

239 

240 When locals not provided, it will use the locals from the caller. 

241 """ 

242 frame = inspect.currentframe() 

243 if _locals is None and frame is not None and frame.f_back is not None: 

244 _locals = frame.f_back.f_locals 

245 return partial(func, _globals=_globals, _locals=_locals) 

246 

247 

248def str_future(obj: Any) -> StringFuture: 

249 """Convert an object to a StringFuture object.""" 

250 return StringFuture(obj) 

251 

252 

253def as_tool(func: Callable, **kwargs: Any) -> Tool: 

254 """Wrap a given function with additional predefined arguments into a Tool. 

255 

256 This function allows converting a standard function into a 'Tool' by 

257 specifying the function and any additional arguments that should be 

258 pre-defined for it. These additional arguments are passed as keyword 

259 arguments and will be bound to the function within the Tool object, 

260 so that these arguments are not required when using this tool. 

261 

262 Args: 

263 func (Callable): 

264 The function to be converted into a Tool. 

265 **kwargs: 

266 Keyword arguments that will be predefined for the function in 

267 the Tool object. 

268 

269 Returns: 

270 Tool: 

271 An object encapsulating the given function and its predefined 

272 arguments, ready to be utilized as a Tool. 

273 

274 Examples: 

275 Given a function `move_disk` that requires an environment and two 

276 pegs to move a disk from one peg to another in the Tower of Hanoi 

277 puzzle, one can create a tool with a predefined environment by: 

278 

279 ```python 

280 def move_disk(env: HanoiEnv, from_peg: int, to_peg: int) -> str: 

281 pass 

282 

283 env = HanoiEnv() 

284 tools = [as_tool(move_disk, env=env)] 

285 ``` 

286 

287 In this example, `move_disk` is encapsulated into a Tool with `env` 

288 predefined, so only `from_peg` and `to_peg` are required. 

289 """ 

290 return Tool(func=func, **kwargs) 

291 

292 

293def as_tool_choice(obj: Union[str, Callable, BaseTool]) -> dict: 

294 """Build a tool choice argument for the OpenAI API from an object.""" 

295 if isinstance(obj, BaseTool): 

296 name = obj.name 

297 else: 

298 name = getattr(obj, "__name__", str(obj)) 

299 return dict(type="function", function=dict(name=name)) 

300 

301 

302def call( 

303 func: Callable, 

304 *args: Any, 

305 executor_type: ExecutorType = ExecutorType.GENERAL_THREAD_POOL, 

306 **kwargs: Any, 

307) -> CallFuture: 

308 """Create a CallFuture object from a function and its arguments. 

309 

310 The CallFuture object will call the function in a separate thread or process, 

311 therefore the function need to be thread-safe or process-safe. 

312 """ 

313 return CallFuture(func, *args, executor_type=executor_type, **kwargs) 

314 

315 

316def openai_tool_schema(func: Callable) -> dict: 

317 """Build openai tool schema from a function.""" 

318 return as_tool(func).openai_schema 

319 

320 

321@need_ctx 

322def get_var(name: str, _ctx: PromptContext) -> Any: 

323 """Get a variable by name from the prompt context.""" 

324 return getattr(_ctx, name) 

325 

326 

327@need_ctx 

328def records(_ctx: Optional[PromptContext] = None) -> PromptRecords: 

329 """Return the prompt defined in the current function. 

330 

331 Similar to locals() in Python in some sense. 

332 """ 

333 # add default value for _ctx to avoid the warning of type checker 

334 if _ctx is None: 

335 raise ValueError( 

336 "PromptContext is required for records, " 

337 "this function should be called within @ppl function." 

338 ) 

339 return _ctx.records 

340 

341 

342@need_ctx 

343def convo(_ctx: Optional[PromptContext] = None) -> Conversation: 

344 """Return the full conversation in the context. 

345 

346 Similar to globals() in Python in some sense. 

347 """ 

348 # Added default value for _ctx to avoid the warning of type checker 

349 if _ctx is None: 

350 raise ValueError( 

351 "PromptContext is required for convo, " 

352 "this function should be called within @ppl function." 

353 ) 

354 return _ctx.messages 

355 

356 

357def empty_line(num_lines: int = 1) -> PromptRecords: 

358 """Create empty lines regardless of other compositor.""" 

359 records = PromptRecords() 

360 records.record(PrinterPush(separator="\n", indexing=Indexing(), new_indent="")) 

361 for _ in range(num_lines): 

362 records.record("") 

363 records.record(PrinterPop()) 

364 return records 

365 

366 

367def build_tools(tools: OneOrMany[Union[BaseTool, Callable]]) -> Sequence[BaseTool]: 

368 """Build a list of tools from the given tools or functions.""" 

369 

370 def convert_to_tool(tool: Union[BaseTool, Callable]) -> BaseTool: 

371 if isinstance(tool, BaseTool): 

372 return tool 

373 if callable(tool): 

374 return as_tool(tool) 

375 raise ValueError(f"Invalid tool: {tool}") 

376 

377 # process tools 

378 if isinstance(tools, BaseTool) or callable(tools): 

379 return [convert_to_tool(tools)] 

380 if isinstance(tools, Sequence): 

381 return [convert_to_tool(tool) for tool in tools] 

382 raise ValueError(f"Invalid tools: {tools}") 

383 

384 

385@need_ctx 

386def grow(content: Any, *, _ctx: Optional[PromptContext] = None) -> None: 

387 """Append the content to the prompt.""" 

388 if _ctx is None: 

389 raise ValueError( 

390 "PromptContext is required for appending. " 

391 "Normally, it should be automatically filled." 

392 ) 

393 appl_execute(content, _ctx=_ctx) 

394 

395 

396@need_ctx 

397def gen( 

398 server: Optional[str] = None, 

399 *, 

400 max_tokens: Optional[int] = None, 

401 stop: MaybeOneOrMany[str] = None, 

402 temperature: Optional[float] = None, 

403 top_p: Optional[float] = None, 

404 n: Optional[int] = None, 

405 tools: OneOrMany[Union[BaseTool, Callable]] = [], # TODO: support dict 

406 tool_format: str = "auto", 

407 stream: Optional[bool] = None, 

408 response_format: Optional[Union[dict, str, Type[M]]] = None, 

409 response_model: Optional[Type[M]] = None, 

410 max_relay_rounds: int = 0, 

411 mock_response: Optional[Union[CompletionResponse, str]] = None, 

412 messages_process_func: Optional[Callable[[Conversation], Conversation]] = None, 

413 _ctx: Optional[PromptContext] = None, 

414 **kwargs: Any, 

415) -> Generation[M]: 

416 """Send a generation request to the LLM backend. 

417 

418 Args: 

419 server (str, optional): 

420 name of the backend server. Defaults to the default server set in the configs. 

421 max_tokens (int, optional): maximum number of tokens to generate. Defaults to None. 

422 stop (str|Sequence[str], optional): stop sequence(s). Defaults to None. 

423 temperature (float, optional): temperature for sampling. Defaults to None. 

424 top_p (float, optional): nucleus sampling parameter. Defaults to None. 

425 n (int, optional): number of choices to generate. Defaults to 1. 

426 tools (BaseTool|Callable|Sequence[BaseTool|Callable], optional): 

427 tools can be used. Defaults to None. 

428 tool_format (str, optional): the format for the tools. Defaults to "auto". 

429 stream (bool, optional): whether to stream the results. Defaults to False. 

430 response_format (Union[dict, str, Type[M]], optional): 

431 OpenAI's argument specifies the response format. Defaults to None. 

432 response_model (Type[M], optional): 

433 instructor's argument specifies the response format as a Pydantic model. 

434 use `instructor_patch_mode` to specify the mode for patching the raw completion. 

435 Recommended to use `response_format` instead. Defaults to None. 

436 max_relay_rounds (int, optional): 

437 the maximum number of relay rounds to continue the unfinished text generation. Defaults to 0. 

438 mock_response (Union[CompletionResponse, str], optional): 

439 mock response for testing. Defaults to None. 

440 messages_process_func (Callable[[Conversation], Conversation], optional): 

441 a function to process the messages before sending to the LLM. 

442 Defaults to None. 

443 _ctx (PromptContext): prompt context, will be automatically filled. 

444 kwargs (Any): extra arguments for the generation. 

445 

446 Returns: 

447 Generation: a future object representing the generation result 

448 """ 

449 backend_server = server_manager.get_server(server) 

450 if _ctx is None: 

451 raise ValueError( 

452 "PromptContext is required for generation." 

453 "Normally, it should be automatically filled." 

454 ) 

455 messages = _ctx.messages 

456 messages.materialize() # materialize the messages 

457 # TODO: double check the correctness 

458 messages = copy.deepcopy(messages) # freeze the prompt for the generation 

459 if messages_process_func: 

460 messages = messages_process_func(messages) 

461 

462 if isinstance(response_format, str): 

463 if response_format != "json": 

464 raise ValueError( 

465 "Only 'json' is supported for response_format in string format." 

466 ) 

467 response_format = {"type": "json_object"} 

468 

469 response_format_is_pydantic_model = False 

470 if response_format is not None: 

471 if isinstance(response_format, dict): 

472 if "type" not in response_format: 

473 raise ValueError("response_format must specify the type.") 

474 if response_format["type"] not in ("json_object", "json_schema"): 

475 raise ValueError( 

476 "Only 'json_object' and 'json_schema' are supported for response_format." 

477 ) 

478 elif isinstance(response_format, type) and issubclass( 

479 response_format, BaseModel 

480 ): 

481 response_format_is_pydantic_model = True 

482 else: 

483 raise ValueError( 

484 "Invalid response_format, can be a dict or a Pydantic model." 

485 ) 

486 

487 if response_model is not None: 

488 raise ValueError( 

489 "response_format and response_model cannot be used together." 

490 ) 

491 

492 if max_relay_rounds > 0: 

493 if response_format_is_pydantic_model or response_model is not None: 

494 raise ValueError( 

495 "max_relay_rounds cannot be used when response_format is " 

496 "a Pydantic model or response_model is specified." 

497 ) 

498 elif response_format is not None: 

499 logger.warning( 

500 "automatic continuation may not work well when response_format is specified. " 

501 "Recommend using plain text generation instead." 

502 ) 

503 

504 if ( 

505 isinstance(get_origin(response_format), type) 

506 or get_origin(response_format) is Literal 

507 or isinstance(response_format, type) 

508 and not issubclass(response_format, BaseModel) 

509 ): 

510 

511 class Response(BaseModel): 

512 response: response_format # type: ignore 

513 

514 response_format = Response # type: ignore 

515 kwargs["_wrapped_attribute"] = "response" 

516 

517 create_args = GenArgs( 

518 model=backend_server.model_name, 

519 messages=messages, 

520 max_tokens=max_tokens, 

521 stop=stop, 

522 temperature=temperature, 

523 top_p=top_p, 

524 n=n, 

525 tools=build_tools(tools), 

526 tool_format=tool_format, # type: ignore 

527 stream=stream, 

528 response_format=response_format, # type: ignore 

529 response_model=response_model, # type: ignore 

530 ) 

531 

532 generation = Generation[M]( 

533 backend_server, 

534 create_args, 

535 max_relay_rounds=max_relay_rounds, 

536 mock_response=mock_response, 

537 _ctx=_ctx, 

538 **kwargs, 

539 ) 

540 

541 @_langsmith_traceable(name=generation.id, metadata={"appl": "gen"}) # type: ignore 

542 def langsmith_trace(*args: Any, **kwargs: Any) -> None: 

543 pass 

544 

545 langsmith_trace(backend_server, create_args, _ctx=_ctx, **kwargs) 

546 return generation 

547 

548 

549# def serialize(obj: Any) -> Any: 

550# if hasattr(obj, "serialize"): 

551# return obj.serialize() 

552# else: 

553# raise TypeError("Object of type '%s' is not serializable" % type(obj).__name__)