Coverage for src/appl/core/runtime.py: 81%

73 statements  

« prev     ^ index     » next       coverage.py v7.6.7, created at 2024-11-22 15:39 -0800

1"""helper functions for runtime execution within the compiled function.""" 

2 

3import inspect 

4from argparse import Namespace 

5from typing import Any, Callable, Dict, Iterable 

6 

7from loguru import logger 

8from PIL.ImageFile import ImageFile 

9 

10from .config import configs 

11from .context import PromptContext 

12from .generation import Generation 

13from .message import BaseMessage 

14from .printer import PromptRecords 

15from .promptable import Promptable, promptify 

16from .types import CallFuture, Image, StringFuture 

17 

18 

19def appl_with_ctx( 

20 *args: Any, 

21 _func: Callable, 

22 _ctx: PromptContext = PromptContext(), 

23 _globals: Any = None, 

24 _locals: Any = None, 

25 **kwargs: Any, 

26) -> Any: 

27 """Forward context to prompt functions.""" 

28 if _func is globals: 

29 # use the globals when calling the function 

30 return _globals 

31 if _func is locals: 

32 # use the locals when calling the function 

33 return _locals 

34 if _func in (exec, eval): 

35 # fix the globals and locals for exec and eval 

36 if len(args) < 2 and "globals" not in kwargs: 

37 args = args + (_globals,) 

38 elif len(args) < 3 and "locals" not in kwargs: 

39 args = args + (_locals,) 

40 if getattr(_func, "__need_ctx__", False): 

41 kwargs["_ctx"] = _ctx # add the context to the kwargs 

42 return _func(*args, **kwargs) 

43 

44 

45def appl_execute( 

46 s: Any, 

47 _ctx: PromptContext = PromptContext(), 

48) -> None: 

49 """Interact with the prompt context using the given value.""" 

50 if s is None: 

51 return 

52 if isinstance(s, str): 

53 add_str = True 

54 if _ctx._is_first_str: 

55 docstring = _ctx._func_docstring 

56 if docstring is not None: 

57 docstring = inspect.cleandoc(docstring) 

58 if _ctx._include_docstring: 

59 if docstring is None: 

60 logger.warning( 

61 f"No docstring found for {_ctx._func_name}, cannot include it." 

62 ) 

63 else: 

64 assert s == docstring, f"Docstring mismatch: {s}" 

65 elif s == docstring and _ctx._docstring_quote_count != 1: 

66 add_str = False 

67 if configs.getattrs( 

68 "settings.logging.display.docstring_warning", False 

69 ): 

70 logger.warning( 

71 f'The docstring """{s}""" for `{_ctx._func_name}` is excluded from the prompt. ' 

72 "To include the docstring, set include_docstring=True in the @ppl function." 

73 ) 

74 if add_str: 

75 _ctx.add_string(StringFuture(s)) 

76 _ctx._is_first_str = False 

77 elif isinstance(s, StringFuture): 

78 _ctx.add_string(s) 

79 elif isinstance(s, PromptRecords): 

80 _ctx.add_records(s) 

81 elif isinstance(s, BaseMessage): 

82 _ctx.add_message(s) 

83 elif isinstance(s, Image): 

84 _ctx.add_image(s) 

85 elif isinstance(s, ImageFile): 

86 _ctx.add_image(Image.from_image(s)) 

87 elif isinstance(s, Generation): 

88 appl_execute(s.as_prompt(), _ctx) 

89 elif isinstance(s, Promptable): 

90 # recursively apply 

91 appl_execute(promptify(s), _ctx) 

92 elif isinstance(s, Iterable): 

93 # iterable items, recursively apply 

94 for x in s: 

95 appl_execute(x, _ctx) 

96 elif isinstance(s, Namespace): # for advanced usage only 

97 logger.info(f"updating context variables using the namespace: {s}") 

98 _ctx._set_vars(s) 

99 else: 

100 logger.warning(f"Cannot convert {s} of type {type(s)} to prompt, ignore.") 

101 

102 

103def appl_format( 

104 value: Any, format_spec: str = "", conversion: int = -1 

105) -> StringFuture: 

106 """Create a StringFuture object that represents the formatted string.""" 

107 if conversion >= 0: 

108 conversion_func: Dict[str, Callable] = {"s": str, "r": repr, "a": ascii} 

109 if (c := chr(conversion)) not in conversion_func: 

110 raise ValueError(f"Invalid conversion character: {c}") 

111 value = StringFuture(CallFuture(conversion_func[c], value)) 

112 

113 return StringFuture(CallFuture(format, value, format_spec))