Coverage for src/appl/core/printer.py: 93%
210 statements
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-22 15:39 -0800
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-22 15:39 -0800
1from __future__ import annotations
3import copy
4from dataclasses import dataclass
5from typing import List, Optional, Union
7import roman
8from loguru import logger
10from .message import BaseMessage, Conversation, as_message
11from .types import Image, MessageRole, String, StringFuture
14class Indexing:
15 """The indexing method for the printer."""
17 def __init__(
18 self,
19 method: Optional[str] = None,
20 ind: int = 0,
21 prefix: str = "",
22 suffix: Optional[str] = None,
23 ):
24 """Initialize the indexing method."""
25 self._method = method
26 self._ind = ind
27 self._prefix = prefix
28 self._suffix = suffix
30 def _get_index(self, ind: int) -> str:
31 if self._method is None:
32 return ""
33 default_suffix = ". "
34 if self._method == "number":
35 base = str(ind + 1)
36 elif self._method in ["lower", "upper", "letter", "Letter"]:
37 if ind >= 26:
38 raise ValueError("Letter-based indexing method only supports 26 items.")
39 base = chr(ord("A") + ind)
40 if self._method in ["lower", "letter"]:
41 base = base.lower()
42 elif self._method in ["roman", "Roman"]:
43 base = roman.toRoman(ind + 1)
44 if self._method == "roman":
45 base = base.lower()
46 else:
47 default_suffix = " "
48 if self._method == "star":
49 base = "*"
50 elif self._method == "dash":
51 base = "-"
52 elif self._method.startswith("sharp"):
53 base = "#" * int(self._method[5:])
54 else:
55 base = self._method
57 return self._prefix + base + (self._suffix or default_suffix)
59 def get_index(self, ind: Optional[int] = None) -> str:
60 """Get the index string for the current or given index."""
61 if ind is None:
62 ind = self._ind
63 self._ind += 1
64 if ind < 0:
65 raise ValueError("Indexing method does not support negative indexing.")
66 return self._get_index(ind)
68 def __repr__(self) -> str:
69 return f"Indexing(method={self._method!r}, ind={self._ind!r}, suffix={self._suffix!r})"
72@dataclass
73class PrinterState:
74 """A state of the printer."""
76 # settings
77 role: Optional[MessageRole] = None
78 """The role to be used for the message."""
79 separator: str = "\n"
80 """The separator to be used between texts."""
81 indexing: Indexing = Indexing(None, 0)
82 """The indexing method to be used."""
83 indent: str = ""
84 """The indent to be used in the beginning of each line."""
85 # inline means the first indent and indexing is inherited from the previous state
86 is_inline: bool = False
87 """Whether the state is inline. Inline means the first indent and
88 indexing is inherited from the previous non-inline state."""
89 # states
90 is_start: bool = True
91 """Whether the state is at the start of the scope."""
92 current_sep: str = ""
93 """The current separator to be used between texts."""
96@dataclass
97class PrinterPush:
98 """A record to push a new printer state to the stack."""
100 new_role: Optional[MessageRole] = None
101 """The new role to be used for the message."""
102 separator: Optional[str] = None
103 """The separator to be used between texts."""
104 indexing: Optional[Indexing] = None
105 """The indexing method to be used."""
106 inc_indent: str = ""
107 """The increment of the indent."""
108 new_indent: Optional[str] = None
109 """The new indent to be used."""
110 is_inline: Optional[bool] = False
111 """Whether the state is inline."""
114@dataclass
115class PrinterPop:
116 """A record to pop the last printer state from the stack."""
119RecordType = Union[BaseMessage, StringFuture, Image, PrinterPush, PrinterPop]
120"""Types allowed in the prompt records."""
123class PromptRecords:
124 """A class represents a list of prompt records."""
126 def __init__(self) -> None:
127 """Initialize the prompt records."""
128 self._records: List[RecordType] = []
130 @property
131 def records(self) -> List[RecordType]:
132 """The list of records."""
133 return self._records
135 def as_convo(self) -> Conversation:
136 """Convert the prompt records to a conversation."""
137 return PromptPrinter()(self)
139 def record(self, record: Union[str, RecordType]) -> None:
140 """Record a string, message, image, printer push or printer pop."""
141 if isinstance(record, str): # compatible to str
142 record = StringFuture(record)
143 if (
144 isinstance(record, StringFuture)
145 or isinstance(record, Image)
146 or isinstance(record, BaseMessage)
147 or isinstance(record, PrinterPush)
148 or isinstance(record, PrinterPop)
149 ):
150 self._records.append(record)
151 else:
152 raise ValueError("Can only record Message, PrinterPush or PrinterPop")
154 def extend(self, record: "PromptRecords") -> None:
155 """Extend the prompt records with another prompt records."""
156 self._records.extend(record._records)
158 def copy(self) -> "PromptRecords":
159 """Copy the prompt records."""
160 return copy.deepcopy(self)
162 def __str__(self) -> str:
163 return str(self.as_convo())
166class PromptPrinter:
167 """A class to print prompt records as conversation.
169 The printer maintains a stack of printer states about the
170 current role, separator, indexing, and indentation.
171 """
173 def __init__(
174 self, states: Optional[List[PrinterState]] = None, is_newline: bool = True
175 ) -> None:
176 """Initialize the prompt printer."""
177 if states is None:
178 states = [PrinterState()]
179 self._states = states
180 self._is_newline = is_newline
182 @property
183 def states(self) -> List[PrinterState]:
184 """The stack of printer states."""
185 return self._states
187 def push(self, data: PrinterPush) -> None:
188 """Push a new printer state to the stack."""
189 self._push(**data.__dict__)
191 def pop(self) -> None:
192 """Pop the last printer state from the stack."""
193 if len(self._states) == 1:
194 raise ValueError("Cannot pop the first state.")
195 self._states.pop()
197 def _push(
198 self,
199 new_role: Optional[MessageRole] = None,
200 separator: Optional[str] = None,
201 indexing: Optional[Indexing] = None,
202 inc_indent: str = "",
203 new_indent: Optional[str] = None,
204 is_inline: bool = False,
205 ) -> None:
206 state = self.states[-1]
207 if new_role is None or new_role == state.role:
208 new_role = state.role
209 current_separator = state.current_sep
210 default_separator = state.separator # Use the same separator as parent
211 default_indexing = state.indexing # Use the same indexing as parent
212 else: # a new role started
213 logger.debug(f"new role started {new_role}")
214 if len(self.states) > 1:
215 raise ValueError(
216 "Cannot start a new role when there are states in the stack."
217 )
218 state.is_start = True # reset the outmost state
219 state.current_sep = "" # also reset the current separator
220 current_separator = ""
221 default_separator = "\n"
222 default_indexing = Indexing(None, 0) # create a empty indexing
223 if new_indent is None:
224 new_indent = "" # reset the indent
225 if inc_indent:
226 raise ValueError(
227 "Cannot specify inc_indent when new role started. "
228 "Use new_indent instead."
229 )
231 if separator is None:
232 separator = default_separator
233 if indexing is None:
234 indexing = default_indexing
235 else:
236 # Avoid changing the original indexing (could be a record)
237 indexing = copy.copy(indexing)
238 if new_indent is None:
239 new_indent = state.indent + inc_indent # increment the indent
240 elif inc_indent:
241 raise ValueError("Cannot specify both inc_indent and new_indent.")
243 self._states.append(
244 PrinterState(
245 new_role,
246 separator,
247 indexing,
248 new_indent,
249 is_inline,
250 True,
251 # The current separator is inherited from the parent state
252 # it will change to its own separator after the first print.
253 current_sep=current_separator,
254 )
255 )
257 def _print_str(self, content: String) -> StringFuture:
258 state, previous = self._states[-1], self._states[:-1]
259 role = state.role
260 sep = state.current_sep
261 indent = state.indent
262 indexing = state.indexing
263 if state.is_start: # is the first content in this scope
264 state.is_start = False
265 state.current_sep = state.separator
266 for st in previous[::-1]:
267 if st.role != role:
268 break
269 if st.is_start:
270 # after first print, change the separator to its own
271 st.is_start = False
272 st.current_sep = st.separator
273 else:
274 break
276 if state.is_inline:
277 # inline means the first indent and indexing is
278 # inherited from the previous non-inline state
279 for st in previous[::-1]:
280 if st.role != role:
281 break
282 if not st.is_inline:
283 # Use the first non-inline's indent and indexing
284 indent = st.indent
285 indexing = st.indexing
286 break
287 if sep.endswith("\n"):
288 self._is_newline = True
290 s = StringFuture(sep)
291 if self._is_newline:
292 if indent:
293 s += indent
294 self._is_newline = False
295 if cur_idx := indexing.get_index():
296 s += cur_idx
297 s += content
299 # TODO: maybe check whether `s` ends with newline
300 return s
302 def _print_message(self, content: String) -> BaseMessage:
303 """Print a string as message with the current printer state."""
304 role = self._states[-1].role # the default role within the context
305 content = self._print_str(content)
306 return as_message(role, content)
308 def _print(
309 self, contents: Union[String, Image, BaseMessage, PromptRecords]
310 ) -> Conversation:
311 convo = Conversation(system_messages=[], messages=[])
313 def handle(rec: Union[RecordType, Image, str]) -> None:
314 if isinstance(rec, (str, StringFuture)):
315 convo.append(self._print_message(rec))
316 elif isinstance(rec, Image):
317 convo.append(as_message(self._states[-1].role, rec))
318 state = self._states[-1]
319 # reset current state after image, TODO: double check
320 state.is_start = True
321 state.current_sep = ""
322 elif isinstance(rec, BaseMessage):
323 convo.append(rec)
324 if rec.role is not None and len(self._states) == 1:
325 # change role in the outmost state
326 state = self._states[0]
327 state.is_start = True # reset the outmost state
328 state.current_sep = "" # also reset the current separator
329 # TODO: what should be the behavior if the role is changed
330 # in states other than the outmost one?
331 # should such behavior being allowed?
333 elif isinstance(rec, PrinterPush):
334 self.push(rec)
335 elif isinstance(rec, PrinterPop):
336 self.pop()
337 else:
338 raise ValueError(f"Unknown record type {type(rec)}")
340 if isinstance(contents, PromptRecords):
341 for rec in contents.records:
342 handle(rec)
343 else:
344 handle(contents)
346 return convo
348 __call__ = _print