Coverage for tests/test_tool.py: 92%

87 statements  

« prev     ^ index     » next       coverage.py v7.6.7, created at 2024-11-22 15:39 -0800

1import time 

2from typing import Any 

3 

4from pydantic import BaseModel, Field, create_model 

5 

6import appl 

7from appl import Generation, as_tool, gen, ppl 

8from appl.core import CompletionResponse, ToolCall 

9 

10 

11def removed_keyword(d: Any, key: str) -> Any: 

12 if isinstance(d, list): 

13 return [removed_keyword(x, key) for x in d] 

14 if isinstance(d, dict): 

15 return {k: removed_keyword(v, key) for k, v in d.items() if k != key} 

16 return d 

17 

18 

19def get_openai_schema(include_desc=True, add_default=False): 

20 args_schema = { 

21 "type": "object", 

22 "properties": { 

23 "x": {"type": "integer"}, 

24 "y": {"type": "integer"}, 

25 }, 

26 "required": ["x", "y"], 

27 } 

28 if include_desc: 

29 args_schema["properties"]["x"]["description"] = "first number" 

30 args_schema["properties"]["y"]["description"] = "second number" 

31 if add_default: 

32 args_schema["properties"]["y"]["default"] = 1 

33 args_schema["required"].remove("y") 

34 schema = { 

35 "type": "function", 

36 "function": { 

37 "name": "add", 

38 "description": "Add two numbers together", 

39 "parameters": args_schema, 

40 }, 

41 } 

42 return schema 

43 

44 

45def test_as_tool(): 

46 def add(x: int, y: int) -> int: 

47 """Add two numbers together 

48 

49 Args: 

50 x (int): first number 

51 y (int): second number 

52 Returns: 

53 int: sum of x and y 

54 Raises: 

55 ValueError: if x or y is not an integer 

56 """ 

57 if not isinstance(x, int) or not isinstance(y, int): 

58 raise ValueError("x and y must be integers") 

59 return x + y 

60 

61 tool = as_tool(add) 

62 assert tool(1, 2) == 3 

63 assert tool.__doc__ == add.__doc__ 

64 assert removed_keyword(tool.openai_schema, "title") == get_openai_schema() 

65 assert removed_keyword(tool.returns.model_json_schema(), "title") == { 

66 "properties": { 

67 "returns": {"type": "integer", "description": "sum of x and y"}, 

68 }, 

69 "type": "object", 

70 "required": ["returns"], 

71 } 

72 assert tool.raises == [ 

73 {"type": "ValueError", "desc": "if x or y is not an integer"} 

74 ] 

75 

76 def add(x, y=1): 

77 """Add two numbers together 

78 

79 Args: 

80 x (int): first number 

81 y (int): second number 

82 """ 

83 

84 tool = as_tool(add) 

85 assert removed_keyword(tool.openai_schema, "title") == get_openai_schema( 

86 add_default=True 

87 ) 

88 

89 def add(x: int, y: int): 

90 """Add two numbers together""" 

91 

92 tool = as_tool(add) 

93 assert removed_keyword(tool.openai_schema, "title") == get_openai_schema( 

94 include_desc=False 

95 ) 

96 

97 

98class AddArgs(BaseModel): 

99 x: int = Field(..., description="first number") 

100 y: int = Field(1, description="second number") 

101 

102 

103def test_args_schema(): 

104 def add(args: AddArgs): 

105 """Add two numbers together""" 

106 return args.x + args.y 

107 

108 tool = as_tool(add) 

109 params = {} 

110 params["args"] = (AddArgs, ...) 

111 assert ( 

112 tool.params.model_json_schema() 

113 == create_model("parameters", **params).model_json_schema() 

114 ) 

115 

116 def add(args: AddArgs = AddArgs(x=1, y=2)): 

117 """Add two numbers together 

118 

119 Args: 

120 args (AddArgs): arguments 

121 """ 

122 return args.x + args.y 

123 

124 tool = as_tool(add) 

125 params = {} 

126 params["args"] = (AddArgs, Field(AddArgs(x=1, y=2), description="arguments")) 

127 assert ( 

128 tool.params.model_json_schema() 

129 == create_model("parameters", **params).model_json_schema() 

130 ) 

131 

132 

133def test_tool_call_sequential(): 

134 appl.init() 

135 response = CompletionResponse( 

136 tool_calls=[ 

137 ToolCall(id="1", name="add", args='{"x": 1, "y": 2, "t": 0.1}'), 

138 ToolCall(id="2", name="mul", args='{"x": 2, "y": 3}'), 

139 ToolCall(id="3", name="add", args='{"x": 3, "y": 4}'), 

140 ToolCall(id="4", name="add", args='{"x": 5, "y": 6}'), 

141 ], 

142 ) 

143 

144 call_args = [] 

145 

146 def add(x: int, y: int, t: float = 0.0) -> int: 

147 time.sleep(t) 

148 call_args.append((x, y)) 

149 return x + y 

150 

151 tools = [as_tool(add)] 

152 

153 def filter_fn(tool_calls: list[ToolCall]) -> list[ToolCall]: 

154 return [tc for tc in tool_calls if tc.name == "add"] 

155 

156 @ppl 

157 def func(): 

158 res = gen("_dummy", tools=tools, mock_response=response) 

159 return [x.get_content() for x in res.run_tool_calls(filter_fn=filter_fn)] 

160 

161 assert func() == [3, 7, 11] 

162 assert call_args == [(1, 2), (3, 4), (5, 6)] 

163 

164 

165def test_tool_call_parallel(): 

166 appl.init() 

167 response = CompletionResponse( 

168 tool_calls=[ 

169 ToolCall(id="1", name="add", args='{"x": 1, "y": 2}'), 

170 ToolCall(id="3", name="add", args='{"x": 3, "y": 4}'), 

171 ], 

172 ) 

173 

174 t = 0.2 

175 

176 def add(x: int, y: int) -> int: 

177 time.sleep(t) 

178 return x + y 

179 

180 tools = [as_tool(add)] 

181 

182 @ppl 

183 def func(): 

184 res = gen(tools=tools, mock_response=response) 

185 return [x.get_content() for x in res.run_tool_calls(parallel="thread")] 

186 

187 start_time = time.time() 

188 assert func() == [3, 7] 

189 assert time.time() - start_time < t + 0.1