Coverage for src/appl/core/tool.py: 91%

115 statements  

« prev     ^ index     » next       coverage.py v7.6.7, created at 2024-11-22 15:39 -0800

1import inspect 

2import json 

3from abc import ABC, abstractmethod 

4from inspect import Signature 

5from typing import Any, Callable, Dict, List, Optional 

6 

7from docstring_parser import Docstring, parse 

8from loguru import logger 

9from openai.types.chat import ChatCompletionMessageToolCall 

10from pydantic import BaseModel, ConfigDict, Field, create_model, field_serializer 

11 

12from .types import Image, String 

13 

14# from pydantic._internal._model_construction import ModelMetaclass 

15 

16 

17class BaseTool(BaseModel, ABC): 

18 """The base class for a Tool.""" 

19 

20 model_config = ConfigDict(arbitrary_types_allowed=True) 

21 

22 name: str = Field(..., description="The name of the Tool") 

23 """The name of the Tool.""" 

24 short_desc: str = Field("", description="The short description of the Tool") 

25 """The short description of the Tool.""" 

26 long_desc: str = Field("", description="The long description of the Tool") 

27 """The long description of the Tool.""" 

28 params: type[BaseModel] = Field(..., description="The parameters of the Tool") 

29 """The parameters of the Tool.""" 

30 returns: type[BaseModel] = Field(..., description="The return of the Tool") 

31 """The return of the Tool.""" 

32 raises: List[Dict[str, Optional[str]]] = Field( 

33 [], description="The exceptions raised by the Tool" 

34 ) 

35 """The exceptions raised by the Tool.""" 

36 examples: List[str] = Field([], description="The examples of the Tool") 

37 """The examples of the Tool.""" 

38 info: Dict = Field({}, description="Additional information of the Tool") 

39 """Additional information of the Tool.""" 

40 

41 # TODO: add toolkit option 

42 def __init__(self, func: Callable, **predefined: Any): 

43 """Create a tool from a function.""" 

44 name = func.__name__ 

45 sig = inspect.signature(func) 

46 doc = func.__doc__ 

47 super().__init__(name=name, **self.parse_data(sig, doc, predefined)) 

48 self._predefined = predefined 

49 self._func = func 

50 self.__name__ = name 

51 self.__signature__ = sig # type: ignore 

52 self.__doc__ = doc # overwrite the doc string 

53 

54 @classmethod 

55 def parse_data( 

56 cls, sig: Signature, docstring: Optional[str], predefined: Dict[str, Any] 

57 ) -> Dict[str, Any]: 

58 """Parse data from the signature and docstring of a function.""" 

59 doc = parse(docstring or "") 

60 data: Dict[str, Any] = { 

61 "short_desc": doc.short_description or "", 

62 "long_desc": doc.long_description or "", 

63 } 

64 

65 # build params 

66 params = {} 

67 doc_param = {p.arg_name: p for p in doc.params} 

68 for name, param in sig.parameters.items(): 

69 anno = param.annotation 

70 default = param.default 

71 

72 if default is param.empty: 

73 default = ... # required 

74 if name in doc_param: 

75 # fill in desc for the param 

76 default = Field(default, description=doc_param[name].description) 

77 # fill in type annotation if not annotated in the function 

78 if (anno is param.empty) and (doc_param[name].type_name is not None): 

79 # use type annotation from docstring 

80 anno = doc_param[name].type_name 

81 # replace empty annotation with Any 

82 if anno is param.empty: 

83 anno = Any 

84 if name not in predefined: 

85 params[name] = (anno, default) 

86 data["params"] = create_model("parameters", **params) # type: ignore 

87 

88 # build returns 

89 anno = sig.return_annotation 

90 if anno is sig.empty: 

91 if (doc.returns is not None) and (doc.returns.type_name is not None): 

92 # use type annotation from docstring 

93 anno = doc.returns.type_name 

94 else: 

95 anno = Any 

96 default = ... # required 

97 if doc.returns is not None: 

98 # fill in desc for the return 

99 default = Field(..., description=doc.returns.description) 

100 data["returns"] = create_model("returns", returns=(anno, default)) 

101 

102 # build raises 

103 data["raises"] = [ 

104 {"type": exc.type_name, "desc": exc.description} for exc in doc.raises 

105 ] 

106 

107 # build examples 

108 data["examples"] = doc.examples 

109 return data 

110 

111 @property 

112 def openai_schema(self) -> dict: 

113 """Get the OpenAI schema of the tool.""" 

114 return { 

115 "type": "function", 

116 "function": { 

117 "name": self.name, 

118 "description": self._get_description(), 

119 "parameters": self.params.model_json_schema(), 

120 }, 

121 } 

122 

123 def to_str(self) -> str: 

124 """Represent the tool as a string.""" 

125 s = f"def {self.name}{self.__signature__}:\n" 

126 s += f' """{self.__doc__}"""' 

127 return s 

128 

129 @abstractmethod 

130 def _get_description(self) -> str: 

131 raise NotImplementedError 

132 

133 @field_serializer("params", when_used="json") 

134 def _serialize_params(self, params: type[BaseModel]) -> dict: 

135 return params.model_json_schema() 

136 

137 @field_serializer("returns", when_used="json") 

138 def _serialize_returns(self, returns: type[BaseModel]) -> dict: 

139 return returns.model_json_schema() 

140 

141 def __str__(self) -> str: 

142 return self.to_str() 

143 

144 def _call(self, *args: Any, **kwargs: Any) -> Any: 

145 kwargs.update(self._predefined) # use predefined kwargs 

146 return self._func(*args, **kwargs) 

147 

148 __call__ = _call 

149 

150 

151class Tool(BaseTool): 

152 """The Tool class that can be called by LLMs.""" 

153 

154 def __init__(self, func: Callable, use_short_desc: bool = False, **kwargs: Any): 

155 """Create a tool from a function. 

156 

157 Args: 

158 func: The function to create the tool from. 

159 use_short_desc: 

160 Whether to use the short description instead of the full description. 

161 kwargs: Additional arguments for the tool. 

162 """ 

163 super().__init__(func=func, **kwargs) 

164 self._use_short_desc = use_short_desc 

165 

166 def _get_description(self): 

167 if not self.short_desc: 

168 logger.warning(f"Tool {self.name} has no description.") 

169 return self.name 

170 

171 if (not self.long_desc) or self._use_short_desc: 

172 return self.short_desc 

173 

174 # use full desc 

175 return self.short_desc + "\n\n" + self.long_desc 

176 

177 

178class ToolCall(BaseModel): 

179 """The class representing a tool call.""" 

180 

181 id: str = Field(..., description="The ID of the tool call.") 

182 """The ID of the tool call.""" 

183 name: str = Field(..., description="The name of the function to call.") 

184 """The name of the function to call.""" 

185 args: str = Field(..., description="The arguments to call the function with.") 

186 """The arguments to call the function with.""" 

187 

188 def get_dict(self): 

189 """Get the OpenAI format dictionary representation of the tool call.""" 

190 return { 

191 "id": self.id, 

192 "type": "function", 

193 "function": { 

194 "name": self.name, 

195 "arguments": self.args, 

196 }, 

197 } 

198 

199 @classmethod 

200 def from_dict(cls, call: Dict) -> "ToolCall": 

201 """Create a ToolCall from a dictionary in the OpenAI format.""" 

202 # throw error if incorrect format 

203 return cls( 

204 id=call["id"], 

205 name=call["function"]["name"], 

206 args=call["function"]["arguments"], 

207 ) 

208 

209 @classmethod 

210 def from_openai_tool_call(cls, call: ChatCompletionMessageToolCall) -> "ToolCall": 

211 """Create a ToolCall from an OpenAI tool call.""" 

212 return cls( 

213 id=call.id, 

214 name=call.function.name, 

215 args=call.function.arguments, 

216 )