Coverage for tests/test_prompt.py: 67%

224 statements  

« prev     ^ index     » next       coverage.py v7.6.7, created at 2024-11-22 15:39 -0800

1import appl 

2import pytest 

3from appl import Generation, convo, define, gen, need_ctx, ppl, records 

4from appl.compositor import * 

5 

6 

7def test_return(): 

8 @ppl 

9 def func(): 

10 "Hello World" 

11 return "answer" 

12 

13 assert func() == "answer" 

14 

15 

16def test_prompt(): 

17 @ppl 

18 def func(_ctx): 

19 "Hello World" 

20 return records() 

21 

22 assert str(func()) == "Hello World" 

23 

24 

25def test_fstring(): 

26 @ppl 

27 def f1(): 

28 f"a is {1!r}" 

29 return records() 

30 

31 assert str(f1()) == f"a is {1!r}" 

32 

33 @ppl 

34 def f2(): 

35 f"a is {3.1415:.2f}" 

36 return records() 

37 

38 assert str(f2()) == f"a is {3.1415:.2f}" 

39 

40 

41def test_prompts_change(): 

42 @ppl 

43 def func(): 

44 "Hello" 

45 ret1 = records() # the reference 

46 ret2 = records().copy() # the copy of the current prompt 

47 "World" 

48 ret3 = records() 

49 return ret1, ret2, ret3 

50 

51 ret1, ret2, ret3 = func() 

52 assert str(ret1) == "Hello\nWorld" 

53 assert str(ret2) == "Hello" 

54 assert str(ret3) == "Hello\nWorld" 

55 

56 

57def test_return_prompt(): 

58 @ppl(default_return="prompt") 

59 def f1(): 

60 "Hello World" 

61 

62 assert str(f1()) == "Hello World" 

63 

64 @ppl(default_return="prompt") 

65 def f2(): 

66 "Hello World" 

67 return "answer" 

68 

69 # The return is unchanged. 

70 assert str(f2()) == "answer" 

71 

72 

73def test_record(): 

74 @ppl 

75 def f2(): 

76 "Hello" 

77 "World" 

78 return records() 

79 

80 @ppl 

81 def func(): 

82 with NumberedList(): 

83 "first line" 

84 "second line" 

85 f2() # add the prompts from f2, following the current format. 

86 return records() 

87 

88 assert str(func()) == f"1. first line\n2. second line\n3. Hello\n4. World" 

89 

90 

91def test_inner_func(): 

92 @ppl 

93 def func(): 

94 "Hello" 

95 

96 def func2(): # the inner function use the same context from the outer function. 

97 "World" 

98 

99 func2() 

100 return records() 

101 

102 assert str(func()) == "Hello\nWorld" 

103 

104 

105def test_tripple_quote(): 

106 @ppl 

107 def func1(): 

108 """This is a docstring""" 

109 """ 

110 begin 

111 1. first 

112 2. second 

113 """ 

114 return records() 

115 

116 @ppl 

117 def func2(): 

118 """This is a docstring""" 

119 # Note that dedent will remove the leading spaces. 

120 """begin 

121 1. first 

122 2. second 

123 """ 

124 return records() 

125 

126 @ppl 

127 def func3(): 

128 """This is a docstring""" 

129 # Not recommended 

130 """ 

131 begin 

132 1. first 

133 2. second 

134 """ # note the leading spaces here 

135 return records() 

136 

137 @ppl 

138 def func4(): 

139 """This is a docstring""" 

140 # Not recommended, but still works. 

141 """begin 

1421. first 

1432. second 

144""" 

145 return records() 

146 

147 assert str(func1()) == "begin\n 1. first\n 2. second" 

148 assert str(func2()) == "begin\n1. first\n2. second" 

149 assert str(func3()) == "begin\n 1. first\n 2. second\n " 

150 assert str(func4()) == "begin\n1. first\n2. second" 

151 

152 

153def test_tripple_quote_fstring(): 

154 @ppl 

155 def func1(): 

156 x = "end" 

157 f""" 

158 begin 

159 {x} 

160 """ 

161 return records() 

162 

163 @ppl 

164 def func2(): 

165 x = f""" 

166 begin 

167 Hello 

168  

169 World 

170 """ 

171 f""" 

172 {x} 

173 end 

174 """ 

175 return records() 

176 

177 @ppl 

178 def func3(): 

179 f""" 

180 1+1={1 + \ 

181 1 

182 } 

183 """ 

184 # The recovered code from libcst will become: 

185 # f""" 

186 # 1+1={1 + \ 

187 # 1 

188 # } 

189 # """ 

190 return records() 

191 

192 assert str(func1()) == "begin\nend" 

193 assert str(func2()) == "begin\n Hello\n\n World\nend" 

194 assert str(func3()) == "1+1=2" 

195 

196 

197def test_include_docstring(): 

198 @ppl(include_docstring=True) 

199 def func(): 

200 """This is a docstring""" 

201 "Hello" 

202 return records() 

203 

204 assert str(func()) == "This is a docstring\nHello" 

205 

206 

207def test_include_multiline_docstring(): 

208 @ppl(include_docstring=True) 

209 def func(): 

210 """This is a 

211 multiline docstring""" 

212 

213 "Hello" 

214 return records() 

215 

216 assert str(func()) == "This is a\nmultiline docstring\nHello" 

217 

218 @ppl(include_docstring=True) 

219 def func2(): 

220 """ 

221 This is a 

222 multiline docstring 

223 """ 

224 return records() 

225 

226 assert str(func2()) == "This is a\n multiline docstring" 

227 

228 

229def test_default_no_docstring(): 

230 @ppl() 

231 def func(): 

232 """This is a docstring""" 

233 "Hello" 

234 return records() 

235 

236 @ppl() 

237 def func2(): 

238 """Same string""" # the first is docstring 

239 # the second string is not a docstring anymore, should be included 

240 """Same string""" 

241 return records() 

242 

243 assert str(func()) == "Hello" 

244 assert str(func2()) == "Same string" 

245 

246 

247def test_copy_ctx(): 

248 @ppl(ctx="copy") 

249 def addon(): 

250 "World" 

251 return str(convo()) 

252 

253 @ppl 

254 def func2(): 

255 "Hello" 

256 first = addon() 

257 second = addon() 

258 return first, second, records() 

259 

260 first, second, origin = func2() 

261 assert first == "Hello\nWorld" 

262 assert second == "Hello\nWorld" 

263 assert str(origin) == "Hello" 

264 

265 

266def test_resume_ctx(): 

267 @ppl(ctx="resume") 

268 def resume_ctx(): 

269 "Hello" 

270 return convo() 

271 

272 target = [] 

273 for i in range(3): 

274 res = resume_ctx() 

275 target += ["Hello"] 

276 assert str(res) == "\n".join(target) 

277 

278 

279def test_class_resume_ctx(): 

280 class A: 

281 @ppl(ctx="resume") 

282 def append(self, msg: str): 

283 msg 

284 return convo() 

285 

286 @classmethod 

287 @ppl(ctx="resume") 

288 def append_cls(cls, msg: str): 

289 msg 

290 return convo() 

291 

292 a = A() 

293 b = A() 

294 target_a = [] 

295 target_b = [] 

296 target_cls = [] 

297 for i in range(3): 

298 res = a.append("Hello") 

299 target_a += ["Hello"] 

300 assert str(res) == "\n".join(target_a) 

301 res = b.append("World") 

302 target_b += ["World"] 

303 assert str(res) == "\n".join(target_b) 

304 res = A.append_cls("Class") 

305 target_cls += ["Class"] 

306 assert str(res) == "\n".join(target_cls) 

307 

308 

309def test_class_func(): 

310 class ComplexPrompt: 

311 def __init__(self, condition: str): 

312 self._condition = condition 

313 

314 @ppl(ctx="same") 

315 def sub1(self): 

316 if self._condition: 

317 "sub1, condition is true" 

318 else: 

319 "sub1, condition is false" 

320 

321 @ppl(ctx="same") 

322 def sub2(self): 

323 if self._condition: 

324 "sub2, condition is true" 

325 else: 

326 "sub2, condition is false" 

327 

328 @ppl 

329 def func(self): 

330 self.sub1() 

331 self.sub2() 

332 return records() 

333 

334 prompt1 = ComplexPrompt(False).func() 

335 prompt2 = ComplexPrompt(True).func() 

336 assert str(prompt1) == "sub1, condition is false\nsub2, condition is false" 

337 assert str(prompt2) == "sub1, condition is true\nsub2, condition is true" 

338 

339 

340def test_generation_message(): 

341 appl.init() 

342 

343 @ppl 

344 def func(): 

345 "Hello World" 

346 gen1 = gen(lazy_eval=True) 

347 "Hi" 

348 gen2 = gen(lazy_eval=True) 

349 return gen1, gen2 

350 

351 gen1, gen2 = func() 

352 assert str(gen1._args.messages) == "Hello World" 

353 assert str(gen2._args.messages) == "Hello World\nHi" 

354 

355 

356def test_generation_message2(): 

357 def fakegen(): 

358 return "24" 

359 

360 @ppl 

361 def func(): 

362 f"Q: 1 + 2 = ?" 

363 f"A: 3" 

364 f"Q: 15 + 9 = ?" 

365 f"A: {fakegen()}" 

366 return convo() 

367 

368 assert str(func()) == "Q: 1 + 2 = ?\nA: 3\nQ: 15 + 9 = ?\nA: 24"