Coverage for src/appl/core/trace.py: 96%

120 statements  

« prev     ^ index     » next       coverage.py v7.6.7, created at 2024-11-22 15:39 -0800

1import threading 

2import time 

3from abc import ABC, abstractmethod 

4from functools import cached_property 

5from inspect import signature 

6from typing import Any, Callable, Dict, List, Optional, TypeVar, Union, overload 

7 

8from loguru import logger 

9from pydantic import BaseModel, model_validator 

10 

11from .config import Configs, configs 

12from .globals import global_vars, inc_global_var 

13from .utils import wraps 

14 

15 

16class TraceEventBase(BaseModel): 

17 """A base class for trace events.""" 

18 

19 name: str 

20 """The name of the event.""" 

21 time_stamp: float = None # type: ignore 

22 """The time stamp of the event.""" 

23 

24 @model_validator(mode="after") 

25 def _check_time_stamp(self) -> "TraceEventBase": 

26 if self.time_stamp is None: 

27 # Set the time stamp to the current time if it is not set 

28 self.time_stamp = time.time() # type: ignore 

29 return self 

30 

31 

32class FunctionCallEvent(TraceEventBase): 

33 """A class representing a function call event.""" 

34 

35 args: Dict 

36 """The arguments of the function call.""" 

37 parent_func: Optional[str] = None 

38 """The name of the parent function.""" 

39 

40 

41class FunctionReturnEvent(TraceEventBase): 

42 """A class representing a function return event.""" 

43 

44 ret: Any = None 

45 """The return value of the function.""" 

46 

47 

48class GenerationInitEvent(TraceEventBase): 

49 """A class representing a generation init event.""" 

50 

51 parent_func: Optional[str] = None 

52 """The name of the parent function.""" 

53 

54 

55class GenerationResponseEvent(TraceEventBase): 

56 """A class representing a generation response event.""" 

57 

58 args: Dict 

59 """The arguments of the generation call.""" 

60 ret: Any 

61 """The return value of the generation call.""" 

62 

63 

64class CompletionRequestEvent(TraceEventBase): 

65 """A class representing a completion request event.""" 

66 

67 parent_func: Optional[str] = None 

68 """The name of the parent function.""" 

69 

70 

71class CompletionResponseEvent(TraceEventBase): 

72 """A class representing a completion response event.""" 

73 

74 args: Dict 

75 """The arguments of the completion call.""" 

76 ret: Any 

77 """The return value of the completion call.""" 

78 cost: Optional[float] 

79 """The api cost of the completion call.""" 

80 

81 

82class TraceNode(BaseModel): 

83 """The node of a trace tree containing information about trace events.""" 

84 

85 type: str 

86 """The type of the trace node.""" 

87 name: str 

88 """The name of the trace node.""" 

89 parent: Optional["TraceNode"] = None 

90 """The parent of the trace node.""" 

91 children: List["TraceNode"] = [] 

92 """The children of the trace node.""" 

93 args: Optional[Dict] = None 

94 """The arguments of the trace node.""" 

95 ret: Any = None 

96 """The return value of the trace node.""" 

97 start_time: float = 0.0 

98 """The start time of the trace node.""" 

99 end_time: float = 0.0 

100 """The end time of the trace node.""" 

101 info: Dict = {} 

102 """The extra information of the trace node.""" 

103 

104 @property 

105 def runtime(self) -> float: 

106 """The runtime of the trace node.""" 

107 return self.end_time - self.start_time 

108 

109 

110def add_to_trace(event: TraceEventBase) -> None: 

111 """Add an event to the trace.""" 

112 if global_vars.trace_engine: 

113 global_vars.trace_engine.append(event) 

114 

115 

116F = TypeVar("F", bound=Callable) 

117 

118 

119@overload 

120def traceable(func: F) -> F: ... 

121 

122 

123@overload 

124def traceable( 

125 func: Optional[str] = None, 

126 *, 

127 metadata: Optional[Dict] = None, 

128) -> Callable[[F], F]: ... 

129 

130 

131def traceable( 

132 func: Optional[Union[F, str]] = None, 

133 *, 

134 metadata: Optional[Dict] = None, 

135) -> Union[F, Callable[[F], F]]: 

136 """Make a function traceable. 

137 

138 Args: 

139 func (str): The custom name of the function. 

140 metadata (Dict): The meta information of the function to be traced. 

141 """ 

142 # TODO: record metadata 

143 name: Optional[str] = None 

144 

145 def decorator(func: F) -> F: 

146 @wraps(func) 

147 def wrapper(*args: Any, **kwargs: Any) -> Any: 

148 func_id = name 

149 if func_id is None: 

150 func_id = func.__qualname__ 

151 func_run_cnt = inc_global_var(func_id) - 1 

152 func_id += f"_{func_run_cnt}" 

153 logger.info( 

154 f"Tracking function {func_id} with parent {global_vars.current_func.get()} in thread {threading.current_thread()}" 

155 ) 

156 

157 def _get_bind_args(): 

158 sig = signature(func) 

159 kwargs_copy = kwargs.copy() 

160 # remove special args that do not need to be passed 

161 for key in ["_ctx", "_locals", "_globals"]: 

162 if key not in sig.parameters: 

163 kwargs_copy.pop(key, None) 

164 return sig.bind_partial(*args, **kwargs_copy) 

165 

166 if global_vars.trace_engine: 

167 # NOTE: compute repr(args) might be time-consuming 

168 # TODO: jsonify the args 

169 add_to_trace( 

170 FunctionCallEvent( 

171 name=func_id, 

172 args={ 

173 k: repr(v) for k, v in _get_bind_args().arguments.items() 

174 }, 

175 ) 

176 ) 

177 

178 # set the current function, used for the function calls inside to get the parent function 

179 token = global_vars.current_func.set(func_id) 

180 

181 # call the inner function 

182 ret = func(*args, **kwargs) 

183 

184 # reset the current function name after the function call 

185 global_vars.current_func.reset(token) 

186 if global_vars.trace_engine: 

187 add_to_trace(FunctionReturnEvent(name=func_id, ret=repr(ret))) 

188 # TODO: replace the return value with the actual value when the computation of future is finished (in trace) 

189 # TODO: jsonify the ret 

190 return ret 

191 

192 return wrapper # type: ignore 

193 

194 if callable(func): 

195 return decorator(func) # type: ignore 

196 else: 

197 name = func 

198 return decorator 

199 

200 

201class TraceEngineBase(ABC): 

202 """A base class for trace engines.""" 

203 

204 @property 

205 @abstractmethod 

206 def events(self) -> List[TraceEventBase]: 

207 """The list of events in the trace.""" 

208 raise NotImplementedError 

209 

210 @cached_property 

211 def min_timestamp(self) -> float: 

212 """The minimum time stamp of the events in the trace.""" 

213 return min([event.time_stamp for event in self.events]) 

214 

215 @property 

216 @abstractmethod 

217 def trace_nodes(self) -> Dict[str, TraceNode]: 

218 """The dictionary of trace nodes in the trace.""" 

219 raise NotImplementedError 

220 

221 @abstractmethod 

222 def append(self, event: TraceEventBase) -> None: 

223 """Append an event to the trace.""" 

224 raise NotImplementedError 

225 

226 @abstractmethod 

227 def find_cache(self, name: str, args: Dict) -> Any: 

228 """Find a completion result in the cache. 

229 

230 Args: 

231 name: The name of the completion. 

232 args: The arguments of the completion. 

233 

234 Returns: 

235 The completion result if found, otherwise None. 

236 """ 

237 raise NotImplementedError 

238 

239 

240def find_in_cache( 

241 name: str, args: Dict, cache: Optional[TraceEngineBase] = None 

242) -> Any: 

243 """Find a completion result in the cache. 

244 

245 Args: 

246 name: The name of the completion. 

247 args: The arguments of the completion. 

248 cache: The cache to search in. Defaults to the global resume cache. 

249 

250 Returns: 

251 The completion result if found, otherwise None. 

252 """ 

253 if cache is None: 

254 if "resume_cache" in global_vars: 

255 cache = global_vars.resume_cache 

256 if cache is not None: 

257 return cache.find_cache(name, args) 

258 return None 

259 

260 

261class TracePrinterBase(ABC): 

262 """A base class for trace printers.""" 

263 

264 @abstractmethod 

265 def print(self, trace: TraceEngineBase, meta_data: Optional[Configs] = None) -> Any: 

266 """Print the trace.""" 

267 raise NotImplementedError