Coverage for src/appl/tracing/engine.py: 80%

135 statements  

« prev     ^ index     » next       coverage.py v7.6.7, created at 2024-11-22 15:39 -0800

1import json 

2import os 

3import pickle 

4import re 

5from threading import Lock 

6from typing import Any, Dict, List, Optional, Type 

7 

8from loguru import logger 

9from pydantic import BaseModel 

10 

11from ..core.globals import global_vars 

12from ..core.trace import ( 

13 CompletionRequestEvent, 

14 CompletionResponseEvent, 

15 FunctionCallEvent, 

16 FunctionReturnEvent, 

17 GenerationInitEvent, 

18 GenerationResponseEvent, 

19 TraceEngineBase, 

20 TraceEventBase, 

21 TraceNode, 

22) 

23 

24 

25class TraceEngine(TraceEngineBase): 

26 """The engine used to record the trace of a program execution.""" 

27 

28 def __init__(self, filename: str, mode: str = "write", strict: bool = True) -> None: 

29 """Initialize the TraceEngine. 

30 

31 Args: 

32 filename: The filename storing the trace. 

33 mode: The mode of the trace, "write" or "read". Defaults to "write". 

34 strict: 

35 Whether to match strictly when used as a cache. Defaults to True. 

36 

37 - True: matching according to the generation id, prompts, and 

38 parameters. And cache stops to work whenever a match failed. 

39 - False: only matching prompts and parameters. 

40 """ 

41 self._mode = mode 

42 self._strict = strict 

43 self._events: List[TraceEventBase] = [] # events read from the file 

44 self._trace_nodes: Dict[str, TraceNode] = {} 

45 self._gen_cache: Dict[str, List[Any]] = {} 

46 self._lock = Lock() 

47 

48 if mode == "write": 

49 if os.path.exists(filename): 

50 logger.warning(f"Trace file {filename} already exists, overwriting") 

51 self._file = open(filename, "wb+") 

52 elif mode == "read": 

53 if not os.path.exists(filename): 

54 raise FileNotFoundError(f"Trace file {filename} not found") 

55 self._file = open(filename, "rb+") 

56 self._read() 

57 else: 

58 raise ValueError(f"Invalid mode {mode}, only 'write' or 'read' allowed.") 

59 

60 @property 

61 def events(self) -> List[TraceEventBase]: 

62 """The list of events in the trace.""" 

63 return self._events 

64 

65 @property 

66 def trace_nodes(self) -> Dict[str, TraceNode]: 

67 """The dictionary of trace nodes.""" 

68 return self._trace_nodes 

69 

70 @classmethod 

71 def convert_pydantic_class_to_schema(cls, class_: Type) -> Dict: 

72 """Convert a class to a schema. 

73 

74 Args: 

75 class_: The class to convert 

76 """ 

77 if issubclass(class_, BaseModel): 

78 return class_.model_json_schema() 

79 raise ValueError(f"Cannot convert class {class_} to schema") 

80 

81 @classmethod 

82 def args_to_json(cls, args: Dict) -> Dict: 

83 """Serialize the values of the arguments to JSON format.""" 

84 args_json = {} 

85 for k, v in args.items(): 

86 if isinstance(v, type) and issubclass(v, BaseModel): 

87 v = cls.convert_pydantic_class_to_schema(v) 

88 # TODO: shall we serialize everything? 

89 # elif k != "message": 

90 # try: 

91 # v = json.dumps(v) 

92 # except: 

93 # v = str(v) 

94 args_json[k] = v 

95 return args_json 

96 

97 def append(self, event: TraceEventBase) -> None: 

98 """Append an event to the trace.""" 

99 # print( 

100 # event.name, 

101 # global_vars.current_func.get(), 

102 # getattr(event, "parent_func", None), 

103 # ) 

104 

105 if hasattr(event, "args"): 

106 event.args = self.args_to_json(event.args) 

107 

108 self._events.append(event) 

109 name, time_stamp = event.name, event.time_stamp 

110 if self._mode == "write": 

111 if isinstance(event, (FunctionCallEvent, GenerationInitEvent)): 

112 event.parent_func = self._last_func 

113 elif isinstance(event, CompletionRequestEvent): 

114 match = re.match(r"(.+)_raw_\d+", event.name) 

115 if match: 

116 event.parent_func = match.group(1) 

117 else: 

118 assert False, f"Invalid completion request name: {event.name}" 

119 

120 with self._lock: 

121 logger.debug(f"add to trace {event}") 

122 pickle.dump(event, self._file) 

123 self._file.flush() 

124 

125 assert name is not None 

126 if isinstance(event, FunctionCallEvent): 

127 newnode = self._add_node(name, event.parent_func, type="func") 

128 newnode.start_time = time_stamp 

129 newnode.args = event.args 

130 elif isinstance(event, FunctionReturnEvent): 

131 node = self._get_node(name) 

132 if node: 

133 node.ret = event.ret 

134 node.end_time = time_stamp 

135 elif isinstance(event, GenerationInitEvent): 

136 newnode = self._add_node(name, event.parent_func, type="gen") 

137 newnode.start_time = time_stamp 

138 elif isinstance(event, GenerationResponseEvent): 

139 node = self._get_node(name) 

140 if node: 

141 node.end_time = time_stamp 

142 node.args = event.args 

143 node.ret = event.ret 

144 elif isinstance(event, CompletionRequestEvent): 

145 newnode = self._add_node(name, event.parent_func, type="raw_llm") 

146 newnode.start_time = time_stamp 

147 elif isinstance(event, CompletionResponseEvent): 

148 node = self._get_node(name) 

149 if node: 

150 node.end_time = time_stamp 

151 node.args = event.args 

152 node.ret = event.ret 

153 node.info["cost"] = event.cost 

154 

155 # cached for raw completion response 

156 key = self._cache_key(name, event.args) 

157 if key not in self._gen_cache: 

158 self._gen_cache[key] = [] 

159 self._gen_cache[key].append(event.ret) 

160 

161 def find_cache(self, name: str, args: Dict) -> Any: 

162 """Find a cached response for a generation request. 

163 

164 Args: 

165 name: The name of the generation request. 

166 args: The arguments of the generation request. 

167 """ 

168 args = self.args_to_json(args) 

169 with self._lock: 

170 entry_list = self._gen_cache.get(self._cache_key(name, args), None) 

171 if not entry_list or len(entry_list) == 0: 

172 return None 

173 entry = entry_list.pop(0) 

174 return entry 

175 

176 def _add_node( 

177 self, name: str, parent_name: Optional[str] = None, type: str = "gen" 

178 ) -> TraceNode: 

179 parent = self._get_node(parent_name) 

180 newnode = TraceNode(type=type, name=name, parent=parent) 

181 if name in self._trace_nodes: 

182 raise ValueError(f"Node {name} already exists in trace") 

183 self._trace_nodes[name] = newnode 

184 if parent: 

185 parent.children.append(newnode) 

186 return newnode 

187 

188 def _get_node(self, name: Optional[str]) -> Optional[TraceNode]: 

189 if name is None: 

190 return None 

191 if name not in self._trace_nodes: 

192 raise ValueError(f"Node {name} not found in trace") 

193 return self._trace_nodes[name] 

194 

195 @property 

196 def _last_func(self) -> Optional[str]: 

197 return global_vars.current_func.get() 

198 

199 def _cache_key(self, name: str, args: Dict) -> str: 

200 # pop the arguments that do not affect the result 

201 args.pop("stream", None) 

202 if self._strict: 

203 return f"{name} {args}" 

204 else: 

205 return f"{args}" 

206 

207 def _read(self) -> None: 

208 while True: 

209 try: 

210 event: TraceEventBase = pickle.load(self._file) 

211 self.append(event) 

212 except EOFError: 

213 break