Coverage for src/appl/core/server.py: 95%
60 statements
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-22 15:39 -0800
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-22 15:39 -0800
1from abc import ABC, abstractmethod
2from typing import Any, Callable, Literal, Optional, Sequence, Type, Union
4from loguru import logger
5from pydantic import BaseModel, Field
6from typing_extensions import override
8from .message import Conversation
9from .response import CompletionResponse
10from .tool import BaseTool, ToolCall
11from .types import MaybeOneOrMany
14class GenArgs(BaseModel):
15 """Common arguments for generating a response from a model."""
17 model: str = Field(..., description="The name of the backend model")
18 messages: Conversation = Field(
19 ..., description="The conversation to use as a prompt"
20 )
21 max_tokens: Optional[int] = Field(
22 None, description="The maximum number of tokens to generate"
23 )
24 stop: MaybeOneOrMany[str] = Field(None, description="The stop sequence(s)")
25 temperature: Optional[float] = Field(
26 None, description="The temperature for sampling"
27 )
28 top_p: Optional[float] = Field(None, description="The nucleus sampling parameter")
29 n: Optional[int] = Field(None, description="The number of choices to generate")
30 tools: Sequence[BaseTool] = Field([], description="The tools can be used")
31 tool_format: Literal["auto", "str"] = Field(
32 "auto", description="The format for the tools"
33 )
34 stream: Optional[bool] = Field(None, description="Whether to stream the results")
35 response_format: Optional[Union[dict, Type[BaseModel]]] = Field(
36 None, description="OpenAI's argument specifies the response format."
37 )
38 response_model: Optional[Type[BaseModel]] = Field(
39 None,
40 description="instructor's argument specifies the response format as a Pydantic model.",
41 )
43 def preprocess(self, convert_func: Callable, is_openai: bool = False) -> dict:
44 """Convert the GenArgs into a dictionary for creating the response."""
45 # build dict, filter out the None values
46 args = self.model_dump(exclude_none=True)
48 # messages
49 args["messages"] = convert_func(self.messages)
51 # format the tool
52 tools = self.tools
53 tool_format = args.pop("tool_format")
54 if len(tools):
55 if tool_format == "auto":
56 tool_format = "openai" if is_openai else "str"
57 formatted_tools = []
58 for tool in tools:
59 tool_str: Any = None
60 if tool_format == "openai":
61 tool_str = tool.openai_schema
62 else: # TODO: supports more formats
63 tool_str = str(tool)
64 formatted_tools.append(tool_str)
65 args["tools"] = formatted_tools
66 else:
67 args.pop("tools", None)
68 return args
71class BaseServer(ABC):
72 """The base class for all servers.
74 Servers are responsible for communicating with the underlying model.
75 """
77 @property
78 @abstractmethod
79 def model_name(self) -> str:
80 """The name of the model used by the server."""
81 raise NotImplementedError
83 @abstractmethod
84 def _get_create_args(self, args: GenArgs, **kwargs: Any) -> dict:
85 """Map the GenArgs to the arguments for creating the response."""
86 raise NotImplementedError
88 @abstractmethod
89 def _convert(self, conversation: Conversation) -> Any:
90 """Convert the conversation into prompt format for the model.
92 Args:
93 conversation (Conversation): The conversation to convert
95 Returns:
96 The prompt for the model in the format it expects
97 """
98 raise NotImplementedError
100 @abstractmethod
101 def _create(self, **kwargs: Any) -> CompletionResponse:
102 """Create a CompletionResponse from the model with processed arguments.
104 Args:
105 kwargs: The arguments to pass to the model.
107 Returns:
108 CompletionResponse: The response from the model.
109 """
110 raise NotImplementedError
112 def create(self, args: GenArgs, gen_id: str, **kwargs: Any) -> CompletionResponse:
113 """Create a CompletionResponse from the model with given arguments.
115 Args:
116 args: The arguments for generating the response
117 gen_id: The ID of the generation
118 **kwargs: Additional keyword arguments
119 Returns:
120 The response from the model.
121 """
122 create_args = self._get_create_args(args, **kwargs)
123 results = self._create(gen_id=gen_id, **create_args)
124 return results
126 @abstractmethod
127 def close(self):
128 """Close the server."""
129 raise NotImplementedError
132class DummyServer(BaseServer):
133 """A dummy server for testing purposes."""
135 @override
136 @property
137 def model_name(self) -> str:
138 return "_dummy"
140 def _get_create_args(self, args: GenArgs, **kwargs: Any) -> dict: # type: ignore
141 return kwargs
143 def _convert(self, conversation: Conversation) -> Any:
144 return conversation
146 def _create(self, **kwargs) -> CompletionResponse: # type: ignore
147 message = kwargs.get("mock_response", "This is a dummy response")
148 return CompletionResponse(message=message) # type: ignore
150 @override
151 def close(self):
152 pass