Coverage for src/appl/core/compile.py: 95%
238 statements
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-22 15:39 -0800
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-22 15:39 -0800
1from __future__ import annotations
3import ast
4import inspect
5import linecache
6import os
7import re
8import sys
9import textwrap
10import traceback
11from ast import (
12 AST,
13 Assign,
14 Attribute,
15 Call,
16 Constant,
17 Expr,
18 FormattedValue,
19 FunctionDef,
20 JoinedStr,
21 Load,
22 Name,
23 NamedExpr,
24 NodeTransformer,
25 Store,
26 With,
27 stmt,
28)
29from types import CodeType
30from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
32import libcst as cst
33from loguru import logger
35from .context import PromptContext
37GLOBALS_KEYWORD = ast.keyword(
38 arg="_globals",
39 value=Call(func=Name(id="globals", ctx=Load()), args=[], keywords=[]),
40)
41LOCALS_KEYWORD = ast.keyword(
42 arg="_locals",
43 value=Call(func=Name(id="locals", ctx=Load()), args=[], keywords=[]),
44)
45CTX_KEYWORD = ast.keyword(arg="_ctx", value=Name(id="_ctx", ctx=Load()))
46CTX_ARG = ast.arg(arg="_ctx", annotation=Name(id="PromptContext", ctx=Load()))
49def _has_arg(args: Union[List[ast.arg], List[ast.keyword]], name: str) -> bool:
50 return any(
51 isinstance(arg, (ast.arg, ast.keyword)) and arg.arg == name for arg in args
52 )
55class ApplNodeTransformer(NodeTransformer):
56 """A base class for AST node transformers in APPL."""
58 def __init__(self, compile_info: Dict, *args: Any, **kwargs: Any) -> None:
59 """Initialize the transformer with compile info."""
60 super().__init__(*args, **kwargs)
61 self._compile_info = compile_info
63 def _raise_syntax_error(self, lineno: int, col_offset: int, msg: str) -> None:
64 file = self._compile_info["sourcefile"]
65 lineno = lineno + self._compile_info["lineno"] - 1
66 text = linecache.getline(file, lineno)
67 raise SyntaxError(msg, (file, lineno, col_offset, text))
70class RemoveApplDecorator(ApplNodeTransformer):
71 """An AST node transformer that removes the ppl decorator."""
73 def __init__(self, *args: Any, **kwargs: Any) -> None:
74 """Initialize the transformer with the outmost flag."""
75 super().__init__(*args, **kwargs)
76 self._outmost = True
78 def _is_ppl_decorator(self, decorator: AST) -> bool:
79 if isinstance(decorator, Name):
80 return decorator.id == "ppl"
81 elif isinstance(decorator, Call):
82 if isinstance(func := decorator.func, Name):
83 return func.id == "ppl"
84 return False # pragma: no cover
86 def visit_FunctionDef(self, node):
87 """Remove the ppl decorator from the function definition."""
88 if node.decorator_list:
89 for decorator in node.decorator_list:
90 if self._is_ppl_decorator(decorator):
91 if not self._outmost:
92 self._raise_syntax_error(
93 decorator.lineno,
94 decorator.col_offset,
95 "Nested ppl decorator is not allowed yet for APPL.",
96 )
97 # all decorators should be removed
98 node.decorator_list = []
99 if self._outmost:
100 self._outmost = False
101 self.generic_visit(node)
102 return node
105class SplitString(ApplNodeTransformer):
106 """An AST node transformer that splits the f-string into multiple parts."""
108 def _add_formatted_value(self, node: FormattedValue) -> Iterable[stmt]:
109 format_args = [node.value]
110 if node.format_spec is not None:
111 format_args.append(node.format_spec)
112 format_keywords = []
113 if node.conversion != -1:
114 # add conversion
115 format_keywords.append(
116 ast.keyword(arg="conversion", value=Constant(node.conversion))
117 )
118 expr = Call(
119 func=Attribute(
120 value=Name(id="appl", ctx=Load()),
121 attr="format",
122 ctx=Load(),
123 ),
124 args=format_args,
125 keywords=format_keywords,
126 )
127 # converted to `appl.format(value, format_spec)`
128 default_result: Iterable[stmt] = [Expr(expr)]
129 if isinstance(node.value, NamedExpr):
130 return default_result
132 if spec := node.format_spec:
133 spec_str = ast.unparse(spec)
134 # logger.debug(spec_str)
135 if spec_str and spec_str[2] == "=": # f"= ..."
136 self._raise_syntax_error(
137 node.lineno,
138 node.col_offset,
139 "Not supported format of named expression inside f-string. "
140 "To use named expression, please add brackets around "
141 f"`{ast.unparse(node)[3:-2]}`.",
142 )
143 return default_result
145 def visit_Expr(self, node: Expr) -> stmt:
146 """Split the f-string into multiple parts, so that we can add appl.execute wrapper to each part."""
147 if isinstance(node.value, JoinedStr):
148 fstring = node.value
149 # logger.debug(f"For joined string: {fstring}")
150 body: List[stmt] = []
151 for value in fstring.values:
152 if isinstance(value, Constant):
153 body.append(Expr(value))
154 elif isinstance(value, FormattedValue):
155 body.extend(self._add_formatted_value(value))
156 else:
157 raise ValueError(
158 f"Unknown value type in a JoinedStr: {type(value)}"
159 )
160 if len(body) == 0: # empty string
161 return node
162 if len(body) == 1: # single string
163 return body[0]
164 return With(
165 items=[
166 ast.withitem(
167 context_expr=Call(
168 func=Attribute(
169 value=Name(id="appl", ctx=Load()),
170 attr="Str",
171 ctx=Load(),
172 ),
173 args=[],
174 keywords=[],
175 )
176 )
177 ],
178 body=body,
179 )
180 return node
183class CallWithContext(ApplNodeTransformer):
184 """An AST node transformer provides the context to function calls."""
186 def visit_Call(self, node: Call) -> Call:
187 """Provide context (_ctx) to function calls that needs ctx."""
188 self.generic_visit(node)
189 # logger.debug(f"visit Call: {ast.dump(node, indent=4)}")
190 # * use appl.with_ctx as wrapper for all functions,
191 # * pass _ctx to the function annotated with @need_ctx
192 new_node = Call(
193 Attribute(
194 value=Name(id="appl", ctx=Load()),
195 attr="with_ctx",
196 ctx=Load(),
197 ),
198 node.args,
199 node.keywords,
200 )
201 new_node.keywords.append(ast.keyword(arg="_func", value=node.func))
202 # add _ctx to kwargs if not present
203 if not _has_arg(node.keywords, "_ctx"):
204 new_node.keywords.append(CTX_KEYWORD)
205 new_node.keywords.append(GLOBALS_KEYWORD)
206 new_node.keywords.append(LOCALS_KEYWORD)
207 return new_node
210class AddCtxToArgs(ApplNodeTransformer):
211 """An AST node transformer that adds _ctx to the function arguments."""
213 def visit_FunctionDef(self, node: FunctionDef) -> FunctionDef:
214 """Add _ctx to the function arguments if not present."""
215 # ! only add _ctx to outermost function def, not call generic_visit here.
216 # self.generic_visit(node) # !! do not call
217 args = node.args
218 # add _ctx to kwargs if not present
219 if not _has_arg(args.args, "_ctx") and not _has_arg(args.kwonlyargs, "_ctx"):
220 args.kwonlyargs.append(CTX_ARG)
221 # PromptContext() as default
222 args.kw_defaults.append(
223 Call(func=Name(id="PromptContext", ctx=Load()), args=[], keywords=[])
224 )
225 for var in self._compile_info["freevars"]:
226 args.kwonlyargs.append(ast.arg(arg=var))
227 args.kw_defaults.append(ast.Name(id=var, ctx=Load()))
228 logger.debug(f"add freevar {var} to function {node.name} args.")
229 return node
232class AddExecuteWrapper(ApplNodeTransformer):
233 """An AST node transformer that adds the appl.execute wrapper to expression statements."""
235 def visit_Expr(self, node: Expr) -> Expr:
236 """Add appl.execute wrapper to expression statements."""
237 return Expr(
238 Call(
239 func=Attribute(
240 value=Name(id="appl", ctx=Load()),
241 attr="execute",
242 ctx=Load(),
243 ),
244 args=[node.value],
245 keywords=[CTX_KEYWORD], # , GLOBALS_KEYWORD, LOCALS_KEYWORD],
246 )
247 )
250class DedentTripleQuotedString(cst.CSTTransformer):
251 """A CST transformer that dedents triple-quoted strings in the source code."""
253 def visit_Module(self, node: cst.Module) -> bool:
254 """Store the module node for later use."""
255 self.module = node
256 return True
258 def leave_SimpleString(
259 self, original_node: cst.SimpleString, updated_node: cst.SimpleString
260 ) -> cst.SimpleString:
261 """Dedent triple-quoted strings in the source code."""
262 value = original_node.value
263 delim = value[-3:]
264 # Check if the string is wrapped by triple quotes (""" or ''')
265 if delim in ['"""', "'''"]:
266 has_prefix_r = False
267 if value.startswith("r" + delim): # deal with raw string
268 has_prefix_r = True
269 value = value[1:]
271 assert value[:3] == delim
272 # Remove the quotes to process the inner string content
273 inner_string = value[3:-3]
275 # Use the standard of cleandoc to dedent the inner string
276 modified_inner_string = inspect.cleandoc(inner_string)
278 # Add back the triple quotes to the modified string
279 new_value = f"{delim}{modified_inner_string}{delim}"
280 if has_prefix_r:
281 new_value = "r" + new_value
283 logger.debug(
284 f"multiline string dedented\n"
285 f"before:\n{original_node.value}\n"
286 f"after:\n{new_value}"
287 )
289 # Return the updated SimpleString node with the modified value
290 return updated_node.with_changes(value=new_value)
292 return updated_node
294 # format string
295 def leave_FormattedString(
296 self, original_node: cst.FormattedString, updated_node: cst.FormattedString
297 ) -> cst.FormattedString:
298 """Dedent triple-quoted formatted strings in the source code."""
299 if len(original_node.end) != 3:
300 return updated_node
302 def get_num_leading_whitespace(line: str) -> int:
303 """Compute the leading whitespace of this line."""
304 if g := re.match(r"^[ \t]*", line):
305 return len(g.group())
306 return 0
308 def dedent_str(text: str) -> str:
309 cleaned = inspect.cleandoc(text)
310 lines = cleaned.splitlines()
311 margin: Optional[int] = None
312 for line in lines:
313 n = get_num_leading_whitespace(line)
314 # try to fix a bug in `libcst.code_for_node``
315 # the leading whitespace of '}' is not resumed correctly
316 # as well as the contents within `{` and `}` (TODO)
317 if len(line) > n and line[n] != "}":
318 if margin is None:
319 margin = n
320 else:
321 margin = min(margin, n)
322 if margin is not None and margin > 0:
323 new_lines = []
324 for line in lines:
325 n = get_num_leading_whitespace(line)
326 new_lines.append(line[min(n, margin) :])
327 cleaned = os.linesep.join(new_lines)
328 cleaned = inspect.cleandoc(cleaned) # clean again
329 logger.warning(
330 f"unusual multiline formatted string detected, "
331 f"this is a bug related to libcst.code_for_node\n"
332 f"The string after manual fix is:\n{cleaned}"
333 )
334 logger.debug(
335 f"multiline formmated string dedented\n"
336 f"before:\n{text}\n"
337 f"after:\n{cleaned}"
338 )
339 return cleaned
341 original_code = self.module.code_for_node(original_node)
342 simple_str = original_code[len(original_node.start) : -len(original_node.end)]
343 format_str = f"{original_node.start}{dedent_str(simple_str)}{original_node.end}"
344 parsed_cst = cst.parse_expression(format_str)
345 if isinstance(parsed_cst, cst.FormattedString):
346 return parsed_cst
347 raise RuntimeError(
348 f"Failed to parse modified formatted string as libcst.FormattedString: {format_str}"
349 )
352def dedent_triple_quoted_string(code: str) -> str:
353 """Automatically dedent triple-quoted strings with in the code (with inspect.cleandoc)."""
354 # Parse the source code into a CST
355 cst_module = cst.parse_module(code)
356 # Apply the transformer to dedent triple-quoted strings
357 cst_transformer = DedentTripleQuotedString()
358 modified_cst_module = cst_module.visit(cst_transformer)
359 # return the modified code
360 return modified_cst_module.code
363def _get_docstring_quote_count(code: str) -> Optional[int]:
364 cst_tree = cst.parse_module(code)
365 for node in cst_tree.body:
366 if isinstance(node, cst.FunctionDef):
367 # Only detect the first func
368 statement = node.body.body[0]
369 if isinstance(statement, cst.SimpleStatementLine):
370 if isinstance(statement.body[0], cst.Expr):
371 s = statement.body[0].value
372 if isinstance(s, cst.SimpleString):
373 if s.value.startswith('"""') or s.value.startswith("'''"):
374 return 3
375 else:
376 return 1
377 return None
378 assert False, "no function detected" # should not reach here
381class APPLCompiled:
382 """A compiled APPL function that can be called with context."""
384 def __init__(
385 self, code: CodeType, ast: AST, original_func: Callable, compile_info: Dict
386 ):
387 """Initialize the compiled function.
389 Args:
390 code: The compiled code object.
391 ast: The AST of the compiled code.
392 original_func: The original function.
393 compile_info: The compile information.
394 """
395 self._code = code
396 self._ast = ast
397 self._name = original_func.__name__
398 self._original_func = original_func
399 self._compile_info = compile_info
400 self._docstring_quote_count = _get_docstring_quote_count(compile_info["source"])
402 @property
403 def freevars(self) -> Tuple[str, ...]:
404 """Get the free variables of the compiled function."""
405 if "freevars" in self._compile_info:
406 return self._compile_info["freevars"]
407 return self._original_func.__code__.co_freevars
409 def __call__(
410 self,
411 *args: Any,
412 _globals: Optional[Dict] = None,
413 _locals: Dict = {},
414 **kwargs: Any,
415 ) -> Any:
416 """Call the compiled function."""
417 _globals = _globals or self._original_func.__globals__
418 if "appl" not in _globals:
419 _globals["appl"] = sys.modules["appl"]
420 local_vars = {"PromptContext": PromptContext}
421 # get closure variables from locals
422 for name in self.freevars:
423 if name in _locals:
424 # set the closure variables to local_vars
425 local_vars[name] = _locals[name]
426 elif name == "__class__":
427 # attempt for super() workaround
428 # def super_(t: Any = None, obj: Any = None):
429 # t = t or self_.__class__
430 # obj = obj or self_
431 # return super(t, obj)
432 local_vars["__class__"] = args[0].__class__
433 else:
434 raise ValueError(
435 f"Freevar '{name}' not found. If you are using closure variables, "
436 "please provide their values in the _locals argument. "
437 "For example, assume the function is `func`, use `func(..., _locals=locals())`. "
438 "Alternatively, you can first use the `appl.as_func` to convert the "
439 "function within the current scope (automatically feeding the locals)."
440 )
442 exec(self._code, _globals, local_vars) # TODO: use closure argument
443 func = local_vars[self._name]
444 return func(*args, **kwargs)
446 def __repr__(self):
447 return f"APPLCompiled({self._name})"
450def appl_dedent(source: str) -> str:
451 """Dedent the source code."""
452 source = textwrap.dedent(source)
453 try:
454 ast.parse(source)
455 except Exception: # the dedent failed due to multiline string
456 logger.warning(
457 "The source code contains multiline string that cannot be dedented. "
458 "It is recommended to dedent the multiline string aligning with the function, "
459 "where APPL will automatically dedent the multiline string "
460 "in the same way as cleaning docstring."
461 )
462 # Compute the dedent and remove the leading whitespace
463 leading_whitespace = re.compile("(^[ \t]*)(?:[^ \t\n])", re.MULTILINE)
464 indents = leading_whitespace.findall(source)
465 margin = indents[0] # use the first line as the standard
466 source = re.sub(r"(?m)^" + margin, "", source)
467 logger.warning(f"The source code after workaround:\n{source}")
468 return source
471def appl_compile(func: Callable) -> APPLCompiled:
472 """Compile an APPL function."""
473 sourcefile = inspect.getsourcefile(func)
474 lines, lineno = inspect.getsourcelines(func)
475 source = appl_dedent(inspect.getsource(func))
476 key = f"<appl-compiled:{sourcefile}:{lineno}>"
477 linecache.cache[key] = (
478 len(source),
479 None,
480 [line + os.linesep for line in source.splitlines()],
481 key,
482 )
484 source = dedent_triple_quoted_string(source)
485 parsed_ast = ast.parse(source)
486 logger.debug(
487 f"\n{'-'*20} code BEFORE appl compile {'-'*20}\n{ast.unparse(parsed_ast)}"
488 )
490 transformers = [
491 RemoveApplDecorator,
492 SplitString,
493 CallWithContext,
494 AddCtxToArgs,
495 AddExecuteWrapper,
496 ]
497 compile_info = {
498 "source": source,
499 "sourcefile": sourcefile,
500 "lineno": lineno,
501 "func_name": func.__name__,
502 "freevars": func.__code__.co_freevars,
503 "docstring": func.__doc__,
504 }
505 for transformer in transformers:
506 parsed_ast = transformer(compile_info).visit(parsed_ast)
508 parsed_ast = ast.fix_missing_locations(parsed_ast)
509 compiled_ast = compile(parsed_ast, filename=key, mode="exec")
510 logger.debug(
511 f"\n{'-'*20} code AFTER appl compile {'-'*20}\n{ast.unparse(parsed_ast)}"
512 )
514 return APPLCompiled(compiled_ast, parsed_ast, func, compile_info)