Coverage for src/appl/tracing/printer.py: 22%
185 statements
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-22 15:39 -0800
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-22 15:39 -0800
1import copy
2import json
3import os
4import uuid
5from datetime import datetime, timezone
6from typing import Any, Dict, List, Optional, Union
8from litellm import ModelResponse
9from loguru import logger
11from ..compositor import Tagged as OriginalTagged
12from ..core.config import Configs
13from ..core.io import load_file
14from ..core.printer import PromptRecords
15from ..core.trace import (
16 CompletionRequestEvent,
17 CompletionResponseEvent,
18 FunctionCallEvent,
19 FunctionReturnEvent,
20 GenerationInitEvent,
21 GenerationResponseEvent,
22 TraceEngineBase,
23 TraceEventBase,
24 TraceNode,
25 TracePrinterBase,
26)
27from ..func import partial, ppl, records
29folder = os.path.dirname(__file__)
31Tagged = partial(OriginalTagged, indent_inside=4)
34def timestamp_to_iso(time_stamp: float) -> str:
35 """Convert the timestamp to the ISO format."""
36 return datetime.fromtimestamp(time_stamp, timezone.utc).isoformat()
39class TraceHTMLPrinter(TracePrinterBase):
40 """The printer used to print the trace in the format of HTML."""
42 def __init__(self):
43 """Initialize the printer."""
44 super().__init__()
45 self._generation_style = "text-success-emphasis bg-Success-subtle list-group-item d-flex justify-content-between align-items-center"
46 self._time_style = "position-absolute top-0 start-100 translate-middle badge rounded-pill bg-info"
47 self._cost_style = "position-absolute bottom-0 start-100 translate-middle badge rounded-pill bg-warning"
48 self._longest_shown_output = 70
50 self._head = load_file(os.path.join(folder, "header.html"))
51 self._color_map = {
52 "user": "text-bg-info",
53 "assistant": "text-bg-warning",
54 "system": "text-bg-success",
55 }
57 @ppl
58 def print(self, trace: TraceEngineBase, meta_data: Optional[Configs] = None) -> str:
59 """Print the trace in the format of HTML."""
60 with Tagged("html"):
61 self._head
62 with Tagged("body"):
63 for node in trace.trace_nodes.values():
64 if node.parent is None:
65 self._print_node(node, trace.min_timestamp)
66 if meta_data:
67 with Tagged("table", attrs={"class": "table small"}):
68 if start_time := meta_data.getattrs("info.start_time"):
69 self._make_line("Start Time", start_time)
70 self._make_line("Full Configs", f"<pre>{meta_data.to_yaml()}</pre>")
71 return str(records())
73 @ppl
74 def _print_messages(self, messages: List[Union[str, Dict]]) -> PromptRecords:
75 def display(message: Union[str, Dict]) -> str:
76 return f'<div style="white-space: pre-wrap;">{message}</div>'
78 with Tagged("ul", attrs={"class": "list-group small"}):
79 for message in messages:
80 with Tagged("li", attrs={"class": "list-group-item"}):
81 if (
82 isinstance(message, dict)
83 and "role" in message
84 and "content" in message
85 ):
86 color = self._color_map.get(message["role"], "text-bg-info")
87 f"<span class='badge {color}'>{message['role']}</span>"
88 display(message["content"])
89 else:
90 display(message)
91 return records()
93 @ppl
94 def _print_genargs(self, args: Dict, output: Optional[str] = None) -> PromptRecords:
95 with Tagged("table", attrs={"class": "table small"}):
96 for k, v in args.items():
97 with Tagged("tr"):
98 with Tagged("th"):
99 f"{k}"
100 with Tagged("td"):
101 if k == "messages":
102 self._print_messages(v)
103 elif k in ["response_format"]:
104 f"<pre>{json.dumps(v, indent=2)}</pre>"
105 elif k in ["response_model"]:
106 f"<pre>{v}</pre>"
107 elif k == "stop":
108 repr(v)
109 else:
110 f"{v}"
111 if output is not None:
112 self._make_line("output", output)
113 return records()
115 @ppl
116 def _print_gen(self, node: TraceNode, min_timestamp: float = 0.0) -> PromptRecords:
117 name = node.name
118 if node.args is not None:
119 completion = node.children[0] if node.children else None
120 with Tagged("ul", attrs={"class": "list-group m-1"}):
121 li_attrs = {"class": self._generation_style}
122 li_attrs.update(self._toggle_attrs(name))
123 with Tagged("li", attrs=li_attrs):
124 f"<div><b>{name}:</b> {node.ret}</div>"
125 if completion and (runtime := completion.runtime) > 0:
126 f"<span class='{self._time_style}'>{runtime:.2e} s</span>"
127 if completion and (cost := completion.info.get("cost")):
128 f"<span class='{self._cost_style}'>$ {cost:.2e}</span>"
129 with Tagged(
130 "li", attrs={"class": "list-group-item collapse", "id": name}
131 ):
132 self._print_genargs(node.args, node.ret)
133 else:
134 li_attrs = {
135 "class": "list-group-item text-warning-emphasis bg-warning-subtle"
136 }
137 with Tagged("ul", attrs={"class": "list-group m-1"}):
138 with Tagged("li", attrs=li_attrs):
139 "Unfinished Generation"
140 return records()
142 @ppl
143 def _print_func(self, node: TraceNode, min_timestamp: float = 0.0) -> PromptRecords:
144 name = node.name
145 with Tagged("ul", attrs={"class": "list-group m-2"}):
146 li_attrs = {"class": "text-center bg-light list-group-item"}
147 li_attrs.update(self._toggle_attrs(name, True))
148 with Tagged("li", attrs=li_attrs):
149 f"<b>{name}</b>"
150 with Tagged("li", attrs={"class": "list-group-item show", "id": name}):
151 # display details for the function
152 # ? display time, args and kwargs
153 # with Tagged("table", attrs={"class": "table small"}):
154 # start = node.start_time - min_timestamp
155 # end = node.end_time - min_timestamp
156 # runtime = end - start
157 # self._make_line(
158 # "Time", f"{runtime:.2e} s (from {start:.2e} to {end:.2e})"
159 # )
160 # # if node.args:
161 # # func_args = node.args["args"]
162 # # self._make_line("args", func_args)
163 # # func_kwargs = node.args["kwargs"]
164 # # for k, v in func_kwargs.items():
165 # # self._make_line(k, v)
166 for child in node.children:
167 self._print_node(child, min_timestamp)
168 return records()
170 def _print_node(self, node: TraceNode, min_timestamp: float = 0.0) -> Any:
171 if node.type == "func":
172 return self._print_func(node, min_timestamp)
173 else:
174 return self._print_gen(node, min_timestamp)
176 def _toggle_attrs(self, name: str, expanded: bool = False) -> Dict:
177 return {
178 "data-bs-toggle": "collapse",
179 "href": f"#{name}",
180 "role": "button",
181 "aria-controls": name,
182 "aria-expanded": "true" if expanded else "false",
183 }
185 def _make_line(self, k: str, v: Any) -> str:
186 return f"<tr><th>{k}</th><td>{v}</td></tr>"
189class TraceLunaryPrinter(TracePrinterBase):
190 """The printer used to log the trace to lunary."""
192 def print(
193 self, trace: TraceEngineBase, meta_data: Optional[Configs] = None
194 ) -> None:
195 """Log the trace to lunary."""
196 import lunary
198 project_id = os.environ.get(
199 "LUNARY_PUBLIC_KEY", "1c1975c5-13b9-4977-8003-89fff5c71c27"
200 )
201 url = os.environ.get("LUNARY_API_URL", "http://localhost:3333")
202 logger.info(f"project_id: {project_id}, api url: {url}")
203 lunary.config(app_id=project_id, api_url=url)
205 suffix = f"_{uuid.uuid4().hex}"
206 logger.info(f"suffix: {suffix}")
208 def get_parent_run_id(node: TraceNode) -> Optional[str]:
209 if node.parent is None:
210 return None
211 return node.parent.name + suffix
213 """Log the trace to lunary."""
214 for node in trace.trace_nodes.values():
215 if node.type == "func":
216 logger.info(
217 f"sending func event {node.name} to lunary with parent {get_parent_run_id(node)}"
218 )
219 lunary.track_event(
220 "chain",
221 "start",
222 run_id=node.name + suffix,
223 name=node.name,
224 parent_run_id=get_parent_run_id(node),
225 input=node.args,
226 timestamp=timestamp_to_iso(node.start_time),
227 )
228 lunary.track_event(
229 "chain",
230 "end",
231 run_id=node.name + suffix,
232 output=node.ret,
233 timestamp=timestamp_to_iso(node.end_time),
234 )
236 elif node.type == "gen":
237 logger.info(
238 f"sending llm event {node.name} to lunary with parent {get_parent_run_id(node)}"
239 )
241 # skip the raw generation, support for legacy traces
242 # if node.name.endswith("_raw"):
243 # continue
244 metadata = copy.deepcopy(node.args or {})
245 model_name = metadata.pop("model", node.name)
246 messages = metadata.pop("messages", "")
247 metadata["gen_ID"] = node.name
248 lunary.track_event(
249 "llm",
250 "start",
251 run_id=node.name + suffix,
252 name=model_name,
253 parent_run_id=get_parent_run_id(node),
254 metadata=metadata,
255 input=messages,
256 timestamp=timestamp_to_iso(node.start_time),
257 )
258 lunary.track_event(
259 "llm",
260 "end",
261 run_id=node.name + suffix,
262 output={"role": "assistant", "content": node.ret},
263 timestamp=timestamp_to_iso(node.end_time),
264 )
265 elif node.type == "raw_llm":
266 logger.info(
267 f"sending raw llm event {node.name} to lunary with parent {get_parent_run_id(node)}"
268 )
269 metadata = copy.deepcopy(node.args or {})
270 model_name = metadata.pop("model", node.name)
271 messages = metadata.pop("messages", "")
272 lunary.track_event(
273 "llm",
274 "start",
275 run_id=node.name + suffix,
276 name=model_name,
277 parent_run_id=get_parent_run_id(node),
278 metadata=metadata,
279 input=messages,
280 timestamp=timestamp_to_iso(node.start_time),
281 )
282 response: ModelResponse = node.ret # complete response
283 lunary.track_event(
284 "llm",
285 "end",
286 run_id=node.name + suffix,
287 output={
288 "role": "assistant",
289 "content": response.choices[0].message.content, # type: ignore
290 # TODO: support tool calls
291 },
292 timestamp=timestamp_to_iso(node.end_time),
293 )
296class TraceYAMLPrinter(TracePrinterBase):
297 """The printer used to print the trace in the format of YAML."""
299 def print(
300 self, trace: TraceEngineBase, meta_data: Optional[Configs] = None
301 ) -> None:
302 """Print the trace in the format of YAML."""
303 # TODO: implement the YAML printer
304 pass
307class TraceProfilePrinter(TracePrinterBase):
308 """The printer used to print the trace in the format of profile."""
310 def __init__(self, display_functions: bool = False):
311 """Initialize the printer.
313 Args:
314 display_functions: Whether to display the function calls.
315 """
316 self._display_functions = display_functions
318 def build_event(self, event: TraceEventBase, min_timestamp: float) -> Dict:
319 """Build the event for the trace."""
320 ts = str((event.time_stamp - min_timestamp) * 1e6)
321 data = {"pid": 0, "tid": 0, "name": event.name, "ts": ts}
322 # TODO: add args to the trace
323 if isinstance(event, CompletionRequestEvent):
324 data["cat"] = "gen"
325 data["ph"] = "b"
326 data["id"] = event.name
327 elif isinstance(event, CompletionResponseEvent):
328 data["cat"] = "gen"
329 data["ph"] = "e"
330 data["id"] = event.name
331 data["cost"] = event.cost
332 data["output"] = event.ret.dict()
333 elif self._display_functions:
334 if isinstance(event, FunctionCallEvent):
335 data["cat"] = "func"
336 data["ph"] = "B"
337 data["tid"] = "main"
338 elif isinstance(event, FunctionReturnEvent):
339 data["cat"] = "func"
340 data["ph"] = "E"
341 data["tid"] = "main"
342 return data
344 def print(
345 self, trace: TraceEngineBase, meta_data: Optional[Configs] = None
346 ) -> Dict:
347 """Print the trace in the format of Chrome tracing."""
348 events = []
349 for event in trace.events:
350 if data := self.build_event(event, trace.min_timestamp):
351 events.append(data)
352 return {"traceEvents": events}