mirror of
https://github.com/osmarks/random-stuff
synced 2024-12-26 18:10:34 +00:00
78 lines
2.7 KiB
Python
78 lines
2.7 KiB
Python
import ast
|
|
import inspect
|
|
import types
|
|
|
|
class Sentinel: pass
|
|
|
|
SENTINEL = Sentinel()
|
|
|
|
def rewrite_recursion(f):
|
|
_, pos = inspect.getsourcelines(f)
|
|
source = inspect.getsource(f)
|
|
source = '\n'.join(source.splitlines()[1:]) # remove the decorator first line.
|
|
|
|
old_code_obj = f.__code__
|
|
old_ast = ast.parse(source)
|
|
|
|
def find_outermost_function_def(node):
|
|
if isinstance(node, ast.FunctionDef):
|
|
return node
|
|
for child in ast.iter_child_nodes(node):
|
|
if r := find_outermost_function_def(child): return r
|
|
|
|
outer = find_outermost_function_def(old_ast)
|
|
|
|
def rewrite(node):
|
|
if node != outer and isinstance(node, ast.FunctionDef): return
|
|
if isinstance(node, ast.Call):
|
|
if node.func.id == outer.name:
|
|
return ast.Yield(ast.Tuple([
|
|
ast.Tuple(node.args, ctx=ast.Load()),
|
|
ast.Tuple([ ast.Tuple([ast.Constant(value=kw.arg), kw.value], ctx=ast.Load()) for kw in node.keywords ], ctx=ast.Load())
|
|
], ctx=ast.Load()))
|
|
|
|
for name, field in ast.iter_fields(node):
|
|
if isinstance(field, ast.AST):
|
|
replacement = rewrite(field)
|
|
if replacement:
|
|
setattr(node, name, replacement)
|
|
elif isinstance(field, list):
|
|
for index, item in enumerate(field):
|
|
if isinstance(item, ast.AST):
|
|
replacement = rewrite(item)
|
|
if replacement:
|
|
field[index] = replacement
|
|
|
|
rewrite(old_ast)
|
|
ast.fix_missing_locations(old_ast)
|
|
ast.increment_lineno(old_ast, pos)
|
|
new_code_obj = compile(old_ast, old_code_obj.co_filename, "exec")
|
|
inner_function = types.FunctionType(next(x for x in new_code_obj.co_consts if isinstance(x, types.CodeType)), f.__globals__, f.__name__, f.__defaults__)
|
|
|
|
def trampoline(*args, **kwargs):
|
|
stk = [inner_function(*args, **kwargs)]
|
|
return_value = SENTINEL
|
|
while stk:
|
|
top = stk[-1]
|
|
try:
|
|
if return_value is not SENTINEL:
|
|
args, kwargs = top.send(return_value)
|
|
else:
|
|
args, kwargs = next(top)
|
|
kwargs = dict(kwargs)
|
|
stk.append(inner_function(*args, **kwargs))
|
|
return_value = SENTINEL
|
|
continue
|
|
except StopIteration as i:
|
|
return_value = i.value
|
|
stk.pop()
|
|
return return_value if return_value is not SENTINEL else None
|
|
|
|
return trampoline
|
|
|
|
@rewrite_recursion
|
|
def rec_demo(n):
|
|
if n <= 1: return n
|
|
return rec_demo(n-1)
|
|
|
|
print(rec_demo(10000)) |