Coverage for src/appl/servers/api.py: 69%

137 statements  

« prev     ^ index     » next       coverage.py v7.6.7, created at 2024-11-22 15:39 -0800

1import asyncio 

2import time 

3from functools import wraps 

4from importlib.metadata import version 

5from typing import Any, Callable, Dict, List, Optional, Tuple 

6 

7import litellm 

8import yaml 

9from litellm import ( 

10 CustomStreamWrapper, 

11 ModelResponse, 

12 completion_cost, 

13 stream_chunk_builder, 

14) 

15from litellm.exceptions import NotFoundError 

16from loguru import logger 

17from openai import OpenAI, Stream 

18from openai.types.chat import ChatCompletion, ChatCompletionChunk 

19from pydantic import BaseModel 

20 

21from ..core.config import configs 

22from ..core.message import Conversation 

23from ..core.response import CompletionResponse 

24from ..core.server import BaseServer, GenArgs 

25from ..core.tool import ToolCall 

26from ..core.trace import ( 

27 CompletionRequestEvent, 

28 CompletionResponseEvent, 

29 add_to_trace, 

30 find_in_cache, 

31) 

32from ..utils import _langsmith_traceable 

33from ..version import __version__ 

34 

35if configs.getattrs("settings.misc.suppress_litellm_debug_info"): 

36 litellm.suppress_debug_info = True 

37 

38 

39# wrap the completion function # TODO: wrap the acompletion function? 

40@wraps(litellm.completion) 

41def chat_completion(**kwargs: Any) -> CompletionResponse: 

42 """Wrap the litellm.completion function to add tracing and logging.""" 

43 if "gen_id" not in kwargs: 

44 raise ValueError("gen_id is required for tracing completion generation.") 

45 gen_id = kwargs.pop("gen_id") 

46 raw_response_holder = [] 

47 if "_raw_response_holder" in kwargs: 

48 raw_response_holder = kwargs.pop("_raw_response_holder") 

49 add_to_trace(CompletionRequestEvent(name=gen_id)) 

50 

51 log_llm_call_args = configs.getattrs("settings.logging.display.llm_raw_call_args") 

52 log_llm_response = configs.getattrs("settings.logging.display.llm_raw_response") 

53 log_llm_usage = configs.getattrs("settings.logging.display.llm_raw_usage") 

54 log_llm_cache = configs.getattrs("settings.logging.display.llm_cache") 

55 if log_llm_call_args: 

56 logger.info(f"Call completion [{gen_id}] with args: {kwargs}") 

57 

58 @_langsmith_traceable( 

59 name=f"ChatCompletion_{gen_id}", 

60 run_type="llm", 

61 metadata={"appl": "completion", "appl_version": __version__}, 

62 ) # type: ignore 

63 def wrapped(**inner_kwargs: Any) -> Tuple[Any, bool]: 

64 if cache_ret := find_in_cache(gen_id, inner_kwargs): 

65 if log_llm_cache: 

66 logger.info("Found in cache, using cached response...") 

67 # ? support rebuild the stream from cached response 

68 if inner_kwargs.get("stream", False): 

69 logger.warning( 

70 "Using cached complete response for a streaming generation." 

71 ) 

72 raw_response = cache_ret 

73 else: 

74 # if log_llm_cache: 

75 # logger.info("Not found in cache, creating response...") 

76 raw_response = litellm.completion(**inner_kwargs) 

77 return raw_response, cache_ret is not None 

78 

79 try: 

80 raw_response, use_cache = wrapped(**kwargs) 

81 except Exception as e: 

82 # log the error information for debugging 

83 logger.error(f"Error encountered for the completion: {e}") 

84 logger.info(f"kwargs:\n{kwargs}") 

85 raise e 

86 

87 if raw_response_holder is not None: 

88 raw_response_holder.append(raw_response) 

89 

90 def post_completion(response: CompletionResponse) -> None: 

91 raw_response = response.complete_response 

92 cost = 0.0 if use_cache else response.cost 

93 response.cost = cost # update the cost 

94 event = CompletionResponseEvent( 

95 name=gen_id, args=kwargs, ret=raw_response, cost=cost 

96 ) 

97 add_to_trace(event) 

98 if log_llm_response: 

99 logger.info(f"Completion [{gen_id}] response: {response}") 

100 if log_llm_usage and response.usage is not None: 

101 logger.info(f"Completion [{gen_id}] usage: {response.usage}") 

102 

103 return CompletionResponse( 

104 raw_response=raw_response, post_finish_callbacks=[post_completion] 

105 ) # type: ignore 

106 

107 

108# TODO: add default batch_size, to avoid too many requests 

109class APIServer(BaseServer): 

110 """The server for API models. It is a wrapper of litellm.completion.""" 

111 

112 def __init__( 

113 self, 

114 model: str, 

115 base_url: Optional[str] = None, 

116 api_key: Optional[str] = None, 

117 custom_llm_provider: Optional[str] = None, 

118 cost_currency: str = "USD", 

119 **kwargs: Any, 

120 ) -> None: 

121 """Initialize the API server. 

122 

123 See [LiteLLM](https://docs.litellm.ai/docs/providers) 

124 for available models and providers. 

125 See [completion](https://docs.litellm.ai/docs/completion/input#input-params-1) 

126 for available options. 

127 """ 

128 super().__init__() 

129 self._model = model 

130 self._base_url = base_url 

131 self._api_key = api_key 

132 self._custom_llm_provider = custom_llm_provider 

133 if custom_llm_provider is not None and api_key is None: 

134 self._api_key = "NotRequired" # bypass the api_key check of litellm 

135 self._cost_currency = cost_currency 

136 self._default_args = kwargs 

137 

138 @property 

139 def model_name(self): 

140 """The model name.""" 

141 return self._model 

142 

143 def _get_create_args(self, args: GenArgs, **kwargs: Any) -> dict: 

144 # supports custom postprocess create_args 

145 create_args = self._default_args.copy() 

146 postprocess = kwargs.pop("postprocess_args", None) 

147 if self._base_url is not None: 

148 create_args["base_url"] = self._base_url 

149 if self._api_key is not None: 

150 create_args["api_key"] = self._api_key 

151 if self._custom_llm_provider: 

152 create_args["custom_llm_provider"] = self._custom_llm_provider 

153 create_args.update(kwargs) # update create_args with other kwargs 

154 

155 # add args to create_args 

156 create_args.update(args.preprocess(self._convert, is_openai=True)) 

157 if postprocess is not None: 

158 create_args = postprocess(create_args) 

159 

160 return create_args 

161 

162 def _create(self, **kwargs: Any) -> CompletionResponse: 

163 response: CompletionResponse = None # type: ignore 

164 # to store raw response when patched completion meets error 

165 kwargs["_raw_response_holder"] = [] 

166 

167 response_model = kwargs.get("response_model", None) 

168 response_format = kwargs.get("response_format", None) 

169 if response_model is not None: 

170 try: 

171 from instructor.mode import Mode 

172 from instructor.patch import patch 

173 except ImportError: 

174 raise RuntimeError( 

175 "response_model requires instructor, install with `pip install instructor`" 

176 ) 

177 

178 def wrapper(**inner_kwargs: Any) -> CompletionResponse: 

179 nonlocal response 

180 response = chat_completion(**inner_kwargs) 

181 return response 

182 

183 try: 

184 # Use instructor.patch to enable using a pydantic model as response model 

185 # added arguments: response_model, validation_context, max_retries 

186 mode = kwargs.pop("instructor_patch_mode", Mode.JSON) 

187 patched = patch(create=wrapper, mode=mode) 

188 results = patched(**kwargs) 

189 # fill in the response_model and response_obj 

190 response.response_model = response_model 

191 response.set_response_obj(results) 

192 # TODO?: update the cost for multiple retries 

193 # instructor has updated the total usage for retries 

194 # ?? response.cost = completion_cost({"usage": response.usage}) 

195 except Exception as e: 

196 # log the error information for debugging 

197 logger.error(f"Error encountered for the patched completion: {e}") 

198 _raw_response_holder = kwargs.pop("_raw_response_holder", []) 

199 logger.info(f"kwargs:\n{yaml.dump(kwargs)}") 

200 if _raw_response_holder: 

201 logger.info(f"raw_response:\n{_raw_response_holder[0]}") 

202 raise e 

203 else: 

204 wrapped_attribute = kwargs.pop("_wrapped_attribute", None) 

205 response = chat_completion(**kwargs) 

206 if isinstance(response_format, type) and issubclass( 

207 response_format, BaseModel 

208 ): 

209 response.response_model = response_format 

210 assert ( 

211 response.response_obj is None 

212 ), "response_obj should not be set yet." 

213 # retrieve the response message and convert it to the response model 

214 # fetching the results will stream the response if it is a streaming 

215 response_obj = response_format.model_validate_json(response.results) 

216 if wrapped_attribute: 

217 assert hasattr( 

218 response_obj, wrapped_attribute 

219 ), f"should have attribute {wrapped_attribute} in the response" 

220 response_obj = getattr(response_obj, wrapped_attribute) 

221 response.set_response_obj(response_obj) 

222 return response 

223 

224 def _convert(self, conversation: Conversation) -> List[Dict[str, str]]: 

225 return conversation.as_list() 

226 

227 def close(self): 

228 """Close the server.""" 

229 pass