Coverage for src/appl/servers/api.py: 69%
137 statements
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-22 15:39 -0800
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-22 15:39 -0800
1import asyncio
2import time
3from functools import wraps
4from importlib.metadata import version
5from typing import Any, Callable, Dict, List, Optional, Tuple
7import litellm
8import yaml
9from litellm import (
10 CustomStreamWrapper,
11 ModelResponse,
12 completion_cost,
13 stream_chunk_builder,
14)
15from litellm.exceptions import NotFoundError
16from loguru import logger
17from openai import OpenAI, Stream
18from openai.types.chat import ChatCompletion, ChatCompletionChunk
19from pydantic import BaseModel
21from ..core.config import configs
22from ..core.message import Conversation
23from ..core.response import CompletionResponse
24from ..core.server import BaseServer, GenArgs
25from ..core.tool import ToolCall
26from ..core.trace import (
27 CompletionRequestEvent,
28 CompletionResponseEvent,
29 add_to_trace,
30 find_in_cache,
31)
32from ..utils import _langsmith_traceable
33from ..version import __version__
35if configs.getattrs("settings.misc.suppress_litellm_debug_info"):
36 litellm.suppress_debug_info = True
39# wrap the completion function # TODO: wrap the acompletion function?
40@wraps(litellm.completion)
41def chat_completion(**kwargs: Any) -> CompletionResponse:
42 """Wrap the litellm.completion function to add tracing and logging."""
43 if "gen_id" not in kwargs:
44 raise ValueError("gen_id is required for tracing completion generation.")
45 gen_id = kwargs.pop("gen_id")
46 raw_response_holder = []
47 if "_raw_response_holder" in kwargs:
48 raw_response_holder = kwargs.pop("_raw_response_holder")
49 add_to_trace(CompletionRequestEvent(name=gen_id))
51 log_llm_call_args = configs.getattrs("settings.logging.display.llm_raw_call_args")
52 log_llm_response = configs.getattrs("settings.logging.display.llm_raw_response")
53 log_llm_usage = configs.getattrs("settings.logging.display.llm_raw_usage")
54 log_llm_cache = configs.getattrs("settings.logging.display.llm_cache")
55 if log_llm_call_args:
56 logger.info(f"Call completion [{gen_id}] with args: {kwargs}")
58 @_langsmith_traceable(
59 name=f"ChatCompletion_{gen_id}",
60 run_type="llm",
61 metadata={"appl": "completion", "appl_version": __version__},
62 ) # type: ignore
63 def wrapped(**inner_kwargs: Any) -> Tuple[Any, bool]:
64 if cache_ret := find_in_cache(gen_id, inner_kwargs):
65 if log_llm_cache:
66 logger.info("Found in cache, using cached response...")
67 # ? support rebuild the stream from cached response
68 if inner_kwargs.get("stream", False):
69 logger.warning(
70 "Using cached complete response for a streaming generation."
71 )
72 raw_response = cache_ret
73 else:
74 # if log_llm_cache:
75 # logger.info("Not found in cache, creating response...")
76 raw_response = litellm.completion(**inner_kwargs)
77 return raw_response, cache_ret is not None
79 try:
80 raw_response, use_cache = wrapped(**kwargs)
81 except Exception as e:
82 # log the error information for debugging
83 logger.error(f"Error encountered for the completion: {e}")
84 logger.info(f"kwargs:\n{kwargs}")
85 raise e
87 if raw_response_holder is not None:
88 raw_response_holder.append(raw_response)
90 def post_completion(response: CompletionResponse) -> None:
91 raw_response = response.complete_response
92 cost = 0.0 if use_cache else response.cost
93 response.cost = cost # update the cost
94 event = CompletionResponseEvent(
95 name=gen_id, args=kwargs, ret=raw_response, cost=cost
96 )
97 add_to_trace(event)
98 if log_llm_response:
99 logger.info(f"Completion [{gen_id}] response: {response}")
100 if log_llm_usage and response.usage is not None:
101 logger.info(f"Completion [{gen_id}] usage: {response.usage}")
103 return CompletionResponse(
104 raw_response=raw_response, post_finish_callbacks=[post_completion]
105 ) # type: ignore
108# TODO: add default batch_size, to avoid too many requests
109class APIServer(BaseServer):
110 """The server for API models. It is a wrapper of litellm.completion."""
112 def __init__(
113 self,
114 model: str,
115 base_url: Optional[str] = None,
116 api_key: Optional[str] = None,
117 custom_llm_provider: Optional[str] = None,
118 cost_currency: str = "USD",
119 **kwargs: Any,
120 ) -> None:
121 """Initialize the API server.
123 See [LiteLLM](https://docs.litellm.ai/docs/providers)
124 for available models and providers.
125 See [completion](https://docs.litellm.ai/docs/completion/input#input-params-1)
126 for available options.
127 """
128 super().__init__()
129 self._model = model
130 self._base_url = base_url
131 self._api_key = api_key
132 self._custom_llm_provider = custom_llm_provider
133 if custom_llm_provider is not None and api_key is None:
134 self._api_key = "NotRequired" # bypass the api_key check of litellm
135 self._cost_currency = cost_currency
136 self._default_args = kwargs
138 @property
139 def model_name(self):
140 """The model name."""
141 return self._model
143 def _get_create_args(self, args: GenArgs, **kwargs: Any) -> dict:
144 # supports custom postprocess create_args
145 create_args = self._default_args.copy()
146 postprocess = kwargs.pop("postprocess_args", None)
147 if self._base_url is not None:
148 create_args["base_url"] = self._base_url
149 if self._api_key is not None:
150 create_args["api_key"] = self._api_key
151 if self._custom_llm_provider:
152 create_args["custom_llm_provider"] = self._custom_llm_provider
153 create_args.update(kwargs) # update create_args with other kwargs
155 # add args to create_args
156 create_args.update(args.preprocess(self._convert, is_openai=True))
157 if postprocess is not None:
158 create_args = postprocess(create_args)
160 return create_args
162 def _create(self, **kwargs: Any) -> CompletionResponse:
163 response: CompletionResponse = None # type: ignore
164 # to store raw response when patched completion meets error
165 kwargs["_raw_response_holder"] = []
167 response_model = kwargs.get("response_model", None)
168 response_format = kwargs.get("response_format", None)
169 if response_model is not None:
170 try:
171 from instructor.mode import Mode
172 from instructor.patch import patch
173 except ImportError:
174 raise RuntimeError(
175 "response_model requires instructor, install with `pip install instructor`"
176 )
178 def wrapper(**inner_kwargs: Any) -> CompletionResponse:
179 nonlocal response
180 response = chat_completion(**inner_kwargs)
181 return response
183 try:
184 # Use instructor.patch to enable using a pydantic model as response model
185 # added arguments: response_model, validation_context, max_retries
186 mode = kwargs.pop("instructor_patch_mode", Mode.JSON)
187 patched = patch(create=wrapper, mode=mode)
188 results = patched(**kwargs)
189 # fill in the response_model and response_obj
190 response.response_model = response_model
191 response.set_response_obj(results)
192 # TODO?: update the cost for multiple retries
193 # instructor has updated the total usage for retries
194 # ?? response.cost = completion_cost({"usage": response.usage})
195 except Exception as e:
196 # log the error information for debugging
197 logger.error(f"Error encountered for the patched completion: {e}")
198 _raw_response_holder = kwargs.pop("_raw_response_holder", [])
199 logger.info(f"kwargs:\n{yaml.dump(kwargs)}")
200 if _raw_response_holder:
201 logger.info(f"raw_response:\n{_raw_response_holder[0]}")
202 raise e
203 else:
204 wrapped_attribute = kwargs.pop("_wrapped_attribute", None)
205 response = chat_completion(**kwargs)
206 if isinstance(response_format, type) and issubclass(
207 response_format, BaseModel
208 ):
209 response.response_model = response_format
210 assert (
211 response.response_obj is None
212 ), "response_obj should not be set yet."
213 # retrieve the response message and convert it to the response model
214 # fetching the results will stream the response if it is a streaming
215 response_obj = response_format.model_validate_json(response.results)
216 if wrapped_attribute:
217 assert hasattr(
218 response_obj, wrapped_attribute
219 ), f"should have attribute {wrapped_attribute} in the response"
220 response_obj = getattr(response_obj, wrapped_attribute)
221 response.set_response_obj(response_obj)
222 return response
224 def _convert(self, conversation: Conversation) -> List[Dict[str, str]]:
225 return conversation.as_list()
227 def close(self):
228 """Close the server."""
229 pass