Coverage for src/appl/core/server.py: 95%

60 statements  

« prev     ^ index     » next       coverage.py v7.6.7, created at 2024-11-22 15:39 -0800

1from abc import ABC, abstractmethod 

2from typing import Any, Callable, Literal, Optional, Sequence, Type, Union 

3 

4from loguru import logger 

5from pydantic import BaseModel, Field 

6from typing_extensions import override 

7 

8from .message import Conversation 

9from .response import CompletionResponse 

10from .tool import BaseTool, ToolCall 

11from .types import MaybeOneOrMany 

12 

13 

14class GenArgs(BaseModel): 

15 """Common arguments for generating a response from a model.""" 

16 

17 model: str = Field(..., description="The name of the backend model") 

18 messages: Conversation = Field( 

19 ..., description="The conversation to use as a prompt" 

20 ) 

21 max_tokens: Optional[int] = Field( 

22 None, description="The maximum number of tokens to generate" 

23 ) 

24 stop: MaybeOneOrMany[str] = Field(None, description="The stop sequence(s)") 

25 temperature: Optional[float] = Field( 

26 None, description="The temperature for sampling" 

27 ) 

28 top_p: Optional[float] = Field(None, description="The nucleus sampling parameter") 

29 n: Optional[int] = Field(None, description="The number of choices to generate") 

30 tools: Sequence[BaseTool] = Field([], description="The tools can be used") 

31 tool_format: Literal["auto", "str"] = Field( 

32 "auto", description="The format for the tools" 

33 ) 

34 stream: Optional[bool] = Field(None, description="Whether to stream the results") 

35 response_format: Optional[Union[dict, Type[BaseModel]]] = Field( 

36 None, description="OpenAI's argument specifies the response format." 

37 ) 

38 response_model: Optional[Type[BaseModel]] = Field( 

39 None, 

40 description="instructor's argument specifies the response format as a Pydantic model.", 

41 ) 

42 

43 def preprocess(self, convert_func: Callable, is_openai: bool = False) -> dict: 

44 """Convert the GenArgs into a dictionary for creating the response.""" 

45 # build dict, filter out the None values 

46 args = self.model_dump(exclude_none=True) 

47 

48 # messages 

49 args["messages"] = convert_func(self.messages) 

50 

51 # format the tool 

52 tools = self.tools 

53 tool_format = args.pop("tool_format") 

54 if len(tools): 

55 if tool_format == "auto": 

56 tool_format = "openai" if is_openai else "str" 

57 formatted_tools = [] 

58 for tool in tools: 

59 tool_str: Any = None 

60 if tool_format == "openai": 

61 tool_str = tool.openai_schema 

62 else: # TODO: supports more formats 

63 tool_str = str(tool) 

64 formatted_tools.append(tool_str) 

65 args["tools"] = formatted_tools 

66 else: 

67 args.pop("tools", None) 

68 return args 

69 

70 

71class BaseServer(ABC): 

72 """The base class for all servers. 

73 

74 Servers are responsible for communicating with the underlying model. 

75 """ 

76 

77 @property 

78 @abstractmethod 

79 def model_name(self) -> str: 

80 """The name of the model used by the server.""" 

81 raise NotImplementedError 

82 

83 @abstractmethod 

84 def _get_create_args(self, args: GenArgs, **kwargs: Any) -> dict: 

85 """Map the GenArgs to the arguments for creating the response.""" 

86 raise NotImplementedError 

87 

88 @abstractmethod 

89 def _convert(self, conversation: Conversation) -> Any: 

90 """Convert the conversation into prompt format for the model. 

91 

92 Args: 

93 conversation (Conversation): The conversation to convert 

94 

95 Returns: 

96 The prompt for the model in the format it expects 

97 """ 

98 raise NotImplementedError 

99 

100 @abstractmethod 

101 def _create(self, **kwargs: Any) -> CompletionResponse: 

102 """Create a CompletionResponse from the model with processed arguments. 

103 

104 Args: 

105 kwargs: The arguments to pass to the model. 

106 

107 Returns: 

108 CompletionResponse: The response from the model. 

109 """ 

110 raise NotImplementedError 

111 

112 def create(self, args: GenArgs, gen_id: str, **kwargs: Any) -> CompletionResponse: 

113 """Create a CompletionResponse from the model with given arguments. 

114 

115 Args: 

116 args: The arguments for generating the response 

117 gen_id: The ID of the generation 

118 **kwargs: Additional keyword arguments 

119 Returns: 

120 The response from the model. 

121 """ 

122 create_args = self._get_create_args(args, **kwargs) 

123 results = self._create(gen_id=gen_id, **create_args) 

124 return results 

125 

126 @abstractmethod 

127 def close(self): 

128 """Close the server.""" 

129 raise NotImplementedError 

130 

131 

132class DummyServer(BaseServer): 

133 """A dummy server for testing purposes.""" 

134 

135 @override 

136 @property 

137 def model_name(self) -> str: 

138 return "_dummy" 

139 

140 def _get_create_args(self, args: GenArgs, **kwargs: Any) -> dict: # type: ignore 

141 return kwargs 

142 

143 def _convert(self, conversation: Conversation) -> Any: 

144 return conversation 

145 

146 def _create(self, **kwargs) -> CompletionResponse: # type: ignore 

147 message = kwargs.get("mock_response", "This is a dummy response") 

148 return CompletionResponse(message=message) # type: ignore 

149 

150 @override 

151 def close(self): 

152 pass