Coverage for src/appl/func.py: 73%
174 statements
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-22 15:39 -0800
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-22 15:39 -0800
1import copy
2import inspect
3from typing import (
4 Any,
5 Callable,
6 Dict,
7 Literal,
8 Optional,
9 Sequence,
10 Type,
11 TypeVar,
12 Union,
13 get_origin,
14 overload,
15)
17from loguru import logger
18from pydantic import BaseModel
20from .core import (
21 BaseTool,
22 Compositor,
23 Conversation,
24 GenArgs,
25 Generation,
26 PrinterPop,
27 PrinterPush,
28 PromptContext,
29 PromptFunc,
30 PromptRecords,
31 Tool,
32 need_ctx,
33 partial,
34 wraps,
35)
36from .core.printer import Indexing
37from .core.response import CompletionResponse
38from .core.runtime import appl_execute
39from .core.trace import traceable
40from .servers import server_manager
41from .types import (
42 CallFuture,
43 ExecutorType,
44 MaybeOneOrMany,
45 OneOrMany,
46 ParamSpec,
47 StringFuture,
48)
49from .utils import _langsmith_traceable
51# https://docs.python.org/3/library/typing.html#typing.ParamSpec
52# https://docs.python.org/3/library/typing.html#typing.Concatenate
53# https://peps.python.org/pep-0612/
54# Callable[P, T] is used for static type inference (Pylance)
55P = ParamSpec("P")
56T = TypeVar("T")
57F = TypeVar("F", bound=Callable) # function
58M = TypeVar("M") # model
59R = TypeVar("R") # return value
62def auto_prime_gen(gen_func):
63 """Decorate a generator to automatically prime the generator."""
65 def wrapper(*args, **kwargs):
66 gen = gen_func(*args, **kwargs)
67 next(gen) # prime the generator
68 return gen
70 return wrapper
73@overload
74def ppl(ctx: F) -> F: ...
77@overload
78def ppl(
79 ctx: str = "new",
80 comp: Optional[Compositor] = None,
81 *,
82 default_return: Optional[Literal["prompt"]] = None,
83 include_docstring: bool = False,
84 auto_prime: bool = False,
85 num_extra_wrappers: int = 0,
86 new_ctx_func: Callable = PromptContext,
87) -> Callable[[F], F]: ...
90def ppl(
91 ctx: Union[str, F] = "new",
92 comp: Optional[Compositor] = None,
93 *,
94 default_return: Optional[Literal["prompt"]] = None,
95 include_docstring: bool = False,
96 auto_prime: bool = False,
97 num_extra_wrappers: int = 0,
98 new_ctx_func: Callable = PromptContext,
99) -> Union[Callable[[F], F], F]:
100 """Decorate a function to mark it as an APPL function.
102 The function contains a prompt context, which could be same as or
103 copied from its caller function, or created from scratch, or resumed
104 from the last run.
106 Args:
107 ctx (str):
108 the method to deal with the child context, available methods includes:
110 - (default) "new" or "new_ctx": create a brand new context.
111 - "copy" or "copy_ctx":
112 copy from the parent's context, the change will not
113 affect the parent's context.
114 - "same" or "same_ctx":
115 use the same context as the parent's, the change will
116 affect the parent's context.
117 - "resume" or "resume_ctx":
118 resume its own context from the last run.
119 For the first run, it will use the parent's context.
121 comp (Compositor, optional):
122 the default compositor to be used. Defaults to None.
123 default_return (str, optional):
124 The default return value, "prompt" means return the prompt within
125 the function. Defaults to None.
126 include_docstring (bool, optional):
127 set to True to include the triple-quoted docstring in the prompt.
128 Defaults to False.
129 auto_prime (bool, optional):
130 set to True to automatically prime the generator. Defaults to False.
131 num_extra_wrappers (int, optional):
132 the number of extra wrappers to go back to the caller frame.
133 new_ctx_func (Callable, optional):
134 the function to create a new context. Defaults to PromptContext.
135 """
136 # The same doc string as PromptFunc (excluding the func argument)
138 ctx_method: str = "new"
140 def decorator(func: F) -> F:
141 """Decorate a function as prompt function."""
142 _is_class_method = False
143 if "." in (qualname := func.__qualname__):
144 # NOTE: this is a workaround for class methods, may not cover all cases
145 qualnames = qualname.split(".")
146 if qualnames[-2] != "<locals>":
147 _is_class_method = True
149 # ? should disable such usage?
150 # if not _is_class_method and "<locals>" in qualname and ctx_method == "resume":
151 # raise ValueError("Cannot use 'resume' with local functions.")
152 prompt_func = PromptFunc(
153 func, ctx_method, comp, default_return, include_docstring, new_ctx_func
154 )
156 @need_ctx
157 @traceable()
158 @_langsmith_traceable(name=func.__qualname__, metadata={"appl": "func"}) # type: ignore
159 @wraps(func)
160 def wrapper(
161 *args: Any,
162 _globals: Optional[Dict] = None,
163 _locals: Optional[Dict] = None,
164 **kwargs: Any,
165 ) -> Any:
166 # closure variables
167 freevars = prompt_func.compiled_func.freevars
168 if _locals is None:
169 # * Workaround for closure variables
170 # Default: use the locals from the caller
171 frame = inspect.currentframe()
172 num_wrappers = (4 if auto_prime else 3) + num_extra_wrappers
173 for _ in range(num_wrappers):
174 if frame is None:
175 raise RuntimeError("No caller frame found")
176 # back to @_langsmith_traceable, @traceable, and the caller frame
177 frame = frame.f_back
178 if frame is None:
179 raise RuntimeError("No caller frame found")
180 _locals = frame.f_locals
182 if len(freevars):
183 vars = {var: _locals.get(var, "NotFound") for var in freevars}
184 logger.debug(
185 f"For freevars of function {func.__name__}, "
186 f"automatically using locals from the caller: {vars}"
187 )
188 for var in freevars:
189 if var not in _locals:
190 logger.warning(
191 f"could not find variable {var} automatically from"
192 f"the caller frame for function {func.__name__}. "
193 "If you have wrapper around the function, you may need"
194 "to set the `num_extra_wrappers` in @ppl function."
195 )
196 results = prompt_func(
197 *args,
198 _globals=_globals,
199 _locals=_locals,
200 _is_class_method=_is_class_method,
201 **kwargs,
202 )
204 return results
206 if auto_prime:
207 wrapper = auto_prime_gen(wrapper)
208 setattr(wrapper, "_prompt_func", prompt_func)
209 return wrapper # type: ignore
211 if isinstance(ctx, str):
212 ctx_method = ctx
213 # used as a decorator with arguments (e.g., @ppl(ctx="copy"))
214 # returns a decorator that takes a function as input
215 return decorator
216 else:
217 # used as a single decorator (i.e., @ppl)
218 return decorator(func=ctx) # returns a wrapper
221def reset_context(func: Callable) -> None:
222 """Reset the context for APPL functions with the 'resume' context method."""
223 if prompt_func := getattr(func, "_prompt_func", None):
224 if reset_func := getattr(prompt_func, "_reset_context_func", None):
225 reset_func()
226 logger.info(f"Context reset for function {func.__name__}")
227 else:
228 logger.warning(f"Nothing to reset for function {func.__name__}")
229 else:
230 logger.warning(f"Not an APPL function: {func.__name__}, cannot reset context.")
233def as_func(
234 func: Callable[P, T],
235 _globals: Optional[Dict] = None,
236 _locals: Optional[Dict] = None,
237) -> Callable[P, T]:
238 """Fill the globals and locals for a ppl function.
240 When locals not provided, it will use the locals from the caller.
241 """
242 frame = inspect.currentframe()
243 if _locals is None and frame is not None and frame.f_back is not None:
244 _locals = frame.f_back.f_locals
245 return partial(func, _globals=_globals, _locals=_locals)
248def str_future(obj: Any) -> StringFuture:
249 """Convert an object to a StringFuture object."""
250 return StringFuture(obj)
253def as_tool(func: Callable, **kwargs: Any) -> Tool:
254 """Wrap a given function with additional predefined arguments into a Tool.
256 This function allows converting a standard function into a 'Tool' by
257 specifying the function and any additional arguments that should be
258 pre-defined for it. These additional arguments are passed as keyword
259 arguments and will be bound to the function within the Tool object,
260 so that these arguments are not required when using this tool.
262 Args:
263 func (Callable):
264 The function to be converted into a Tool.
265 **kwargs:
266 Keyword arguments that will be predefined for the function in
267 the Tool object.
269 Returns:
270 Tool:
271 An object encapsulating the given function and its predefined
272 arguments, ready to be utilized as a Tool.
274 Examples:
275 Given a function `move_disk` that requires an environment and two
276 pegs to move a disk from one peg to another in the Tower of Hanoi
277 puzzle, one can create a tool with a predefined environment by:
279 ```python
280 def move_disk(env: HanoiEnv, from_peg: int, to_peg: int) -> str:
281 pass
283 env = HanoiEnv()
284 tools = [as_tool(move_disk, env=env)]
285 ```
287 In this example, `move_disk` is encapsulated into a Tool with `env`
288 predefined, so only `from_peg` and `to_peg` are required.
289 """
290 return Tool(func=func, **kwargs)
293def as_tool_choice(obj: Union[str, Callable, BaseTool]) -> dict:
294 """Build a tool choice argument for the OpenAI API from an object."""
295 if isinstance(obj, BaseTool):
296 name = obj.name
297 else:
298 name = getattr(obj, "__name__", str(obj))
299 return dict(type="function", function=dict(name=name))
302def call(
303 func: Callable,
304 *args: Any,
305 executor_type: ExecutorType = ExecutorType.GENERAL_THREAD_POOL,
306 **kwargs: Any,
307) -> CallFuture:
308 """Create a CallFuture object from a function and its arguments.
310 The CallFuture object will call the function in a separate thread or process,
311 therefore the function need to be thread-safe or process-safe.
312 """
313 return CallFuture(func, *args, executor_type=executor_type, **kwargs)
316def openai_tool_schema(func: Callable) -> dict:
317 """Build openai tool schema from a function."""
318 return as_tool(func).openai_schema
321@need_ctx
322def get_var(name: str, _ctx: PromptContext) -> Any:
323 """Get a variable by name from the prompt context."""
324 return getattr(_ctx, name)
327@need_ctx
328def records(_ctx: Optional[PromptContext] = None) -> PromptRecords:
329 """Return the prompt defined in the current function.
331 Similar to locals() in Python in some sense.
332 """
333 # add default value for _ctx to avoid the warning of type checker
334 if _ctx is None:
335 raise ValueError(
336 "PromptContext is required for records, "
337 "this function should be called within @ppl function."
338 )
339 return _ctx.records
342@need_ctx
343def convo(_ctx: Optional[PromptContext] = None) -> Conversation:
344 """Return the full conversation in the context.
346 Similar to globals() in Python in some sense.
347 """
348 # Added default value for _ctx to avoid the warning of type checker
349 if _ctx is None:
350 raise ValueError(
351 "PromptContext is required for convo, "
352 "this function should be called within @ppl function."
353 )
354 return _ctx.messages
357def empty_line(num_lines: int = 1) -> PromptRecords:
358 """Create empty lines regardless of other compositor."""
359 records = PromptRecords()
360 records.record(PrinterPush(separator="\n", indexing=Indexing(), new_indent=""))
361 for _ in range(num_lines):
362 records.record("")
363 records.record(PrinterPop())
364 return records
367def build_tools(tools: OneOrMany[Union[BaseTool, Callable]]) -> Sequence[BaseTool]:
368 """Build a list of tools from the given tools or functions."""
370 def convert_to_tool(tool: Union[BaseTool, Callable]) -> BaseTool:
371 if isinstance(tool, BaseTool):
372 return tool
373 if callable(tool):
374 return as_tool(tool)
375 raise ValueError(f"Invalid tool: {tool}")
377 # process tools
378 if isinstance(tools, BaseTool) or callable(tools):
379 return [convert_to_tool(tools)]
380 if isinstance(tools, Sequence):
381 return [convert_to_tool(tool) for tool in tools]
382 raise ValueError(f"Invalid tools: {tools}")
385@need_ctx
386def grow(content: Any, *, _ctx: Optional[PromptContext] = None) -> None:
387 """Append the content to the prompt."""
388 if _ctx is None:
389 raise ValueError(
390 "PromptContext is required for appending. "
391 "Normally, it should be automatically filled."
392 )
393 appl_execute(content, _ctx=_ctx)
396@need_ctx
397def gen(
398 server: Optional[str] = None,
399 *,
400 max_tokens: Optional[int] = None,
401 stop: MaybeOneOrMany[str] = None,
402 temperature: Optional[float] = None,
403 top_p: Optional[float] = None,
404 n: Optional[int] = None,
405 tools: OneOrMany[Union[BaseTool, Callable]] = [], # TODO: support dict
406 tool_format: str = "auto",
407 stream: Optional[bool] = None,
408 response_format: Optional[Union[dict, str, Type[M]]] = None,
409 response_model: Optional[Type[M]] = None,
410 max_relay_rounds: int = 0,
411 mock_response: Optional[Union[CompletionResponse, str]] = None,
412 messages_process_func: Optional[Callable[[Conversation], Conversation]] = None,
413 _ctx: Optional[PromptContext] = None,
414 **kwargs: Any,
415) -> Generation[M]:
416 """Send a generation request to the LLM backend.
418 Args:
419 server (str, optional):
420 name of the backend server. Defaults to the default server set in the configs.
421 max_tokens (int, optional): maximum number of tokens to generate. Defaults to None.
422 stop (str|Sequence[str], optional): stop sequence(s). Defaults to None.
423 temperature (float, optional): temperature for sampling. Defaults to None.
424 top_p (float, optional): nucleus sampling parameter. Defaults to None.
425 n (int, optional): number of choices to generate. Defaults to 1.
426 tools (BaseTool|Callable|Sequence[BaseTool|Callable], optional):
427 tools can be used. Defaults to None.
428 tool_format (str, optional): the format for the tools. Defaults to "auto".
429 stream (bool, optional): whether to stream the results. Defaults to False.
430 response_format (Union[dict, str, Type[M]], optional):
431 OpenAI's argument specifies the response format. Defaults to None.
432 response_model (Type[M], optional):
433 instructor's argument specifies the response format as a Pydantic model.
434 use `instructor_patch_mode` to specify the mode for patching the raw completion.
435 Recommended to use `response_format` instead. Defaults to None.
436 max_relay_rounds (int, optional):
437 the maximum number of relay rounds to continue the unfinished text generation. Defaults to 0.
438 mock_response (Union[CompletionResponse, str], optional):
439 mock response for testing. Defaults to None.
440 messages_process_func (Callable[[Conversation], Conversation], optional):
441 a function to process the messages before sending to the LLM.
442 Defaults to None.
443 _ctx (PromptContext): prompt context, will be automatically filled.
444 kwargs (Any): extra arguments for the generation.
446 Returns:
447 Generation: a future object representing the generation result
448 """
449 backend_server = server_manager.get_server(server)
450 if _ctx is None:
451 raise ValueError(
452 "PromptContext is required for generation."
453 "Normally, it should be automatically filled."
454 )
455 messages = _ctx.messages
456 messages.materialize() # materialize the messages
457 # TODO: double check the correctness
458 messages = copy.deepcopy(messages) # freeze the prompt for the generation
459 if messages_process_func:
460 messages = messages_process_func(messages)
462 if isinstance(response_format, str):
463 if response_format != "json":
464 raise ValueError(
465 "Only 'json' is supported for response_format in string format."
466 )
467 response_format = {"type": "json_object"}
469 response_format_is_pydantic_model = False
470 if response_format is not None:
471 if isinstance(response_format, dict):
472 if "type" not in response_format:
473 raise ValueError("response_format must specify the type.")
474 if response_format["type"] not in ("json_object", "json_schema"):
475 raise ValueError(
476 "Only 'json_object' and 'json_schema' are supported for response_format."
477 )
478 elif isinstance(response_format, type) and issubclass(
479 response_format, BaseModel
480 ):
481 response_format_is_pydantic_model = True
482 else:
483 raise ValueError(
484 "Invalid response_format, can be a dict or a Pydantic model."
485 )
487 if response_model is not None:
488 raise ValueError(
489 "response_format and response_model cannot be used together."
490 )
492 if max_relay_rounds > 0:
493 if response_format_is_pydantic_model or response_model is not None:
494 raise ValueError(
495 "max_relay_rounds cannot be used when response_format is "
496 "a Pydantic model or response_model is specified."
497 )
498 elif response_format is not None:
499 logger.warning(
500 "automatic continuation may not work well when response_format is specified. "
501 "Recommend using plain text generation instead."
502 )
504 if (
505 isinstance(get_origin(response_format), type)
506 or get_origin(response_format) is Literal
507 or isinstance(response_format, type)
508 and not issubclass(response_format, BaseModel)
509 ):
511 class Response(BaseModel):
512 response: response_format # type: ignore
514 response_format = Response # type: ignore
515 kwargs["_wrapped_attribute"] = "response"
517 create_args = GenArgs(
518 model=backend_server.model_name,
519 messages=messages,
520 max_tokens=max_tokens,
521 stop=stop,
522 temperature=temperature,
523 top_p=top_p,
524 n=n,
525 tools=build_tools(tools),
526 tool_format=tool_format, # type: ignore
527 stream=stream,
528 response_format=response_format, # type: ignore
529 response_model=response_model, # type: ignore
530 )
532 generation = Generation[M](
533 backend_server,
534 create_args,
535 max_relay_rounds=max_relay_rounds,
536 mock_response=mock_response,
537 _ctx=_ctx,
538 **kwargs,
539 )
541 @_langsmith_traceable(name=generation.id, metadata={"appl": "gen"}) # type: ignore
542 def langsmith_trace(*args: Any, **kwargs: Any) -> None:
543 pass
545 langsmith_trace(backend_server, create_args, _ctx=_ctx, **kwargs)
546 return generation
549# def serialize(obj: Any) -> Any:
550# if hasattr(obj, "serialize"):
551# return obj.serialize()
552# else:
553# raise TypeError("Object of type '%s' is not serializable" % type(obj).__name__)