Coverage for src/appl/core/tool.py: 91%
115 statements
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-22 15:39 -0800
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-22 15:39 -0800
1import inspect
2import json
3from abc import ABC, abstractmethod
4from inspect import Signature
5from typing import Any, Callable, Dict, List, Optional
7from docstring_parser import Docstring, parse
8from loguru import logger
9from openai.types.chat import ChatCompletionMessageToolCall
10from pydantic import BaseModel, ConfigDict, Field, create_model, field_serializer
12from .types import Image, String
14# from pydantic._internal._model_construction import ModelMetaclass
17class BaseTool(BaseModel, ABC):
18 """The base class for a Tool."""
20 model_config = ConfigDict(arbitrary_types_allowed=True)
22 name: str = Field(..., description="The name of the Tool")
23 """The name of the Tool."""
24 short_desc: str = Field("", description="The short description of the Tool")
25 """The short description of the Tool."""
26 long_desc: str = Field("", description="The long description of the Tool")
27 """The long description of the Tool."""
28 params: type[BaseModel] = Field(..., description="The parameters of the Tool")
29 """The parameters of the Tool."""
30 returns: type[BaseModel] = Field(..., description="The return of the Tool")
31 """The return of the Tool."""
32 raises: List[Dict[str, Optional[str]]] = Field(
33 [], description="The exceptions raised by the Tool"
34 )
35 """The exceptions raised by the Tool."""
36 examples: List[str] = Field([], description="The examples of the Tool")
37 """The examples of the Tool."""
38 info: Dict = Field({}, description="Additional information of the Tool")
39 """Additional information of the Tool."""
41 # TODO: add toolkit option
42 def __init__(self, func: Callable, **predefined: Any):
43 """Create a tool from a function."""
44 name = func.__name__
45 sig = inspect.signature(func)
46 doc = func.__doc__
47 super().__init__(name=name, **self.parse_data(sig, doc, predefined))
48 self._predefined = predefined
49 self._func = func
50 self.__name__ = name
51 self.__signature__ = sig # type: ignore
52 self.__doc__ = doc # overwrite the doc string
54 @classmethod
55 def parse_data(
56 cls, sig: Signature, docstring: Optional[str], predefined: Dict[str, Any]
57 ) -> Dict[str, Any]:
58 """Parse data from the signature and docstring of a function."""
59 doc = parse(docstring or "")
60 data: Dict[str, Any] = {
61 "short_desc": doc.short_description or "",
62 "long_desc": doc.long_description or "",
63 }
65 # build params
66 params = {}
67 doc_param = {p.arg_name: p for p in doc.params}
68 for name, param in sig.parameters.items():
69 anno = param.annotation
70 default = param.default
72 if default is param.empty:
73 default = ... # required
74 if name in doc_param:
75 # fill in desc for the param
76 default = Field(default, description=doc_param[name].description)
77 # fill in type annotation if not annotated in the function
78 if (anno is param.empty) and (doc_param[name].type_name is not None):
79 # use type annotation from docstring
80 anno = doc_param[name].type_name
81 # replace empty annotation with Any
82 if anno is param.empty:
83 anno = Any
84 if name not in predefined:
85 params[name] = (anno, default)
86 data["params"] = create_model("parameters", **params) # type: ignore
88 # build returns
89 anno = sig.return_annotation
90 if anno is sig.empty:
91 if (doc.returns is not None) and (doc.returns.type_name is not None):
92 # use type annotation from docstring
93 anno = doc.returns.type_name
94 else:
95 anno = Any
96 default = ... # required
97 if doc.returns is not None:
98 # fill in desc for the return
99 default = Field(..., description=doc.returns.description)
100 data["returns"] = create_model("returns", returns=(anno, default))
102 # build raises
103 data["raises"] = [
104 {"type": exc.type_name, "desc": exc.description} for exc in doc.raises
105 ]
107 # build examples
108 data["examples"] = doc.examples
109 return data
111 @property
112 def openai_schema(self) -> dict:
113 """Get the OpenAI schema of the tool."""
114 return {
115 "type": "function",
116 "function": {
117 "name": self.name,
118 "description": self._get_description(),
119 "parameters": self.params.model_json_schema(),
120 },
121 }
123 def to_str(self) -> str:
124 """Represent the tool as a string."""
125 s = f"def {self.name}{self.__signature__}:\n"
126 s += f' """{self.__doc__}"""'
127 return s
129 @abstractmethod
130 def _get_description(self) -> str:
131 raise NotImplementedError
133 @field_serializer("params", when_used="json")
134 def _serialize_params(self, params: type[BaseModel]) -> dict:
135 return params.model_json_schema()
137 @field_serializer("returns", when_used="json")
138 def _serialize_returns(self, returns: type[BaseModel]) -> dict:
139 return returns.model_json_schema()
141 def __str__(self) -> str:
142 return self.to_str()
144 def _call(self, *args: Any, **kwargs: Any) -> Any:
145 kwargs.update(self._predefined) # use predefined kwargs
146 return self._func(*args, **kwargs)
148 __call__ = _call
151class Tool(BaseTool):
152 """The Tool class that can be called by LLMs."""
154 def __init__(self, func: Callable, use_short_desc: bool = False, **kwargs: Any):
155 """Create a tool from a function.
157 Args:
158 func: The function to create the tool from.
159 use_short_desc:
160 Whether to use the short description instead of the full description.
161 kwargs: Additional arguments for the tool.
162 """
163 super().__init__(func=func, **kwargs)
164 self._use_short_desc = use_short_desc
166 def _get_description(self):
167 if not self.short_desc:
168 logger.warning(f"Tool {self.name} has no description.")
169 return self.name
171 if (not self.long_desc) or self._use_short_desc:
172 return self.short_desc
174 # use full desc
175 return self.short_desc + "\n\n" + self.long_desc
178class ToolCall(BaseModel):
179 """The class representing a tool call."""
181 id: str = Field(..., description="The ID of the tool call.")
182 """The ID of the tool call."""
183 name: str = Field(..., description="The name of the function to call.")
184 """The name of the function to call."""
185 args: str = Field(..., description="The arguments to call the function with.")
186 """The arguments to call the function with."""
188 def get_dict(self):
189 """Get the OpenAI format dictionary representation of the tool call."""
190 return {
191 "id": self.id,
192 "type": "function",
193 "function": {
194 "name": self.name,
195 "arguments": self.args,
196 },
197 }
199 @classmethod
200 def from_dict(cls, call: Dict) -> "ToolCall":
201 """Create a ToolCall from a dictionary in the OpenAI format."""
202 # throw error if incorrect format
203 return cls(
204 id=call["id"],
205 name=call["function"]["name"],
206 args=call["function"]["arguments"],
207 )
209 @classmethod
210 def from_openai_tool_call(cls, call: ChatCompletionMessageToolCall) -> "ToolCall":
211 """Create a ToolCall from an OpenAI tool call."""
212 return cls(
213 id=call.id,
214 name=call.function.name,
215 args=call.function.arguments,
216 )