Coverage for tests/test_prompt.py: 67%
224 statements
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-22 15:39 -0800
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-22 15:39 -0800
1import appl
2import pytest
3from appl import Generation, convo, define, gen, need_ctx, ppl, records
4from appl.compositor import *
7def test_return():
8 @ppl
9 def func():
10 "Hello World"
11 return "answer"
13 assert func() == "answer"
16def test_prompt():
17 @ppl
18 def func(_ctx):
19 "Hello World"
20 return records()
22 assert str(func()) == "Hello World"
25def test_fstring():
26 @ppl
27 def f1():
28 f"a is {1!r}"
29 return records()
31 assert str(f1()) == f"a is {1!r}"
33 @ppl
34 def f2():
35 f"a is {3.1415:.2f}"
36 return records()
38 assert str(f2()) == f"a is {3.1415:.2f}"
41def test_prompts_change():
42 @ppl
43 def func():
44 "Hello"
45 ret1 = records() # the reference
46 ret2 = records().copy() # the copy of the current prompt
47 "World"
48 ret3 = records()
49 return ret1, ret2, ret3
51 ret1, ret2, ret3 = func()
52 assert str(ret1) == "Hello\nWorld"
53 assert str(ret2) == "Hello"
54 assert str(ret3) == "Hello\nWorld"
57def test_return_prompt():
58 @ppl(default_return="prompt")
59 def f1():
60 "Hello World"
62 assert str(f1()) == "Hello World"
64 @ppl(default_return="prompt")
65 def f2():
66 "Hello World"
67 return "answer"
69 # The return is unchanged.
70 assert str(f2()) == "answer"
73def test_record():
74 @ppl
75 def f2():
76 "Hello"
77 "World"
78 return records()
80 @ppl
81 def func():
82 with NumberedList():
83 "first line"
84 "second line"
85 f2() # add the prompts from f2, following the current format.
86 return records()
88 assert str(func()) == f"1. first line\n2. second line\n3. Hello\n4. World"
91def test_inner_func():
92 @ppl
93 def func():
94 "Hello"
96 def func2(): # the inner function use the same context from the outer function.
97 "World"
99 func2()
100 return records()
102 assert str(func()) == "Hello\nWorld"
105def test_tripple_quote():
106 @ppl
107 def func1():
108 """This is a docstring"""
109 """
110 begin
111 1. first
112 2. second
113 """
114 return records()
116 @ppl
117 def func2():
118 """This is a docstring"""
119 # Note that dedent will remove the leading spaces.
120 """begin
121 1. first
122 2. second
123 """
124 return records()
126 @ppl
127 def func3():
128 """This is a docstring"""
129 # Not recommended
130 """
131 begin
132 1. first
133 2. second
134 """ # note the leading spaces here
135 return records()
137 @ppl
138 def func4():
139 """This is a docstring"""
140 # Not recommended, but still works.
141 """begin
1421. first
1432. second
144"""
145 return records()
147 assert str(func1()) == "begin\n 1. first\n 2. second"
148 assert str(func2()) == "begin\n1. first\n2. second"
149 assert str(func3()) == "begin\n 1. first\n 2. second\n "
150 assert str(func4()) == "begin\n1. first\n2. second"
153def test_tripple_quote_fstring():
154 @ppl
155 def func1():
156 x = "end"
157 f"""
158 begin
159 {x}
160 """
161 return records()
163 @ppl
164 def func2():
165 x = f"""
166 begin
167 Hello
169 World
170 """
171 f"""
172 {x}
173 end
174 """
175 return records()
177 @ppl
178 def func3():
179 f"""
180 1+1={1 + \
181 1
182 }
183 """
184 # The recovered code from libcst will become:
185 # f"""
186 # 1+1={1 + \
187 # 1
188 # }
189 # """
190 return records()
192 assert str(func1()) == "begin\nend"
193 assert str(func2()) == "begin\n Hello\n\n World\nend"
194 assert str(func3()) == "1+1=2"
197def test_include_docstring():
198 @ppl(include_docstring=True)
199 def func():
200 """This is a docstring"""
201 "Hello"
202 return records()
204 assert str(func()) == "This is a docstring\nHello"
207def test_include_multiline_docstring():
208 @ppl(include_docstring=True)
209 def func():
210 """This is a
211 multiline docstring"""
213 "Hello"
214 return records()
216 assert str(func()) == "This is a\nmultiline docstring\nHello"
218 @ppl(include_docstring=True)
219 def func2():
220 """
221 This is a
222 multiline docstring
223 """
224 return records()
226 assert str(func2()) == "This is a\n multiline docstring"
229def test_default_no_docstring():
230 @ppl()
231 def func():
232 """This is a docstring"""
233 "Hello"
234 return records()
236 @ppl()
237 def func2():
238 """Same string""" # the first is docstring
239 # the second string is not a docstring anymore, should be included
240 """Same string"""
241 return records()
243 assert str(func()) == "Hello"
244 assert str(func2()) == "Same string"
247def test_copy_ctx():
248 @ppl(ctx="copy")
249 def addon():
250 "World"
251 return str(convo())
253 @ppl
254 def func2():
255 "Hello"
256 first = addon()
257 second = addon()
258 return first, second, records()
260 first, second, origin = func2()
261 assert first == "Hello\nWorld"
262 assert second == "Hello\nWorld"
263 assert str(origin) == "Hello"
266def test_resume_ctx():
267 @ppl(ctx="resume")
268 def resume_ctx():
269 "Hello"
270 return convo()
272 target = []
273 for i in range(3):
274 res = resume_ctx()
275 target += ["Hello"]
276 assert str(res) == "\n".join(target)
279def test_class_resume_ctx():
280 class A:
281 @ppl(ctx="resume")
282 def append(self, msg: str):
283 msg
284 return convo()
286 @classmethod
287 @ppl(ctx="resume")
288 def append_cls(cls, msg: str):
289 msg
290 return convo()
292 a = A()
293 b = A()
294 target_a = []
295 target_b = []
296 target_cls = []
297 for i in range(3):
298 res = a.append("Hello")
299 target_a += ["Hello"]
300 assert str(res) == "\n".join(target_a)
301 res = b.append("World")
302 target_b += ["World"]
303 assert str(res) == "\n".join(target_b)
304 res = A.append_cls("Class")
305 target_cls += ["Class"]
306 assert str(res) == "\n".join(target_cls)
309def test_class_func():
310 class ComplexPrompt:
311 def __init__(self, condition: str):
312 self._condition = condition
314 @ppl(ctx="same")
315 def sub1(self):
316 if self._condition:
317 "sub1, condition is true"
318 else:
319 "sub1, condition is false"
321 @ppl(ctx="same")
322 def sub2(self):
323 if self._condition:
324 "sub2, condition is true"
325 else:
326 "sub2, condition is false"
328 @ppl
329 def func(self):
330 self.sub1()
331 self.sub2()
332 return records()
334 prompt1 = ComplexPrompt(False).func()
335 prompt2 = ComplexPrompt(True).func()
336 assert str(prompt1) == "sub1, condition is false\nsub2, condition is false"
337 assert str(prompt2) == "sub1, condition is true\nsub2, condition is true"
340def test_generation_message():
341 appl.init()
343 @ppl
344 def func():
345 "Hello World"
346 gen1 = gen(lazy_eval=True)
347 "Hi"
348 gen2 = gen(lazy_eval=True)
349 return gen1, gen2
351 gen1, gen2 = func()
352 assert str(gen1._args.messages) == "Hello World"
353 assert str(gen2._args.messages) == "Hello World\nHi"
356def test_generation_message2():
357 def fakegen():
358 return "24"
360 @ppl
361 def func():
362 f"Q: 1 + 2 = ?"
363 f"A: 3"
364 f"Q: 15 + 9 = ?"
365 f"A: {fakegen()}"
366 return convo()
368 assert str(func()) == "Q: 1 + 2 = ?\nA: 3\nQ: 15 + 9 = ?\nA: 24"