Coverage for tests/test_message.py: 55%

138 statements  

« prev     ^ index     » next       coverage.py v7.6.7, created at 2024-11-22 15:39 -0800

1import pytest 

2 

3import appl 

4from appl import AIRole, SystemRole, UserRole, ppl, records 

5from appl.core import ( 

6 AIMessage, 

7 Conversation, 

8 PromptRecords, 

9 SystemMessage, 

10 ToolCall, 

11 UserMessage, 

12) 

13 

14 

15def test_message(): 

16 sys = SystemMessage("You are a helpful assistant.") 

17 user = UserMessage("Hello, who are you") 

18 ai = AIMessage("I am a helpful assistant.") 

19 assert sys.is_system 

20 assert user.is_user 

21 assert ai.is_ai 

22 assert sys.get_dict() == { 

23 "role": "system", 

24 "content": "You are a helpful assistant.", 

25 } 

26 assert user.get_dict() == {"role": "user", "content": "Hello, who are you"} 

27 assert ai.get_dict() == { 

28 "role": "assistant", 

29 "content": "I am a helpful assistant.", 

30 } 

31 

32 

33def test_msg_role(): 

34 @ppl 

35 def func() -> PromptRecords: 

36 "Hello" 

37 AIMessage("World") 

38 return records() 

39 

40 assert func().as_convo().as_list() == [ 

41 {"role": "user", "content": "Hello"}, 

42 {"role": "assistant", "content": "World"}, 

43 ] 

44 

45 

46def test_chat_example(): 

47 @ppl 

48 def func() -> PromptRecords: 

49 SystemMessage("You are a helpful assistant.") 

50 for i in range(5): 

51 UserMessage("Hello") 

52 AIMessage("World") 

53 return records() 

54 

55 expected = [{"role": "system", "content": "You are a helpful assistant."}] 

56 for i in range(5): 

57 expected.append({"role": "user", "content": "Hello"}) 

58 expected.append({"role": "assistant", "content": "World"}) 

59 assert func().as_convo().as_list() == expected 

60 

61 

62def test_system_msg(): 

63 @ppl 

64 def func() -> PromptRecords: 

65 with SystemRole(): 

66 "You are a helpful Assistant." 

67 "You should be helpful." 

68 with UserRole(): 

69 "Hello" 

70 return records() 

71 

72 assert func().as_convo().as_list() == [ 

73 { 

74 "role": "system", 

75 "content": "You are a helpful Assistant.\nYou should be helpful.", 

76 }, 

77 {"role": "user", "content": "Hello"}, 

78 ] 

79 

80 

81def test_message_merge(): 

82 @ppl 

83 def func() -> PromptRecords: 

84 SystemMessage("You are a helpful assistant.") 

85 UserMessage("Hello, who are you") 

86 AIMessage("I am ") 

87 AIMessage("a helpful ") 

88 AIMessage("assistant.") 

89 UserMessage("H") 

90 UserMessage("i") 

91 SystemMessage(" Add something to system prompt.") 

92 return records() 

93 

94 assert func().as_convo().as_list() == [ 

95 { 

96 "role": "system", 

97 "content": "You are a helpful assistant. Add something to system prompt.", 

98 }, 

99 {"role": "user", "content": "Hello, who are you"}, 

100 {"role": "assistant", "content": "I am a helpful assistant."}, 

101 {"role": "user", "content": "Hi"}, 

102 ] 

103 

104 

105def test_ai_tool_call(): 

106 @ppl 

107 def func() -> PromptRecords: 

108 AIMessage(tool_calls=[ToolCall(id="1", name="add", args='{"x": 1, "y": 2}')]) 

109 return records() 

110 

111 assert func().as_convo().as_list() == [ 

112 { 

113 "role": "assistant", 

114 "content": None, 

115 "tool_calls": [ 

116 { 

117 "id": "1", 

118 "type": "function", 

119 "function": {"name": "add", "arguments": '{"x": 1, "y": 2}'}, 

120 } 

121 ], 

122 }, 

123 ] 

124 

125 

126def test_role_wrap(): 

127 @ppl 

128 def func() -> PromptRecords: 

129 with AIRole(): 

130 UserMessage("Hello") 

131 AIMessage("World") 

132 return records() 

133 

134 assert func().as_convo().as_list() == [ 

135 {"role": "user", "content": "Hello"}, 

136 {"role": "assistant", "content": "World"}, 

137 ] 

138 

139 

140def test_mix_role_usage(): 

141 @ppl 

142 def func() -> PromptRecords: 

143 with AIRole(): 

144 UserMessage("Hello") 

145 AIMessage("World") 

146 "Again" 

147 return records() 

148 

149 assert func().as_convo().as_list() == [ 

150 {"role": "user", "content": "Hello"}, 

151 {"role": "assistant", "content": "WorldAgain"}, 

152 ] 

153 

154 

155def test_multi_user(): 

156 @ppl 

157 def func() -> PromptRecords: 

158 with UserRole("A"): 

159 "Hello" 

160 with UserRole("B"): 

161 "World" 

162 return records() 

163 

164 assert func().as_convo().as_list() == [ 

165 {"role": "user", "name": "A", "content": "Hello"}, 

166 {"role": "user", "name": "B", "content": "World"}, 

167 ] 

168 

169 

170def test_role_change(): 

171 @ppl 

172 def func() -> PromptRecords: 

173 "Hello" 

174 with AIRole(): 

175 "World" 

176 return records() 

177 

178 assert func().as_convo().as_list() == [ 

179 {"role": "user", "content": "Hello"}, 

180 {"role": "assistant", "content": "World"}, 

181 ] 

182 

183 

184def test_role_change_loop1(): 

185 @ppl 

186 def func() -> PromptRecords: 

187 for i in range(5): 

188 "Hello" 

189 with AIRole(): 

190 "World" 

191 return records() 

192 

193 assert ( 

194 func().as_convo().as_list() 

195 == [ 

196 {"role": "user", "content": "Hello"}, 

197 {"role": "assistant", "content": "World"}, 

198 ] 

199 * 5 

200 ) 

201 

202 

203def test_role_change_loop2(): 

204 @ppl 

205 def func() -> PromptRecords: 

206 for i in range(3): 

207 "Hello" 

208 AIMessage("World") 

209 return records() 

210 

211 assert ( 

212 func().as_convo().as_list() 

213 == [ 

214 {"role": "user", "content": "Hello"}, 

215 {"role": "assistant", "content": "World"}, 

216 ] 

217 * 3 

218 ) 

219 

220 

221def test_role_change_loop3(): 

222 @ppl 

223 def func() -> PromptRecords: 

224 "Hello" 

225 for i in range(5): 

226 with AIRole(): 

227 "Foo" 

228 with UserRole(): 

229 "Bar" 

230 return records() 

231 

232 expected = [{"role": "user", "content": "Hello"}] 

233 for i in range(5): 

234 expected.append({"role": "assistant", "content": "Foo"}) 

235 expected.append({"role": "user", "content": "Bar"}) 

236 

237 assert func().as_convo().as_list() == expected 

238 

239 

240def test_nested_role_error(): 

241 with pytest.raises(ValueError) as excinfo: 

242 

243 @ppl 

244 def func() -> PromptRecords: 

245 "Hello" 

246 for i in range(5): 

247 with UserRole(): 

248 "Foo" 

249 with AIRole(): 

250 "Bar" 

251 return records() 

252 

253 func() 

254 

255 assert "Cannot start a new role" in str(excinfo.value)