Coverage for src/appl/core/context.py: 90%
78 statements
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-22 15:39 -0800
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-22 15:39 -0800
1from __future__ import annotations
3from argparse import Namespace
4from copy import deepcopy
5from typing import Any, Optional
7from .message import BaseMessage, Conversation
8from .printer import PrinterPop, PrinterPush, PromptPrinter, PromptRecords
9from .types import Image, String, StringFuture
12class PromptContext:
13 """The context of the APPL function."""
15 def __init__(self, globals_: Optional[Namespace] = None):
16 """Initialize the PromptContext object.
18 Args:
19 globals_: The global namespace of the APPL function.
20 """
21 if globals_ is None:
22 # create a new namespace (should inside __init__)
23 globals_ = Namespace()
24 self.globals = globals_
25 # set default values
26 if "messages" not in globals_:
27 self.messages = Conversation(system_messages=[], messages=[])
28 if "printer" not in globals_:
29 self.printer = PromptPrinter()
30 if "is_outmost" not in globals_:
31 self.is_outmost = True
33 # local vars start with "_"
34 self.locals = Namespace()
35 self._records = PromptRecords()
36 self._func_name: Optional[str] = None
37 self._func_docstring: Optional[str] = None
38 self._include_docstring: bool = False
39 self._docstring_quote_count: Optional[int] = None
40 self._is_first_str: bool = True
42 @property
43 def records(self) -> PromptRecords:
44 """The prompt records of the context."""
45 return self._records
47 def set_records(self, records: PromptRecords) -> None:
48 """Set the prompt records of the context."""
49 self._records = records
51 def add_string(self, string: String) -> None:
52 """Add a string to the prompt context."""
53 if isinstance(string, str):
54 string = StringFuture(string)
55 self.messages.extend(self.printer(string))
56 self.records.record(string)
58 def add_image(self, img: Image) -> None:
59 """Add an image to the prompt context."""
60 self.messages.extend(self.printer(img))
61 self.records.record(img)
63 def add_message(self, message: BaseMessage) -> None:
64 """Add a message to the prompt context."""
65 self.messages.extend(self.printer(message))
66 self.records.record(message)
68 def add_records(self, records: PromptRecords, write_to_prompt: bool = True) -> None:
69 """Add prompt records to the prompt context."""
70 if write_to_prompt:
71 self.messages.extend(self.printer(records))
72 self.records.extend(records)
74 def push_printer(self, push_args: PrinterPush) -> None:
75 """Push a new printer state to the prompt context."""
76 self.printer.push(push_args)
77 self.records.record(push_args)
79 def pop_printer(self) -> None:
80 """Pop a printer state from the prompt context."""
81 self.printer.pop()
82 self.records.record(PrinterPop())
84 def copy(self) -> "PromptContext":
85 """Create a new prompt context that copies the globals."""
86 return PromptContext(globals_=deepcopy(self.globals))
88 def inherit(self) -> "PromptContext":
89 """Create a new prompt context that has the same globals."""
90 return PromptContext(globals_=self.globals)
92 def _set_vars(self, vars: Namespace) -> None:
93 for k, v in vars.items():
94 setattr(self, k, v)
96 def __getattr__(self, name: str) -> Any:
97 """Forward attribute access to locals and globals."""
98 # logger.debug("getattr", name)
99 if name == "locals":
100 return self.locals
101 if name == "globals":
102 return self.globals
103 # Locals have higher priority
104 if name in self.locals:
105 return getattr(self.locals, name)
106 if name in self.globals:
107 return getattr(self.globals, name)
108 # Not found, raise AttributeError
109 if "_" + name in self.locals:
110 raise AttributeError(
111 f"Attribute '{name}' is local to the function, add '_' to access it."
112 )
113 raise AttributeError(f"Attribute '{name}' not found.")
115 def __setattr__(self, name: str, val: Any) -> None:
116 """Forward attribute assignment to vars."""
117 # logger.debug("setattr", name, val)
118 if name == "locals":
119 self.__dict__["locals"] = val
120 elif name == "globals":
121 self.__dict__["globals"] = val
122 elif name.startswith("_"):
123 setattr(self.locals, name, val)
124 else:
125 setattr(self.globals, name, val)
127 def __repr__(self) -> str:
128 return f"PromptContext(globals={self.globals!r}, locals={self.locals!r})"