Coverage for src/appl/core/compile.py: 95%

238 statements  

« prev     ^ index     » next       coverage.py v7.6.7, created at 2024-11-22 15:39 -0800

1from __future__ import annotations 

2 

3import ast 

4import inspect 

5import linecache 

6import os 

7import re 

8import sys 

9import textwrap 

10import traceback 

11from ast import ( 

12 AST, 

13 Assign, 

14 Attribute, 

15 Call, 

16 Constant, 

17 Expr, 

18 FormattedValue, 

19 FunctionDef, 

20 JoinedStr, 

21 Load, 

22 Name, 

23 NamedExpr, 

24 NodeTransformer, 

25 Store, 

26 With, 

27 stmt, 

28) 

29from types import CodeType 

30from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union 

31 

32import libcst as cst 

33from loguru import logger 

34 

35from .context import PromptContext 

36 

37GLOBALS_KEYWORD = ast.keyword( 

38 arg="_globals", 

39 value=Call(func=Name(id="globals", ctx=Load()), args=[], keywords=[]), 

40) 

41LOCALS_KEYWORD = ast.keyword( 

42 arg="_locals", 

43 value=Call(func=Name(id="locals", ctx=Load()), args=[], keywords=[]), 

44) 

45CTX_KEYWORD = ast.keyword(arg="_ctx", value=Name(id="_ctx", ctx=Load())) 

46CTX_ARG = ast.arg(arg="_ctx", annotation=Name(id="PromptContext", ctx=Load())) 

47 

48 

49def _has_arg(args: Union[List[ast.arg], List[ast.keyword]], name: str) -> bool: 

50 return any( 

51 isinstance(arg, (ast.arg, ast.keyword)) and arg.arg == name for arg in args 

52 ) 

53 

54 

55class ApplNodeTransformer(NodeTransformer): 

56 """A base class for AST node transformers in APPL.""" 

57 

58 def __init__(self, compile_info: Dict, *args: Any, **kwargs: Any) -> None: 

59 """Initialize the transformer with compile info.""" 

60 super().__init__(*args, **kwargs) 

61 self._compile_info = compile_info 

62 

63 def _raise_syntax_error(self, lineno: int, col_offset: int, msg: str) -> None: 

64 file = self._compile_info["sourcefile"] 

65 lineno = lineno + self._compile_info["lineno"] - 1 

66 text = linecache.getline(file, lineno) 

67 raise SyntaxError(msg, (file, lineno, col_offset, text)) 

68 

69 

70class RemoveApplDecorator(ApplNodeTransformer): 

71 """An AST node transformer that removes the ppl decorator.""" 

72 

73 def __init__(self, *args: Any, **kwargs: Any) -> None: 

74 """Initialize the transformer with the outmost flag.""" 

75 super().__init__(*args, **kwargs) 

76 self._outmost = True 

77 

78 def _is_ppl_decorator(self, decorator: AST) -> bool: 

79 if isinstance(decorator, Name): 

80 return decorator.id == "ppl" 

81 elif isinstance(decorator, Call): 

82 if isinstance(func := decorator.func, Name): 

83 return func.id == "ppl" 

84 return False # pragma: no cover 

85 

86 def visit_FunctionDef(self, node): 

87 """Remove the ppl decorator from the function definition.""" 

88 if node.decorator_list: 

89 for decorator in node.decorator_list: 

90 if self._is_ppl_decorator(decorator): 

91 if not self._outmost: 

92 self._raise_syntax_error( 

93 decorator.lineno, 

94 decorator.col_offset, 

95 "Nested ppl decorator is not allowed yet for APPL.", 

96 ) 

97 # all decorators should be removed 

98 node.decorator_list = [] 

99 if self._outmost: 

100 self._outmost = False 

101 self.generic_visit(node) 

102 return node 

103 

104 

105class SplitString(ApplNodeTransformer): 

106 """An AST node transformer that splits the f-string into multiple parts.""" 

107 

108 def _add_formatted_value(self, node: FormattedValue) -> Iterable[stmt]: 

109 format_args = [node.value] 

110 if node.format_spec is not None: 

111 format_args.append(node.format_spec) 

112 format_keywords = [] 

113 if node.conversion != -1: 

114 # add conversion 

115 format_keywords.append( 

116 ast.keyword(arg="conversion", value=Constant(node.conversion)) 

117 ) 

118 expr = Call( 

119 func=Attribute( 

120 value=Name(id="appl", ctx=Load()), 

121 attr="format", 

122 ctx=Load(), 

123 ), 

124 args=format_args, 

125 keywords=format_keywords, 

126 ) 

127 # converted to `appl.format(value, format_spec)` 

128 default_result: Iterable[stmt] = [Expr(expr)] 

129 if isinstance(node.value, NamedExpr): 

130 return default_result 

131 

132 if spec := node.format_spec: 

133 spec_str = ast.unparse(spec) 

134 # logger.debug(spec_str) 

135 if spec_str and spec_str[2] == "=": # f"= ..." 

136 self._raise_syntax_error( 

137 node.lineno, 

138 node.col_offset, 

139 "Not supported format of named expression inside f-string. " 

140 "To use named expression, please add brackets around " 

141 f"`{ast.unparse(node)[3:-2]}`.", 

142 ) 

143 return default_result 

144 

145 def visit_Expr(self, node: Expr) -> stmt: 

146 """Split the f-string into multiple parts, so that we can add appl.execute wrapper to each part.""" 

147 if isinstance(node.value, JoinedStr): 

148 fstring = node.value 

149 # logger.debug(f"For joined string: {fstring}") 

150 body: List[stmt] = [] 

151 for value in fstring.values: 

152 if isinstance(value, Constant): 

153 body.append(Expr(value)) 

154 elif isinstance(value, FormattedValue): 

155 body.extend(self._add_formatted_value(value)) 

156 else: 

157 raise ValueError( 

158 f"Unknown value type in a JoinedStr: {type(value)}" 

159 ) 

160 if len(body) == 0: # empty string 

161 return node 

162 if len(body) == 1: # single string 

163 return body[0] 

164 return With( 

165 items=[ 

166 ast.withitem( 

167 context_expr=Call( 

168 func=Attribute( 

169 value=Name(id="appl", ctx=Load()), 

170 attr="Str", 

171 ctx=Load(), 

172 ), 

173 args=[], 

174 keywords=[], 

175 ) 

176 ) 

177 ], 

178 body=body, 

179 ) 

180 return node 

181 

182 

183class CallWithContext(ApplNodeTransformer): 

184 """An AST node transformer provides the context to function calls.""" 

185 

186 def visit_Call(self, node: Call) -> Call: 

187 """Provide context (_ctx) to function calls that needs ctx.""" 

188 self.generic_visit(node) 

189 # logger.debug(f"visit Call: {ast.dump(node, indent=4)}") 

190 # * use appl.with_ctx as wrapper for all functions, 

191 # * pass _ctx to the function annotated with @need_ctx 

192 new_node = Call( 

193 Attribute( 

194 value=Name(id="appl", ctx=Load()), 

195 attr="with_ctx", 

196 ctx=Load(), 

197 ), 

198 node.args, 

199 node.keywords, 

200 ) 

201 new_node.keywords.append(ast.keyword(arg="_func", value=node.func)) 

202 # add _ctx to kwargs if not present 

203 if not _has_arg(node.keywords, "_ctx"): 

204 new_node.keywords.append(CTX_KEYWORD) 

205 new_node.keywords.append(GLOBALS_KEYWORD) 

206 new_node.keywords.append(LOCALS_KEYWORD) 

207 return new_node 

208 

209 

210class AddCtxToArgs(ApplNodeTransformer): 

211 """An AST node transformer that adds _ctx to the function arguments.""" 

212 

213 def visit_FunctionDef(self, node: FunctionDef) -> FunctionDef: 

214 """Add _ctx to the function arguments if not present.""" 

215 # ! only add _ctx to outermost function def, not call generic_visit here. 

216 # self.generic_visit(node) # !! do not call 

217 args = node.args 

218 # add _ctx to kwargs if not present 

219 if not _has_arg(args.args, "_ctx") and not _has_arg(args.kwonlyargs, "_ctx"): 

220 args.kwonlyargs.append(CTX_ARG) 

221 # PromptContext() as default 

222 args.kw_defaults.append( 

223 Call(func=Name(id="PromptContext", ctx=Load()), args=[], keywords=[]) 

224 ) 

225 for var in self._compile_info["freevars"]: 

226 args.kwonlyargs.append(ast.arg(arg=var)) 

227 args.kw_defaults.append(ast.Name(id=var, ctx=Load())) 

228 logger.debug(f"add freevar {var} to function {node.name} args.") 

229 return node 

230 

231 

232class AddExecuteWrapper(ApplNodeTransformer): 

233 """An AST node transformer that adds the appl.execute wrapper to expression statements.""" 

234 

235 def visit_Expr(self, node: Expr) -> Expr: 

236 """Add appl.execute wrapper to expression statements.""" 

237 return Expr( 

238 Call( 

239 func=Attribute( 

240 value=Name(id="appl", ctx=Load()), 

241 attr="execute", 

242 ctx=Load(), 

243 ), 

244 args=[node.value], 

245 keywords=[CTX_KEYWORD], # , GLOBALS_KEYWORD, LOCALS_KEYWORD], 

246 ) 

247 ) 

248 

249 

250class DedentTripleQuotedString(cst.CSTTransformer): 

251 """A CST transformer that dedents triple-quoted strings in the source code.""" 

252 

253 def visit_Module(self, node: cst.Module) -> bool: 

254 """Store the module node for later use.""" 

255 self.module = node 

256 return True 

257 

258 def leave_SimpleString( 

259 self, original_node: cst.SimpleString, updated_node: cst.SimpleString 

260 ) -> cst.SimpleString: 

261 """Dedent triple-quoted strings in the source code.""" 

262 value = original_node.value 

263 delim = value[-3:] 

264 # Check if the string is wrapped by triple quotes (""" or ''') 

265 if delim in ['"""', "'''"]: 

266 has_prefix_r = False 

267 if value.startswith("r" + delim): # deal with raw string 

268 has_prefix_r = True 

269 value = value[1:] 

270 

271 assert value[:3] == delim 

272 # Remove the quotes to process the inner string content 

273 inner_string = value[3:-3] 

274 

275 # Use the standard of cleandoc to dedent the inner string 

276 modified_inner_string = inspect.cleandoc(inner_string) 

277 

278 # Add back the triple quotes to the modified string 

279 new_value = f"{delim}{modified_inner_string}{delim}" 

280 if has_prefix_r: 

281 new_value = "r" + new_value 

282 

283 logger.debug( 

284 f"multiline string dedented\n" 

285 f"before:\n{original_node.value}\n" 

286 f"after:\n{new_value}" 

287 ) 

288 

289 # Return the updated SimpleString node with the modified value 

290 return updated_node.with_changes(value=new_value) 

291 

292 return updated_node 

293 

294 # format string 

295 def leave_FormattedString( 

296 self, original_node: cst.FormattedString, updated_node: cst.FormattedString 

297 ) -> cst.FormattedString: 

298 """Dedent triple-quoted formatted strings in the source code.""" 

299 if len(original_node.end) != 3: 

300 return updated_node 

301 

302 def get_num_leading_whitespace(line: str) -> int: 

303 """Compute the leading whitespace of this line.""" 

304 if g := re.match(r"^[ \t]*", line): 

305 return len(g.group()) 

306 return 0 

307 

308 def dedent_str(text: str) -> str: 

309 cleaned = inspect.cleandoc(text) 

310 lines = cleaned.splitlines() 

311 margin: Optional[int] = None 

312 for line in lines: 

313 n = get_num_leading_whitespace(line) 

314 # try to fix a bug in `libcst.code_for_node`` 

315 # the leading whitespace of '}' is not resumed correctly 

316 # as well as the contents within `{` and `}` (TODO) 

317 if len(line) > n and line[n] != "}": 

318 if margin is None: 

319 margin = n 

320 else: 

321 margin = min(margin, n) 

322 if margin is not None and margin > 0: 

323 new_lines = [] 

324 for line in lines: 

325 n = get_num_leading_whitespace(line) 

326 new_lines.append(line[min(n, margin) :]) 

327 cleaned = os.linesep.join(new_lines) 

328 cleaned = inspect.cleandoc(cleaned) # clean again 

329 logger.warning( 

330 f"unusual multiline formatted string detected, " 

331 f"this is a bug related to libcst.code_for_node\n" 

332 f"The string after manual fix is:\n{cleaned}" 

333 ) 

334 logger.debug( 

335 f"multiline formmated string dedented\n" 

336 f"before:\n{text}\n" 

337 f"after:\n{cleaned}" 

338 ) 

339 return cleaned 

340 

341 original_code = self.module.code_for_node(original_node) 

342 simple_str = original_code[len(original_node.start) : -len(original_node.end)] 

343 format_str = f"{original_node.start}{dedent_str(simple_str)}{original_node.end}" 

344 parsed_cst = cst.parse_expression(format_str) 

345 if isinstance(parsed_cst, cst.FormattedString): 

346 return parsed_cst 

347 raise RuntimeError( 

348 f"Failed to parse modified formatted string as libcst.FormattedString: {format_str}" 

349 ) 

350 

351 

352def dedent_triple_quoted_string(code: str) -> str: 

353 """Automatically dedent triple-quoted strings with in the code (with inspect.cleandoc).""" 

354 # Parse the source code into a CST 

355 cst_module = cst.parse_module(code) 

356 # Apply the transformer to dedent triple-quoted strings 

357 cst_transformer = DedentTripleQuotedString() 

358 modified_cst_module = cst_module.visit(cst_transformer) 

359 # return the modified code 

360 return modified_cst_module.code 

361 

362 

363def _get_docstring_quote_count(code: str) -> Optional[int]: 

364 cst_tree = cst.parse_module(code) 

365 for node in cst_tree.body: 

366 if isinstance(node, cst.FunctionDef): 

367 # Only detect the first func 

368 statement = node.body.body[0] 

369 if isinstance(statement, cst.SimpleStatementLine): 

370 if isinstance(statement.body[0], cst.Expr): 

371 s = statement.body[0].value 

372 if isinstance(s, cst.SimpleString): 

373 if s.value.startswith('"""') or s.value.startswith("'''"): 

374 return 3 

375 else: 

376 return 1 

377 return None 

378 assert False, "no function detected" # should not reach here 

379 

380 

381class APPLCompiled: 

382 """A compiled APPL function that can be called with context.""" 

383 

384 def __init__( 

385 self, code: CodeType, ast: AST, original_func: Callable, compile_info: Dict 

386 ): 

387 """Initialize the compiled function. 

388 

389 Args: 

390 code: The compiled code object. 

391 ast: The AST of the compiled code. 

392 original_func: The original function. 

393 compile_info: The compile information. 

394 """ 

395 self._code = code 

396 self._ast = ast 

397 self._name = original_func.__name__ 

398 self._original_func = original_func 

399 self._compile_info = compile_info 

400 self._docstring_quote_count = _get_docstring_quote_count(compile_info["source"]) 

401 

402 @property 

403 def freevars(self) -> Tuple[str, ...]: 

404 """Get the free variables of the compiled function.""" 

405 if "freevars" in self._compile_info: 

406 return self._compile_info["freevars"] 

407 return self._original_func.__code__.co_freevars 

408 

409 def __call__( 

410 self, 

411 *args: Any, 

412 _globals: Optional[Dict] = None, 

413 _locals: Dict = {}, 

414 **kwargs: Any, 

415 ) -> Any: 

416 """Call the compiled function.""" 

417 _globals = _globals or self._original_func.__globals__ 

418 if "appl" not in _globals: 

419 _globals["appl"] = sys.modules["appl"] 

420 local_vars = {"PromptContext": PromptContext} 

421 # get closure variables from locals 

422 for name in self.freevars: 

423 if name in _locals: 

424 # set the closure variables to local_vars 

425 local_vars[name] = _locals[name] 

426 elif name == "__class__": 

427 # attempt for super() workaround 

428 # def super_(t: Any = None, obj: Any = None): 

429 # t = t or self_.__class__ 

430 # obj = obj or self_ 

431 # return super(t, obj) 

432 local_vars["__class__"] = args[0].__class__ 

433 else: 

434 raise ValueError( 

435 f"Freevar '{name}' not found. If you are using closure variables, " 

436 "please provide their values in the _locals argument. " 

437 "For example, assume the function is `func`, use `func(..., _locals=locals())`. " 

438 "Alternatively, you can first use the `appl.as_func` to convert the " 

439 "function within the current scope (automatically feeding the locals)." 

440 ) 

441 

442 exec(self._code, _globals, local_vars) # TODO: use closure argument 

443 func = local_vars[self._name] 

444 return func(*args, **kwargs) 

445 

446 def __repr__(self): 

447 return f"APPLCompiled({self._name})" 

448 

449 

450def appl_dedent(source: str) -> str: 

451 """Dedent the source code.""" 

452 source = textwrap.dedent(source) 

453 try: 

454 ast.parse(source) 

455 except Exception: # the dedent failed due to multiline string 

456 logger.warning( 

457 "The source code contains multiline string that cannot be dedented. " 

458 "It is recommended to dedent the multiline string aligning with the function, " 

459 "where APPL will automatically dedent the multiline string " 

460 "in the same way as cleaning docstring." 

461 ) 

462 # Compute the dedent and remove the leading whitespace 

463 leading_whitespace = re.compile("(^[ \t]*)(?:[^ \t\n])", re.MULTILINE) 

464 indents = leading_whitespace.findall(source) 

465 margin = indents[0] # use the first line as the standard 

466 source = re.sub(r"(?m)^" + margin, "", source) 

467 logger.warning(f"The source code after workaround:\n{source}") 

468 return source 

469 

470 

471def appl_compile(func: Callable) -> APPLCompiled: 

472 """Compile an APPL function.""" 

473 sourcefile = inspect.getsourcefile(func) 

474 lines, lineno = inspect.getsourcelines(func) 

475 source = appl_dedent(inspect.getsource(func)) 

476 key = f"<appl-compiled:{sourcefile}:{lineno}>" 

477 linecache.cache[key] = ( 

478 len(source), 

479 None, 

480 [line + os.linesep for line in source.splitlines()], 

481 key, 

482 ) 

483 

484 source = dedent_triple_quoted_string(source) 

485 parsed_ast = ast.parse(source) 

486 logger.debug( 

487 f"\n{'-'*20} code BEFORE appl compile {'-'*20}\n{ast.unparse(parsed_ast)}" 

488 ) 

489 

490 transformers = [ 

491 RemoveApplDecorator, 

492 SplitString, 

493 CallWithContext, 

494 AddCtxToArgs, 

495 AddExecuteWrapper, 

496 ] 

497 compile_info = { 

498 "source": source, 

499 "sourcefile": sourcefile, 

500 "lineno": lineno, 

501 "func_name": func.__name__, 

502 "freevars": func.__code__.co_freevars, 

503 "docstring": func.__doc__, 

504 } 

505 for transformer in transformers: 

506 parsed_ast = transformer(compile_info).visit(parsed_ast) 

507 

508 parsed_ast = ast.fix_missing_locations(parsed_ast) 

509 compiled_ast = compile(parsed_ast, filename=key, mode="exec") 

510 logger.debug( 

511 f"\n{'-'*20} code AFTER appl compile {'-'*20}\n{ast.unparse(parsed_ast)}" 

512 ) 

513 

514 return APPLCompiled(compiled_ast, parsed_ast, func, compile_info)