Coverage for src/appl/tracing/engine.py: 80%
135 statements
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-22 15:39 -0800
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-22 15:39 -0800
1import json
2import os
3import pickle
4import re
5from threading import Lock
6from typing import Any, Dict, List, Optional, Type
8from loguru import logger
9from pydantic import BaseModel
11from ..core.globals import global_vars
12from ..core.trace import (
13 CompletionRequestEvent,
14 CompletionResponseEvent,
15 FunctionCallEvent,
16 FunctionReturnEvent,
17 GenerationInitEvent,
18 GenerationResponseEvent,
19 TraceEngineBase,
20 TraceEventBase,
21 TraceNode,
22)
25class TraceEngine(TraceEngineBase):
26 """The engine used to record the trace of a program execution."""
28 def __init__(self, filename: str, mode: str = "write", strict: bool = True) -> None:
29 """Initialize the TraceEngine.
31 Args:
32 filename: The filename storing the trace.
33 mode: The mode of the trace, "write" or "read". Defaults to "write".
34 strict:
35 Whether to match strictly when used as a cache. Defaults to True.
37 - True: matching according to the generation id, prompts, and
38 parameters. And cache stops to work whenever a match failed.
39 - False: only matching prompts and parameters.
40 """
41 self._mode = mode
42 self._strict = strict
43 self._events: List[TraceEventBase] = [] # events read from the file
44 self._trace_nodes: Dict[str, TraceNode] = {}
45 self._gen_cache: Dict[str, List[Any]] = {}
46 self._lock = Lock()
48 if mode == "write":
49 if os.path.exists(filename):
50 logger.warning(f"Trace file {filename} already exists, overwriting")
51 self._file = open(filename, "wb+")
52 elif mode == "read":
53 if not os.path.exists(filename):
54 raise FileNotFoundError(f"Trace file {filename} not found")
55 self._file = open(filename, "rb+")
56 self._read()
57 else:
58 raise ValueError(f"Invalid mode {mode}, only 'write' or 'read' allowed.")
60 @property
61 def events(self) -> List[TraceEventBase]:
62 """The list of events in the trace."""
63 return self._events
65 @property
66 def trace_nodes(self) -> Dict[str, TraceNode]:
67 """The dictionary of trace nodes."""
68 return self._trace_nodes
70 @classmethod
71 def convert_pydantic_class_to_schema(cls, class_: Type) -> Dict:
72 """Convert a class to a schema.
74 Args:
75 class_: The class to convert
76 """
77 if issubclass(class_, BaseModel):
78 return class_.model_json_schema()
79 raise ValueError(f"Cannot convert class {class_} to schema")
81 @classmethod
82 def args_to_json(cls, args: Dict) -> Dict:
83 """Serialize the values of the arguments to JSON format."""
84 args_json = {}
85 for k, v in args.items():
86 if isinstance(v, type) and issubclass(v, BaseModel):
87 v = cls.convert_pydantic_class_to_schema(v)
88 # TODO: shall we serialize everything?
89 # elif k != "message":
90 # try:
91 # v = json.dumps(v)
92 # except:
93 # v = str(v)
94 args_json[k] = v
95 return args_json
97 def append(self, event: TraceEventBase) -> None:
98 """Append an event to the trace."""
99 # print(
100 # event.name,
101 # global_vars.current_func.get(),
102 # getattr(event, "parent_func", None),
103 # )
105 if hasattr(event, "args"):
106 event.args = self.args_to_json(event.args)
108 self._events.append(event)
109 name, time_stamp = event.name, event.time_stamp
110 if self._mode == "write":
111 if isinstance(event, (FunctionCallEvent, GenerationInitEvent)):
112 event.parent_func = self._last_func
113 elif isinstance(event, CompletionRequestEvent):
114 match = re.match(r"(.+)_raw_\d+", event.name)
115 if match:
116 event.parent_func = match.group(1)
117 else:
118 assert False, f"Invalid completion request name: {event.name}"
120 with self._lock:
121 logger.debug(f"add to trace {event}")
122 pickle.dump(event, self._file)
123 self._file.flush()
125 assert name is not None
126 if isinstance(event, FunctionCallEvent):
127 newnode = self._add_node(name, event.parent_func, type="func")
128 newnode.start_time = time_stamp
129 newnode.args = event.args
130 elif isinstance(event, FunctionReturnEvent):
131 node = self._get_node(name)
132 if node:
133 node.ret = event.ret
134 node.end_time = time_stamp
135 elif isinstance(event, GenerationInitEvent):
136 newnode = self._add_node(name, event.parent_func, type="gen")
137 newnode.start_time = time_stamp
138 elif isinstance(event, GenerationResponseEvent):
139 node = self._get_node(name)
140 if node:
141 node.end_time = time_stamp
142 node.args = event.args
143 node.ret = event.ret
144 elif isinstance(event, CompletionRequestEvent):
145 newnode = self._add_node(name, event.parent_func, type="raw_llm")
146 newnode.start_time = time_stamp
147 elif isinstance(event, CompletionResponseEvent):
148 node = self._get_node(name)
149 if node:
150 node.end_time = time_stamp
151 node.args = event.args
152 node.ret = event.ret
153 node.info["cost"] = event.cost
155 # cached for raw completion response
156 key = self._cache_key(name, event.args)
157 if key not in self._gen_cache:
158 self._gen_cache[key] = []
159 self._gen_cache[key].append(event.ret)
161 def find_cache(self, name: str, args: Dict) -> Any:
162 """Find a cached response for a generation request.
164 Args:
165 name: The name of the generation request.
166 args: The arguments of the generation request.
167 """
168 args = self.args_to_json(args)
169 with self._lock:
170 entry_list = self._gen_cache.get(self._cache_key(name, args), None)
171 if not entry_list or len(entry_list) == 0:
172 return None
173 entry = entry_list.pop(0)
174 return entry
176 def _add_node(
177 self, name: str, parent_name: Optional[str] = None, type: str = "gen"
178 ) -> TraceNode:
179 parent = self._get_node(parent_name)
180 newnode = TraceNode(type=type, name=name, parent=parent)
181 if name in self._trace_nodes:
182 raise ValueError(f"Node {name} already exists in trace")
183 self._trace_nodes[name] = newnode
184 if parent:
185 parent.children.append(newnode)
186 return newnode
188 def _get_node(self, name: Optional[str]) -> Optional[TraceNode]:
189 if name is None:
190 return None
191 if name not in self._trace_nodes:
192 raise ValueError(f"Node {name} not found in trace")
193 return self._trace_nodes[name]
195 @property
196 def _last_func(self) -> Optional[str]:
197 return global_vars.current_func.get()
199 def _cache_key(self, name: str, args: Dict) -> str:
200 # pop the arguments that do not affect the result
201 args.pop("stream", None)
202 if self._strict:
203 return f"{name} {args}"
204 else:
205 return f"{args}"
207 def _read(self) -> None:
208 while True:
209 try:
210 event: TraceEventBase = pickle.load(self._file)
211 self.append(event)
212 except EOFError:
213 break