Coverage for src/appl/core/trace.py: 96%
120 statements
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-22 15:39 -0800
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-22 15:39 -0800
1import threading
2import time
3from abc import ABC, abstractmethod
4from functools import cached_property
5from inspect import signature
6from typing import Any, Callable, Dict, List, Optional, TypeVar, Union, overload
8from loguru import logger
9from pydantic import BaseModel, model_validator
11from .config import Configs, configs
12from .globals import global_vars, inc_global_var
13from .utils import wraps
16class TraceEventBase(BaseModel):
17 """A base class for trace events."""
19 name: str
20 """The name of the event."""
21 time_stamp: float = None # type: ignore
22 """The time stamp of the event."""
24 @model_validator(mode="after")
25 def _check_time_stamp(self) -> "TraceEventBase":
26 if self.time_stamp is None:
27 # Set the time stamp to the current time if it is not set
28 self.time_stamp = time.time() # type: ignore
29 return self
32class FunctionCallEvent(TraceEventBase):
33 """A class representing a function call event."""
35 args: Dict
36 """The arguments of the function call."""
37 parent_func: Optional[str] = None
38 """The name of the parent function."""
41class FunctionReturnEvent(TraceEventBase):
42 """A class representing a function return event."""
44 ret: Any = None
45 """The return value of the function."""
48class GenerationInitEvent(TraceEventBase):
49 """A class representing a generation init event."""
51 parent_func: Optional[str] = None
52 """The name of the parent function."""
55class GenerationResponseEvent(TraceEventBase):
56 """A class representing a generation response event."""
58 args: Dict
59 """The arguments of the generation call."""
60 ret: Any
61 """The return value of the generation call."""
64class CompletionRequestEvent(TraceEventBase):
65 """A class representing a completion request event."""
67 parent_func: Optional[str] = None
68 """The name of the parent function."""
71class CompletionResponseEvent(TraceEventBase):
72 """A class representing a completion response event."""
74 args: Dict
75 """The arguments of the completion call."""
76 ret: Any
77 """The return value of the completion call."""
78 cost: Optional[float]
79 """The api cost of the completion call."""
82class TraceNode(BaseModel):
83 """The node of a trace tree containing information about trace events."""
85 type: str
86 """The type of the trace node."""
87 name: str
88 """The name of the trace node."""
89 parent: Optional["TraceNode"] = None
90 """The parent of the trace node."""
91 children: List["TraceNode"] = []
92 """The children of the trace node."""
93 args: Optional[Dict] = None
94 """The arguments of the trace node."""
95 ret: Any = None
96 """The return value of the trace node."""
97 start_time: float = 0.0
98 """The start time of the trace node."""
99 end_time: float = 0.0
100 """The end time of the trace node."""
101 info: Dict = {}
102 """The extra information of the trace node."""
104 @property
105 def runtime(self) -> float:
106 """The runtime of the trace node."""
107 return self.end_time - self.start_time
110def add_to_trace(event: TraceEventBase) -> None:
111 """Add an event to the trace."""
112 if global_vars.trace_engine:
113 global_vars.trace_engine.append(event)
116F = TypeVar("F", bound=Callable)
119@overload
120def traceable(func: F) -> F: ...
123@overload
124def traceable(
125 func: Optional[str] = None,
126 *,
127 metadata: Optional[Dict] = None,
128) -> Callable[[F], F]: ...
131def traceable(
132 func: Optional[Union[F, str]] = None,
133 *,
134 metadata: Optional[Dict] = None,
135) -> Union[F, Callable[[F], F]]:
136 """Make a function traceable.
138 Args:
139 func (str): The custom name of the function.
140 metadata (Dict): The meta information of the function to be traced.
141 """
142 # TODO: record metadata
143 name: Optional[str] = None
145 def decorator(func: F) -> F:
146 @wraps(func)
147 def wrapper(*args: Any, **kwargs: Any) -> Any:
148 func_id = name
149 if func_id is None:
150 func_id = func.__qualname__
151 func_run_cnt = inc_global_var(func_id) - 1
152 func_id += f"_{func_run_cnt}"
153 logger.info(
154 f"Tracking function {func_id} with parent {global_vars.current_func.get()} in thread {threading.current_thread()}"
155 )
157 def _get_bind_args():
158 sig = signature(func)
159 kwargs_copy = kwargs.copy()
160 # remove special args that do not need to be passed
161 for key in ["_ctx", "_locals", "_globals"]:
162 if key not in sig.parameters:
163 kwargs_copy.pop(key, None)
164 return sig.bind_partial(*args, **kwargs_copy)
166 if global_vars.trace_engine:
167 # NOTE: compute repr(args) might be time-consuming
168 # TODO: jsonify the args
169 add_to_trace(
170 FunctionCallEvent(
171 name=func_id,
172 args={
173 k: repr(v) for k, v in _get_bind_args().arguments.items()
174 },
175 )
176 )
178 # set the current function, used for the function calls inside to get the parent function
179 token = global_vars.current_func.set(func_id)
181 # call the inner function
182 ret = func(*args, **kwargs)
184 # reset the current function name after the function call
185 global_vars.current_func.reset(token)
186 if global_vars.trace_engine:
187 add_to_trace(FunctionReturnEvent(name=func_id, ret=repr(ret)))
188 # TODO: replace the return value with the actual value when the computation of future is finished (in trace)
189 # TODO: jsonify the ret
190 return ret
192 return wrapper # type: ignore
194 if callable(func):
195 return decorator(func) # type: ignore
196 else:
197 name = func
198 return decorator
201class TraceEngineBase(ABC):
202 """A base class for trace engines."""
204 @property
205 @abstractmethod
206 def events(self) -> List[TraceEventBase]:
207 """The list of events in the trace."""
208 raise NotImplementedError
210 @cached_property
211 def min_timestamp(self) -> float:
212 """The minimum time stamp of the events in the trace."""
213 return min([event.time_stamp for event in self.events])
215 @property
216 @abstractmethod
217 def trace_nodes(self) -> Dict[str, TraceNode]:
218 """The dictionary of trace nodes in the trace."""
219 raise NotImplementedError
221 @abstractmethod
222 def append(self, event: TraceEventBase) -> None:
223 """Append an event to the trace."""
224 raise NotImplementedError
226 @abstractmethod
227 def find_cache(self, name: str, args: Dict) -> Any:
228 """Find a completion result in the cache.
230 Args:
231 name: The name of the completion.
232 args: The arguments of the completion.
234 Returns:
235 The completion result if found, otherwise None.
236 """
237 raise NotImplementedError
240def find_in_cache(
241 name: str, args: Dict, cache: Optional[TraceEngineBase] = None
242) -> Any:
243 """Find a completion result in the cache.
245 Args:
246 name: The name of the completion.
247 args: The arguments of the completion.
248 cache: The cache to search in. Defaults to the global resume cache.
250 Returns:
251 The completion result if found, otherwise None.
252 """
253 if cache is None:
254 if "resume_cache" in global_vars:
255 cache = global_vars.resume_cache
256 if cache is not None:
257 return cache.find_cache(name, args)
258 return None
261class TracePrinterBase(ABC):
262 """A base class for trace printers."""
264 @abstractmethod
265 def print(self, trace: TraceEngineBase, meta_data: Optional[Configs] = None) -> Any:
266 """Print the trace."""
267 raise NotImplementedError