Coverage for src/appl/core/context.py: 90%

78 statements  

« prev     ^ index     » next       coverage.py v7.6.7, created at 2024-11-22 15:39 -0800

1from __future__ import annotations 

2 

3from argparse import Namespace 

4from copy import deepcopy 

5from typing import Any, Optional 

6 

7from .message import BaseMessage, Conversation 

8from .printer import PrinterPop, PrinterPush, PromptPrinter, PromptRecords 

9from .types import Image, String, StringFuture 

10 

11 

12class PromptContext: 

13 """The context of the APPL function.""" 

14 

15 def __init__(self, globals_: Optional[Namespace] = None): 

16 """Initialize the PromptContext object. 

17 

18 Args: 

19 globals_: The global namespace of the APPL function. 

20 """ 

21 if globals_ is None: 

22 # create a new namespace (should inside __init__) 

23 globals_ = Namespace() 

24 self.globals = globals_ 

25 # set default values 

26 if "messages" not in globals_: 

27 self.messages = Conversation(system_messages=[], messages=[]) 

28 if "printer" not in globals_: 

29 self.printer = PromptPrinter() 

30 if "is_outmost" not in globals_: 

31 self.is_outmost = True 

32 

33 # local vars start with "_" 

34 self.locals = Namespace() 

35 self._records = PromptRecords() 

36 self._func_name: Optional[str] = None 

37 self._func_docstring: Optional[str] = None 

38 self._include_docstring: bool = False 

39 self._docstring_quote_count: Optional[int] = None 

40 self._is_first_str: bool = True 

41 

42 @property 

43 def records(self) -> PromptRecords: 

44 """The prompt records of the context.""" 

45 return self._records 

46 

47 def set_records(self, records: PromptRecords) -> None: 

48 """Set the prompt records of the context.""" 

49 self._records = records 

50 

51 def add_string(self, string: String) -> None: 

52 """Add a string to the prompt context.""" 

53 if isinstance(string, str): 

54 string = StringFuture(string) 

55 self.messages.extend(self.printer(string)) 

56 self.records.record(string) 

57 

58 def add_image(self, img: Image) -> None: 

59 """Add an image to the prompt context.""" 

60 self.messages.extend(self.printer(img)) 

61 self.records.record(img) 

62 

63 def add_message(self, message: BaseMessage) -> None: 

64 """Add a message to the prompt context.""" 

65 self.messages.extend(self.printer(message)) 

66 self.records.record(message) 

67 

68 def add_records(self, records: PromptRecords, write_to_prompt: bool = True) -> None: 

69 """Add prompt records to the prompt context.""" 

70 if write_to_prompt: 

71 self.messages.extend(self.printer(records)) 

72 self.records.extend(records) 

73 

74 def push_printer(self, push_args: PrinterPush) -> None: 

75 """Push a new printer state to the prompt context.""" 

76 self.printer.push(push_args) 

77 self.records.record(push_args) 

78 

79 def pop_printer(self) -> None: 

80 """Pop a printer state from the prompt context.""" 

81 self.printer.pop() 

82 self.records.record(PrinterPop()) 

83 

84 def copy(self) -> "PromptContext": 

85 """Create a new prompt context that copies the globals.""" 

86 return PromptContext(globals_=deepcopy(self.globals)) 

87 

88 def inherit(self) -> "PromptContext": 

89 """Create a new prompt context that has the same globals.""" 

90 return PromptContext(globals_=self.globals) 

91 

92 def _set_vars(self, vars: Namespace) -> None: 

93 for k, v in vars.items(): 

94 setattr(self, k, v) 

95 

96 def __getattr__(self, name: str) -> Any: 

97 """Forward attribute access to locals and globals.""" 

98 # logger.debug("getattr", name) 

99 if name == "locals": 

100 return self.locals 

101 if name == "globals": 

102 return self.globals 

103 # Locals have higher priority 

104 if name in self.locals: 

105 return getattr(self.locals, name) 

106 if name in self.globals: 

107 return getattr(self.globals, name) 

108 # Not found, raise AttributeError 

109 if "_" + name in self.locals: 

110 raise AttributeError( 

111 f"Attribute '{name}' is local to the function, add '_' to access it." 

112 ) 

113 raise AttributeError(f"Attribute '{name}' not found.") 

114 

115 def __setattr__(self, name: str, val: Any) -> None: 

116 """Forward attribute assignment to vars.""" 

117 # logger.debug("setattr", name, val) 

118 if name == "locals": 

119 self.__dict__["locals"] = val 

120 elif name == "globals": 

121 self.__dict__["globals"] = val 

122 elif name.startswith("_"): 

123 setattr(self.locals, name, val) 

124 else: 

125 setattr(self.globals, name, val) 

126 

127 def __repr__(self) -> str: 

128 return f"PromptContext(globals={self.globals!r}, locals={self.locals!r})"