Coverage for tests/test_tool.py: 92%
87 statements
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-22 15:39 -0800
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-22 15:39 -0800
1import time
2from typing import Any
4from pydantic import BaseModel, Field, create_model
6import appl
7from appl import Generation, as_tool, gen, ppl
8from appl.core import CompletionResponse, ToolCall
11def removed_keyword(d: Any, key: str) -> Any:
12 if isinstance(d, list):
13 return [removed_keyword(x, key) for x in d]
14 if isinstance(d, dict):
15 return {k: removed_keyword(v, key) for k, v in d.items() if k != key}
16 return d
19def get_openai_schema(include_desc=True, add_default=False):
20 args_schema = {
21 "type": "object",
22 "properties": {
23 "x": {"type": "integer"},
24 "y": {"type": "integer"},
25 },
26 "required": ["x", "y"],
27 }
28 if include_desc:
29 args_schema["properties"]["x"]["description"] = "first number"
30 args_schema["properties"]["y"]["description"] = "second number"
31 if add_default:
32 args_schema["properties"]["y"]["default"] = 1
33 args_schema["required"].remove("y")
34 schema = {
35 "type": "function",
36 "function": {
37 "name": "add",
38 "description": "Add two numbers together",
39 "parameters": args_schema,
40 },
41 }
42 return schema
45def test_as_tool():
46 def add(x: int, y: int) -> int:
47 """Add two numbers together
49 Args:
50 x (int): first number
51 y (int): second number
52 Returns:
53 int: sum of x and y
54 Raises:
55 ValueError: if x or y is not an integer
56 """
57 if not isinstance(x, int) or not isinstance(y, int):
58 raise ValueError("x and y must be integers")
59 return x + y
61 tool = as_tool(add)
62 assert tool(1, 2) == 3
63 assert tool.__doc__ == add.__doc__
64 assert removed_keyword(tool.openai_schema, "title") == get_openai_schema()
65 assert removed_keyword(tool.returns.model_json_schema(), "title") == {
66 "properties": {
67 "returns": {"type": "integer", "description": "sum of x and y"},
68 },
69 "type": "object",
70 "required": ["returns"],
71 }
72 assert tool.raises == [
73 {"type": "ValueError", "desc": "if x or y is not an integer"}
74 ]
76 def add(x, y=1):
77 """Add two numbers together
79 Args:
80 x (int): first number
81 y (int): second number
82 """
84 tool = as_tool(add)
85 assert removed_keyword(tool.openai_schema, "title") == get_openai_schema(
86 add_default=True
87 )
89 def add(x: int, y: int):
90 """Add two numbers together"""
92 tool = as_tool(add)
93 assert removed_keyword(tool.openai_schema, "title") == get_openai_schema(
94 include_desc=False
95 )
98class AddArgs(BaseModel):
99 x: int = Field(..., description="first number")
100 y: int = Field(1, description="second number")
103def test_args_schema():
104 def add(args: AddArgs):
105 """Add two numbers together"""
106 return args.x + args.y
108 tool = as_tool(add)
109 params = {}
110 params["args"] = (AddArgs, ...)
111 assert (
112 tool.params.model_json_schema()
113 == create_model("parameters", **params).model_json_schema()
114 )
116 def add(args: AddArgs = AddArgs(x=1, y=2)):
117 """Add two numbers together
119 Args:
120 args (AddArgs): arguments
121 """
122 return args.x + args.y
124 tool = as_tool(add)
125 params = {}
126 params["args"] = (AddArgs, Field(AddArgs(x=1, y=2), description="arguments"))
127 assert (
128 tool.params.model_json_schema()
129 == create_model("parameters", **params).model_json_schema()
130 )
133def test_tool_call_sequential():
134 appl.init()
135 response = CompletionResponse(
136 tool_calls=[
137 ToolCall(id="1", name="add", args='{"x": 1, "y": 2, "t": 0.1}'),
138 ToolCall(id="2", name="mul", args='{"x": 2, "y": 3}'),
139 ToolCall(id="3", name="add", args='{"x": 3, "y": 4}'),
140 ToolCall(id="4", name="add", args='{"x": 5, "y": 6}'),
141 ],
142 )
144 call_args = []
146 def add(x: int, y: int, t: float = 0.0) -> int:
147 time.sleep(t)
148 call_args.append((x, y))
149 return x + y
151 tools = [as_tool(add)]
153 def filter_fn(tool_calls: list[ToolCall]) -> list[ToolCall]:
154 return [tc for tc in tool_calls if tc.name == "add"]
156 @ppl
157 def func():
158 res = gen("_dummy", tools=tools, mock_response=response)
159 return [x.get_content() for x in res.run_tool_calls(filter_fn=filter_fn)]
161 assert func() == [3, 7, 11]
162 assert call_args == [(1, 2), (3, 4), (5, 6)]
165def test_tool_call_parallel():
166 appl.init()
167 response = CompletionResponse(
168 tool_calls=[
169 ToolCall(id="1", name="add", args='{"x": 1, "y": 2}'),
170 ToolCall(id="3", name="add", args='{"x": 3, "y": 4}'),
171 ],
172 )
174 t = 0.2
176 def add(x: int, y: int) -> int:
177 time.sleep(t)
178 return x + y
180 tools = [as_tool(add)]
182 @ppl
183 def func():
184 res = gen(tools=tools, mock_response=response)
185 return [x.get_content() for x in res.run_tool_calls(parallel="thread")]
187 start_time = time.time()
188 assert func() == [3, 7]
189 assert time.time() - start_time < t + 0.1