Coverage for src/appl/tracing/printer.py: 22%

185 statements  

« prev     ^ index     » next       coverage.py v7.6.7, created at 2024-11-22 15:39 -0800

1import copy 

2import json 

3import os 

4import uuid 

5from datetime import datetime, timezone 

6from typing import Any, Dict, List, Optional, Union 

7 

8from litellm import ModelResponse 

9from loguru import logger 

10 

11from ..compositor import Tagged as OriginalTagged 

12from ..core.config import Configs 

13from ..core.io import load_file 

14from ..core.printer import PromptRecords 

15from ..core.trace import ( 

16 CompletionRequestEvent, 

17 CompletionResponseEvent, 

18 FunctionCallEvent, 

19 FunctionReturnEvent, 

20 GenerationInitEvent, 

21 GenerationResponseEvent, 

22 TraceEngineBase, 

23 TraceEventBase, 

24 TraceNode, 

25 TracePrinterBase, 

26) 

27from ..func import partial, ppl, records 

28 

29folder = os.path.dirname(__file__) 

30 

31Tagged = partial(OriginalTagged, indent_inside=4) 

32 

33 

34def timestamp_to_iso(time_stamp: float) -> str: 

35 """Convert the timestamp to the ISO format.""" 

36 return datetime.fromtimestamp(time_stamp, timezone.utc).isoformat() 

37 

38 

39class TraceHTMLPrinter(TracePrinterBase): 

40 """The printer used to print the trace in the format of HTML.""" 

41 

42 def __init__(self): 

43 """Initialize the printer.""" 

44 super().__init__() 

45 self._generation_style = "text-success-emphasis bg-Success-subtle list-group-item d-flex justify-content-between align-items-center" 

46 self._time_style = "position-absolute top-0 start-100 translate-middle badge rounded-pill bg-info" 

47 self._cost_style = "position-absolute bottom-0 start-100 translate-middle badge rounded-pill bg-warning" 

48 self._longest_shown_output = 70 

49 

50 self._head = load_file(os.path.join(folder, "header.html")) 

51 self._color_map = { 

52 "user": "text-bg-info", 

53 "assistant": "text-bg-warning", 

54 "system": "text-bg-success", 

55 } 

56 

57 @ppl 

58 def print(self, trace: TraceEngineBase, meta_data: Optional[Configs] = None) -> str: 

59 """Print the trace in the format of HTML.""" 

60 with Tagged("html"): 

61 self._head 

62 with Tagged("body"): 

63 for node in trace.trace_nodes.values(): 

64 if node.parent is None: 

65 self._print_node(node, trace.min_timestamp) 

66 if meta_data: 

67 with Tagged("table", attrs={"class": "table small"}): 

68 if start_time := meta_data.getattrs("info.start_time"): 

69 self._make_line("Start Time", start_time) 

70 self._make_line("Full Configs", f"<pre>{meta_data.to_yaml()}</pre>") 

71 return str(records()) 

72 

73 @ppl 

74 def _print_messages(self, messages: List[Union[str, Dict]]) -> PromptRecords: 

75 def display(message: Union[str, Dict]) -> str: 

76 return f'<div style="white-space: pre-wrap;">{message}</div>' 

77 

78 with Tagged("ul", attrs={"class": "list-group small"}): 

79 for message in messages: 

80 with Tagged("li", attrs={"class": "list-group-item"}): 

81 if ( 

82 isinstance(message, dict) 

83 and "role" in message 

84 and "content" in message 

85 ): 

86 color = self._color_map.get(message["role"], "text-bg-info") 

87 f"<span class='badge {color}'>{message['role']}</span>" 

88 display(message["content"]) 

89 else: 

90 display(message) 

91 return records() 

92 

93 @ppl 

94 def _print_genargs(self, args: Dict, output: Optional[str] = None) -> PromptRecords: 

95 with Tagged("table", attrs={"class": "table small"}): 

96 for k, v in args.items(): 

97 with Tagged("tr"): 

98 with Tagged("th"): 

99 f"{k}" 

100 with Tagged("td"): 

101 if k == "messages": 

102 self._print_messages(v) 

103 elif k in ["response_format"]: 

104 f"<pre>{json.dumps(v, indent=2)}</pre>" 

105 elif k in ["response_model"]: 

106 f"<pre>{v}</pre>" 

107 elif k == "stop": 

108 repr(v) 

109 else: 

110 f"{v}" 

111 if output is not None: 

112 self._make_line("output", output) 

113 return records() 

114 

115 @ppl 

116 def _print_gen(self, node: TraceNode, min_timestamp: float = 0.0) -> PromptRecords: 

117 name = node.name 

118 if node.args is not None: 

119 completion = node.children[0] if node.children else None 

120 with Tagged("ul", attrs={"class": "list-group m-1"}): 

121 li_attrs = {"class": self._generation_style} 

122 li_attrs.update(self._toggle_attrs(name)) 

123 with Tagged("li", attrs=li_attrs): 

124 f"<div><b>{name}:</b> {node.ret}</div>" 

125 if completion and (runtime := completion.runtime) > 0: 

126 f"<span class='{self._time_style}'>{runtime:.2e} s</span>" 

127 if completion and (cost := completion.info.get("cost")): 

128 f"<span class='{self._cost_style}'>$ {cost:.2e}</span>" 

129 with Tagged( 

130 "li", attrs={"class": "list-group-item collapse", "id": name} 

131 ): 

132 self._print_genargs(node.args, node.ret) 

133 else: 

134 li_attrs = { 

135 "class": "list-group-item text-warning-emphasis bg-warning-subtle" 

136 } 

137 with Tagged("ul", attrs={"class": "list-group m-1"}): 

138 with Tagged("li", attrs=li_attrs): 

139 "Unfinished Generation" 

140 return records() 

141 

142 @ppl 

143 def _print_func(self, node: TraceNode, min_timestamp: float = 0.0) -> PromptRecords: 

144 name = node.name 

145 with Tagged("ul", attrs={"class": "list-group m-2"}): 

146 li_attrs = {"class": "text-center bg-light list-group-item"} 

147 li_attrs.update(self._toggle_attrs(name, True)) 

148 with Tagged("li", attrs=li_attrs): 

149 f"<b>{name}</b>" 

150 with Tagged("li", attrs={"class": "list-group-item show", "id": name}): 

151 # display details for the function 

152 # ? display time, args and kwargs 

153 # with Tagged("table", attrs={"class": "table small"}): 

154 # start = node.start_time - min_timestamp 

155 # end = node.end_time - min_timestamp 

156 # runtime = end - start 

157 # self._make_line( 

158 # "Time", f"{runtime:.2e} s (from {start:.2e} to {end:.2e})" 

159 # ) 

160 # # if node.args: 

161 # # func_args = node.args["args"] 

162 # # self._make_line("args", func_args) 

163 # # func_kwargs = node.args["kwargs"] 

164 # # for k, v in func_kwargs.items(): 

165 # # self._make_line(k, v) 

166 for child in node.children: 

167 self._print_node(child, min_timestamp) 

168 return records() 

169 

170 def _print_node(self, node: TraceNode, min_timestamp: float = 0.0) -> Any: 

171 if node.type == "func": 

172 return self._print_func(node, min_timestamp) 

173 else: 

174 return self._print_gen(node, min_timestamp) 

175 

176 def _toggle_attrs(self, name: str, expanded: bool = False) -> Dict: 

177 return { 

178 "data-bs-toggle": "collapse", 

179 "href": f"#{name}", 

180 "role": "button", 

181 "aria-controls": name, 

182 "aria-expanded": "true" if expanded else "false", 

183 } 

184 

185 def _make_line(self, k: str, v: Any) -> str: 

186 return f"<tr><th>{k}</th><td>{v}</td></tr>" 

187 

188 

189class TraceLunaryPrinter(TracePrinterBase): 

190 """The printer used to log the trace to lunary.""" 

191 

192 def print( 

193 self, trace: TraceEngineBase, meta_data: Optional[Configs] = None 

194 ) -> None: 

195 """Log the trace to lunary.""" 

196 import lunary 

197 

198 project_id = os.environ.get( 

199 "LUNARY_PUBLIC_KEY", "1c1975c5-13b9-4977-8003-89fff5c71c27" 

200 ) 

201 url = os.environ.get("LUNARY_API_URL", "http://localhost:3333") 

202 logger.info(f"project_id: {project_id}, api url: {url}") 

203 lunary.config(app_id=project_id, api_url=url) 

204 

205 suffix = f"_{uuid.uuid4().hex}" 

206 logger.info(f"suffix: {suffix}") 

207 

208 def get_parent_run_id(node: TraceNode) -> Optional[str]: 

209 if node.parent is None: 

210 return None 

211 return node.parent.name + suffix 

212 

213 """Log the trace to lunary.""" 

214 for node in trace.trace_nodes.values(): 

215 if node.type == "func": 

216 logger.info( 

217 f"sending func event {node.name} to lunary with parent {get_parent_run_id(node)}" 

218 ) 

219 lunary.track_event( 

220 "chain", 

221 "start", 

222 run_id=node.name + suffix, 

223 name=node.name, 

224 parent_run_id=get_parent_run_id(node), 

225 input=node.args, 

226 timestamp=timestamp_to_iso(node.start_time), 

227 ) 

228 lunary.track_event( 

229 "chain", 

230 "end", 

231 run_id=node.name + suffix, 

232 output=node.ret, 

233 timestamp=timestamp_to_iso(node.end_time), 

234 ) 

235 

236 elif node.type == "gen": 

237 logger.info( 

238 f"sending llm event {node.name} to lunary with parent {get_parent_run_id(node)}" 

239 ) 

240 

241 # skip the raw generation, support for legacy traces 

242 # if node.name.endswith("_raw"): 

243 # continue 

244 metadata = copy.deepcopy(node.args or {}) 

245 model_name = metadata.pop("model", node.name) 

246 messages = metadata.pop("messages", "") 

247 metadata["gen_ID"] = node.name 

248 lunary.track_event( 

249 "llm", 

250 "start", 

251 run_id=node.name + suffix, 

252 name=model_name, 

253 parent_run_id=get_parent_run_id(node), 

254 metadata=metadata, 

255 input=messages, 

256 timestamp=timestamp_to_iso(node.start_time), 

257 ) 

258 lunary.track_event( 

259 "llm", 

260 "end", 

261 run_id=node.name + suffix, 

262 output={"role": "assistant", "content": node.ret}, 

263 timestamp=timestamp_to_iso(node.end_time), 

264 ) 

265 elif node.type == "raw_llm": 

266 logger.info( 

267 f"sending raw llm event {node.name} to lunary with parent {get_parent_run_id(node)}" 

268 ) 

269 metadata = copy.deepcopy(node.args or {}) 

270 model_name = metadata.pop("model", node.name) 

271 messages = metadata.pop("messages", "") 

272 lunary.track_event( 

273 "llm", 

274 "start", 

275 run_id=node.name + suffix, 

276 name=model_name, 

277 parent_run_id=get_parent_run_id(node), 

278 metadata=metadata, 

279 input=messages, 

280 timestamp=timestamp_to_iso(node.start_time), 

281 ) 

282 response: ModelResponse = node.ret # complete response 

283 lunary.track_event( 

284 "llm", 

285 "end", 

286 run_id=node.name + suffix, 

287 output={ 

288 "role": "assistant", 

289 "content": response.choices[0].message.content, # type: ignore 

290 # TODO: support tool calls 

291 }, 

292 timestamp=timestamp_to_iso(node.end_time), 

293 ) 

294 

295 

296class TraceYAMLPrinter(TracePrinterBase): 

297 """The printer used to print the trace in the format of YAML.""" 

298 

299 def print( 

300 self, trace: TraceEngineBase, meta_data: Optional[Configs] = None 

301 ) -> None: 

302 """Print the trace in the format of YAML.""" 

303 # TODO: implement the YAML printer 

304 pass 

305 

306 

307class TraceProfilePrinter(TracePrinterBase): 

308 """The printer used to print the trace in the format of profile.""" 

309 

310 def __init__(self, display_functions: bool = False): 

311 """Initialize the printer. 

312 

313 Args: 

314 display_functions: Whether to display the function calls. 

315 """ 

316 self._display_functions = display_functions 

317 

318 def build_event(self, event: TraceEventBase, min_timestamp: float) -> Dict: 

319 """Build the event for the trace.""" 

320 ts = str((event.time_stamp - min_timestamp) * 1e6) 

321 data = {"pid": 0, "tid": 0, "name": event.name, "ts": ts} 

322 # TODO: add args to the trace 

323 if isinstance(event, CompletionRequestEvent): 

324 data["cat"] = "gen" 

325 data["ph"] = "b" 

326 data["id"] = event.name 

327 elif isinstance(event, CompletionResponseEvent): 

328 data["cat"] = "gen" 

329 data["ph"] = "e" 

330 data["id"] = event.name 

331 data["cost"] = event.cost 

332 data["output"] = event.ret.dict() 

333 elif self._display_functions: 

334 if isinstance(event, FunctionCallEvent): 

335 data["cat"] = "func" 

336 data["ph"] = "B" 

337 data["tid"] = "main" 

338 elif isinstance(event, FunctionReturnEvent): 

339 data["cat"] = "func" 

340 data["ph"] = "E" 

341 data["tid"] = "main" 

342 return data 

343 

344 def print( 

345 self, trace: TraceEngineBase, meta_data: Optional[Configs] = None 

346 ) -> Dict: 

347 """Print the trace in the format of Chrome tracing.""" 

348 events = [] 

349 for event in trace.events: 

350 if data := self.build_event(event, trace.min_timestamp): 

351 events.append(data) 

352 return {"traceEvents": events}