Coverage for src/appl/core/printer.py: 93%

210 statements  

« prev     ^ index     » next       coverage.py v7.6.7, created at 2024-11-22 15:39 -0800

1from __future__ import annotations 

2 

3import copy 

4from dataclasses import dataclass 

5from typing import List, Optional, Union 

6 

7import roman 

8from loguru import logger 

9 

10from .message import BaseMessage, Conversation, as_message 

11from .types import Image, MessageRole, String, StringFuture 

12 

13 

14class Indexing: 

15 """The indexing method for the printer.""" 

16 

17 def __init__( 

18 self, 

19 method: Optional[str] = None, 

20 ind: int = 0, 

21 prefix: str = "", 

22 suffix: Optional[str] = None, 

23 ): 

24 """Initialize the indexing method.""" 

25 self._method = method 

26 self._ind = ind 

27 self._prefix = prefix 

28 self._suffix = suffix 

29 

30 def _get_index(self, ind: int) -> str: 

31 if self._method is None: 

32 return "" 

33 default_suffix = ". " 

34 if self._method == "number": 

35 base = str(ind + 1) 

36 elif self._method in ["lower", "upper", "letter", "Letter"]: 

37 if ind >= 26: 

38 raise ValueError("Letter-based indexing method only supports 26 items.") 

39 base = chr(ord("A") + ind) 

40 if self._method in ["lower", "letter"]: 

41 base = base.lower() 

42 elif self._method in ["roman", "Roman"]: 

43 base = roman.toRoman(ind + 1) 

44 if self._method == "roman": 

45 base = base.lower() 

46 else: 

47 default_suffix = " " 

48 if self._method == "star": 

49 base = "*" 

50 elif self._method == "dash": 

51 base = "-" 

52 elif self._method.startswith("sharp"): 

53 base = "#" * int(self._method[5:]) 

54 else: 

55 base = self._method 

56 

57 return self._prefix + base + (self._suffix or default_suffix) 

58 

59 def get_index(self, ind: Optional[int] = None) -> str: 

60 """Get the index string for the current or given index.""" 

61 if ind is None: 

62 ind = self._ind 

63 self._ind += 1 

64 if ind < 0: 

65 raise ValueError("Indexing method does not support negative indexing.") 

66 return self._get_index(ind) 

67 

68 def __repr__(self) -> str: 

69 return f"Indexing(method={self._method!r}, ind={self._ind!r}, suffix={self._suffix!r})" 

70 

71 

72@dataclass 

73class PrinterState: 

74 """A state of the printer.""" 

75 

76 # settings 

77 role: Optional[MessageRole] = None 

78 """The role to be used for the message.""" 

79 separator: str = "\n" 

80 """The separator to be used between texts.""" 

81 indexing: Indexing = Indexing(None, 0) 

82 """The indexing method to be used.""" 

83 indent: str = "" 

84 """The indent to be used in the beginning of each line.""" 

85 # inline means the first indent and indexing is inherited from the previous state 

86 is_inline: bool = False 

87 """Whether the state is inline. Inline means the first indent and 

88 indexing is inherited from the previous non-inline state.""" 

89 # states 

90 is_start: bool = True 

91 """Whether the state is at the start of the scope.""" 

92 current_sep: str = "" 

93 """The current separator to be used between texts.""" 

94 

95 

96@dataclass 

97class PrinterPush: 

98 """A record to push a new printer state to the stack.""" 

99 

100 new_role: Optional[MessageRole] = None 

101 """The new role to be used for the message.""" 

102 separator: Optional[str] = None 

103 """The separator to be used between texts.""" 

104 indexing: Optional[Indexing] = None 

105 """The indexing method to be used.""" 

106 inc_indent: str = "" 

107 """The increment of the indent.""" 

108 new_indent: Optional[str] = None 

109 """The new indent to be used.""" 

110 is_inline: Optional[bool] = False 

111 """Whether the state is inline.""" 

112 

113 

114@dataclass 

115class PrinterPop: 

116 """A record to pop the last printer state from the stack.""" 

117 

118 

119RecordType = Union[BaseMessage, StringFuture, Image, PrinterPush, PrinterPop] 

120"""Types allowed in the prompt records.""" 

121 

122 

123class PromptRecords: 

124 """A class represents a list of prompt records.""" 

125 

126 def __init__(self) -> None: 

127 """Initialize the prompt records.""" 

128 self._records: List[RecordType] = [] 

129 

130 @property 

131 def records(self) -> List[RecordType]: 

132 """The list of records.""" 

133 return self._records 

134 

135 def as_convo(self) -> Conversation: 

136 """Convert the prompt records to a conversation.""" 

137 return PromptPrinter()(self) 

138 

139 def record(self, record: Union[str, RecordType]) -> None: 

140 """Record a string, message, image, printer push or printer pop.""" 

141 if isinstance(record, str): # compatible to str 

142 record = StringFuture(record) 

143 if ( 

144 isinstance(record, StringFuture) 

145 or isinstance(record, Image) 

146 or isinstance(record, BaseMessage) 

147 or isinstance(record, PrinterPush) 

148 or isinstance(record, PrinterPop) 

149 ): 

150 self._records.append(record) 

151 else: 

152 raise ValueError("Can only record Message, PrinterPush or PrinterPop") 

153 

154 def extend(self, record: "PromptRecords") -> None: 

155 """Extend the prompt records with another prompt records.""" 

156 self._records.extend(record._records) 

157 

158 def copy(self) -> "PromptRecords": 

159 """Copy the prompt records.""" 

160 return copy.deepcopy(self) 

161 

162 def __str__(self) -> str: 

163 return str(self.as_convo()) 

164 

165 

166class PromptPrinter: 

167 """A class to print prompt records as conversation. 

168 

169 The printer maintains a stack of printer states about the 

170 current role, separator, indexing, and indentation. 

171 """ 

172 

173 def __init__( 

174 self, states: Optional[List[PrinterState]] = None, is_newline: bool = True 

175 ) -> None: 

176 """Initialize the prompt printer.""" 

177 if states is None: 

178 states = [PrinterState()] 

179 self._states = states 

180 self._is_newline = is_newline 

181 

182 @property 

183 def states(self) -> List[PrinterState]: 

184 """The stack of printer states.""" 

185 return self._states 

186 

187 def push(self, data: PrinterPush) -> None: 

188 """Push a new printer state to the stack.""" 

189 self._push(**data.__dict__) 

190 

191 def pop(self) -> None: 

192 """Pop the last printer state from the stack.""" 

193 if len(self._states) == 1: 

194 raise ValueError("Cannot pop the first state.") 

195 self._states.pop() 

196 

197 def _push( 

198 self, 

199 new_role: Optional[MessageRole] = None, 

200 separator: Optional[str] = None, 

201 indexing: Optional[Indexing] = None, 

202 inc_indent: str = "", 

203 new_indent: Optional[str] = None, 

204 is_inline: bool = False, 

205 ) -> None: 

206 state = self.states[-1] 

207 if new_role is None or new_role == state.role: 

208 new_role = state.role 

209 current_separator = state.current_sep 

210 default_separator = state.separator # Use the same separator as parent 

211 default_indexing = state.indexing # Use the same indexing as parent 

212 else: # a new role started 

213 logger.debug(f"new role started {new_role}") 

214 if len(self.states) > 1: 

215 raise ValueError( 

216 "Cannot start a new role when there are states in the stack." 

217 ) 

218 state.is_start = True # reset the outmost state 

219 state.current_sep = "" # also reset the current separator 

220 current_separator = "" 

221 default_separator = "\n" 

222 default_indexing = Indexing(None, 0) # create a empty indexing 

223 if new_indent is None: 

224 new_indent = "" # reset the indent 

225 if inc_indent: 

226 raise ValueError( 

227 "Cannot specify inc_indent when new role started. " 

228 "Use new_indent instead." 

229 ) 

230 

231 if separator is None: 

232 separator = default_separator 

233 if indexing is None: 

234 indexing = default_indexing 

235 else: 

236 # Avoid changing the original indexing (could be a record) 

237 indexing = copy.copy(indexing) 

238 if new_indent is None: 

239 new_indent = state.indent + inc_indent # increment the indent 

240 elif inc_indent: 

241 raise ValueError("Cannot specify both inc_indent and new_indent.") 

242 

243 self._states.append( 

244 PrinterState( 

245 new_role, 

246 separator, 

247 indexing, 

248 new_indent, 

249 is_inline, 

250 True, 

251 # The current separator is inherited from the parent state 

252 # it will change to its own separator after the first print. 

253 current_sep=current_separator, 

254 ) 

255 ) 

256 

257 def _print_str(self, content: String) -> StringFuture: 

258 state, previous = self._states[-1], self._states[:-1] 

259 role = state.role 

260 sep = state.current_sep 

261 indent = state.indent 

262 indexing = state.indexing 

263 if state.is_start: # is the first content in this scope 

264 state.is_start = False 

265 state.current_sep = state.separator 

266 for st in previous[::-1]: 

267 if st.role != role: 

268 break 

269 if st.is_start: 

270 # after first print, change the separator to its own 

271 st.is_start = False 

272 st.current_sep = st.separator 

273 else: 

274 break 

275 

276 if state.is_inline: 

277 # inline means the first indent and indexing is 

278 # inherited from the previous non-inline state 

279 for st in previous[::-1]: 

280 if st.role != role: 

281 break 

282 if not st.is_inline: 

283 # Use the first non-inline's indent and indexing 

284 indent = st.indent 

285 indexing = st.indexing 

286 break 

287 if sep.endswith("\n"): 

288 self._is_newline = True 

289 

290 s = StringFuture(sep) 

291 if self._is_newline: 

292 if indent: 

293 s += indent 

294 self._is_newline = False 

295 if cur_idx := indexing.get_index(): 

296 s += cur_idx 

297 s += content 

298 

299 # TODO: maybe check whether `s` ends with newline 

300 return s 

301 

302 def _print_message(self, content: String) -> BaseMessage: 

303 """Print a string as message with the current printer state.""" 

304 role = self._states[-1].role # the default role within the context 

305 content = self._print_str(content) 

306 return as_message(role, content) 

307 

308 def _print( 

309 self, contents: Union[String, Image, BaseMessage, PromptRecords] 

310 ) -> Conversation: 

311 convo = Conversation(system_messages=[], messages=[]) 

312 

313 def handle(rec: Union[RecordType, Image, str]) -> None: 

314 if isinstance(rec, (str, StringFuture)): 

315 convo.append(self._print_message(rec)) 

316 elif isinstance(rec, Image): 

317 convo.append(as_message(self._states[-1].role, rec)) 

318 state = self._states[-1] 

319 # reset current state after image, TODO: double check 

320 state.is_start = True 

321 state.current_sep = "" 

322 elif isinstance(rec, BaseMessage): 

323 convo.append(rec) 

324 if rec.role is not None and len(self._states) == 1: 

325 # change role in the outmost state 

326 state = self._states[0] 

327 state.is_start = True # reset the outmost state 

328 state.current_sep = "" # also reset the current separator 

329 # TODO: what should be the behavior if the role is changed 

330 # in states other than the outmost one? 

331 # should such behavior being allowed? 

332 

333 elif isinstance(rec, PrinterPush): 

334 self.push(rec) 

335 elif isinstance(rec, PrinterPop): 

336 self.pop() 

337 else: 

338 raise ValueError(f"Unknown record type {type(rec)}") 

339 

340 if isinstance(contents, PromptRecords): 

341 for rec in contents.records: 

342 handle(rec) 

343 else: 

344 handle(contents) 

345 

346 return convo 

347 

348 __call__ = _print