A simple .py source file parser using AST
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 

536 lines
20 KiB

#!/usr/bin/env python3
# A simple python source code parser
# Author: Wengling Chen
# Date: Sat, 15 Feb 2020 06:18:42 -0500
from _ast import AST
import ast
import os
import pprint
import sys
import traceback
import warnings
from collections import defaultdict
from enum import Enum
class BranchingType(Enum):
If = 1
Else = 2
class BranchingTreeNode:
"""Node for branch tracing. This is an n-ary tree node."""
def __init__(self, val, type=None):
self._val = val # value of the node
self._type = type # type of the node
self._parent = None
self._var_list = defaultdict(list) # for variables tracing
self._subtrees = [] # list of subtree node objects
def add_child(self, node):
self._subtrees.append(node)
node.parent = self
def get_child(self):
return self._subtrees
def add_to_all_child_varlist(self, name, val):
self._var_list[name].append(val)
for sub in self._subtrees:
sub.add_to_all_child_varlist(name, val)
@property
def val(self):
return self._val
@property
def type(self):
return self._type
@property
def parent(self):
return self._parent
@parent.setter
def parent(self, node):
self._parent = node
@property
def var_list(self):
return self._var_list
class BranchingTree:
"""Contains branching information of a function."""
def __init__(self, func_def):
if not isinstance(func_def, ast.FunctionDef):
raise RuntimeError("Error building branching tree.")
self.root = BranchingTreeNode(func_def)
def search(self, val, type):
"""DFS search of the tree"""
return self._search_tree_sub(self.root, val, type)
def _search_tree_sub(self, node, val, type):
if node.val == val:
if type is None or node.type == type:
return node
for n in node.get_child():
ret = self._search_tree_sub(n, val, type)
if ret is not None:
return ret
def walk(self, node):
"""BFS walk of the tree, return all var_list"""
ret = [node.var_list]
for n in node.get_child():
ret.extend(self.walk(n))
return ret
def build(self):
self._build_subroutine(self.root, self.root.val)
# self.print_tree()
def _build_subroutine(self, node, ast_node):
cls_name = ast_node.__class__.__name__
if cls_name == "If":
num_else = len(ast_node.orelse)
exist = self.search(ast_node, BranchingType.If)
if not exist:
tree_node = BranchingTreeNode(ast_node, BranchingType.If)
node.add_child(tree_node)
self._ast_node_visit(tree_node, ast_node.body)
self._ast_node_visit(tree_node, ast_node.test)
if num_else > 0:
exist = self.search(ast_node, BranchingType.Else)
if not exist:
tree_node = BranchingTreeNode(ast_node, BranchingType.Else)
for i in range(num_else - 1, -1, -1):
self._ast_node_visit(tree_node, ast_node.orelse[i])
node.add_child(tree_node)
elif cls_name == "FunctionDef":
for field, value in iter_fields(ast_node):
self._ast_node_visit(node, value)
else:
for field, value in iter_fields(ast_node):
self._ast_node_visit(node, value)
def _ast_node_visit(self, node, value):
if isinstance(value, list):
for item in value:
if isinstance(item, AST):
self._build_subroutine(node, item)
elif isinstance(value, AST):
self._build_subroutine(node, value)
def __repr__(self):
out = self._to_string_sub(self.root, "")
return out
def __str__(self):
out = self._to_string_sub(self.root, "")
return out
def _to_string_sub(self, node, indent):
output = ""
_str = indent + str(node.type) + "\n"
if node.val.__class__.__name__ == "If":
_str += indent + " " + pprint.pformat(node.var_list, compact=False, indent=len(indent)- 4)
# _str += " " + ast.dump(node.val.test)
output += _str + "\n"
# Sort if..else
children = node.get_child()
if_list, else_list = [], []
for n in children:
if n.type == BranchingType.If:
if_list.append(n)
elif n.type == BranchingType.Else:
else_list.append(n)
children = if_list + else_list
for n in children:
output += self._to_string_sub(n, indent + " ")
return output
def iter_fields(node):
"""
Yield a tuple of ``(fieldname, value)`` for each field in ``node._fields``
that is present on *node*.
"""
for field in node._fields:
try:
yield field, getattr(node, field)
except AttributeError:
pass
class NodeVisitor(ast.NodeVisitor):
def __init__(self, caller):
super().__init__()
self.caller = caller
def visit(self, node):
"""Visit a node."""
method = 'visit_' + node.__class__.__name__
visitor = getattr(self, method, self.generic_visit)
return visitor(node)
def generic_visit(self, node):
"""Called if no explicit visitor function exists for a node."""
for field, value in iter_fields(node):
if isinstance(value, list):
for item in value:
if isinstance(item, AST):
self.visit(item)
elif isinstance(value, AST):
self.visit(value)
def func_body_visit(self, node, body_stat, func_stat, op_stat, branching_tree):
"""Visit a node. Function body ver."""
cls_name = node.__class__.__name__
return self.func_body_generic_visit(node, body_stat, func_stat, op_stat, branching_tree)
def func_body_generic_visit(self, node, body_stat, func_stat, op_stat, branching_tree):
"""Visit function used only for tracing variables and operations in function bodies"""
# TODO: class is not considered currently
cls_name = node.__class__.__name__
for field, value in iter_fields(node):
if isinstance(value, list):
for item in value:
if isinstance(item, AST):
self.func_body_visit(item, body_stat, func_stat, op_stat, branching_tree)
elif isinstance(value, AST):
self.func_body_visit(value, body_stat, func_stat, op_stat, branching_tree)
# print(ast.dump(node))
# Operation stats
body_stat[cls_name] += 1
# Function calling stats
if cls_name == "Call":
if node.func.__class__.__name__ == "Attribute":
func_stat[node.func.attr] += 1
else:
func_stat[node.func.id] += 1
# Variable tracing
if cls_name == "Assign":
for var in node.targets:
NodeVisitor._add_var_trace(branching_tree, var.id, "Assign", node.value.ret_type, node)
# var_list[var.id].append({"op": "Assign", "type": node.value.ret_type})
elif isinstance(node, ast.stmt):
node.ret_type = None
elif isinstance(node, ast.expr):
if cls_name == "BinOp":
sub_nodes = [node.left, node.right]
types = [None, None]
# Infer type of left and right node
for i in range(2):
if sub_nodes[i].__class__.__name__ == "Constant":
types[i] = type(sub_nodes[i].value).__name__
elif sub_nodes[i].__class__.__name__ == "Name":
try:
types[i] = NodeVisitor._get_var_list(branching_tree, sub_nodes[i])[sub_nodes[i].id][-1]["type"]
# types[i] = var_list[sub_nodes[i].id][-1]["type"]
except IndexError:
# The variable is not initialized inside context, so no way to determine type
types[i] = None
else:
types[i] = sub_nodes[i].ret_type # Inherent from previous operation
left_type, right_type = types
if left_type is not None:
if left_type == "str":
out_type = "str"
elif left_type == "int" or left_type == "float":
out_type = "float" # For simplicity
elif right_type is not None:
if right_type == "str":
out_type = "str"
elif right_type == "int" or right_type == "float":
if isinstance(node.op, ast.Add) or isinstance(node.op, ast.Sub):
out_type = "float"
else:
out_type = None # Cannot be determined
else:
warnings.warn("Type of both sides of an expression cannot be determined!")
out_type = None
# Add variable change to list
for i in range(2):
if sub_nodes[i].__class__.__name__ == "Name":
NodeVisitor._add_var_trace(branching_tree, sub_nodes[i].id, node.op.__class__.__name__,
types[i], sub_nodes[i])
# var_list[sub_nodes[i].id].append({"op": node.op.__class__.__name__, "type": types[i]})
node.ret_type = out_type
# Add OP stats
op_stat.append({'type': node.op.__class__.__name__, 'left': left_type, 'right': right_type})
elif cls_name == "BoolOp":
raise NotImplementedError("Not yet implemented")
elif cls_name == "Tuple" or cls_name == "List" or cls_name == "Set" or cls_name == "Dict":
node.ret_type = cls_name
elif cls_name == "Call":
# Known built-in functions goes here
if node.func.__class__.__name__ == "Attribute":
node.ret_type = None
elif node.func.id == "len":
node.ret_type = "int"
elif node.func.id == "int":
node.ret_type = "int"
elif node.func.id == "str":
node.ret_type = "str"
elif node.func.id == "float":
node.ret_type = "float"
else:
node.ret_type = None
elif cls_name == "Compare":
# Common results should all be boolean
out_type = "bool"
assert len(node.ops) == len(node.comparators)
sub_nodes = [node.left, *node.comparators]
types = [None for i in range(len(sub_nodes))]
for i in range(len(sub_nodes)):
if sub_nodes[i].__class__.__name__ == "Constant":
types[i] = type(sub_nodes[i].value).__name__
elif sub_nodes[i].__class__.__name__ == "Name":
try:
types[i] = NodeVisitor._get_var_list(branching_tree, sub_nodes[i])[sub_nodes[i].id][-1]["type"]
# types[i] = var_list[sub_nodes[i].id][-1]["type"]
except IndexError:
# The variable is not initialized inside context, so no way to determine type
types[i] = None
else:
types[i] = sub_nodes[i].ret_type # Inherent from previous operation
for i in range(len(sub_nodes)):
if sub_nodes[i].__class__.__name__ == "Name":
NodeVisitor._add_var_trace(branching_tree, sub_nodes[i].id,
node.ops[i - 1 if i > 0 else i].__class__.__name__,
types[i], sub_nodes[i])
# var_list[sub_nodes[i].id].append({"op": node.ops[i - 1 if i > 0 else i].__class__.__name__, "type": types[i]})
node.ret_type = out_type
elif cls_name == "Name":
try:
out_type = NodeVisitor._get_var_list(branching_tree, node)[node.id][-1]["type"]
# out_type = var_list[node.id][-1]["type"]
except IndexError:
# The variable is not initialized inside context, so no way to determine type
out_type = None
node.ret_type = out_type
elif cls_name == "Constant":
node.ret_type = type(node.value).__name__ # Nothing to do
elif cls_name == "Attribute":
sub_node = node.value
if sub_node.__class__.__name__ == "Constant":
n_type = type(sub_node.value).__name__
elif sub_node.__class__.__name__ == "Name":
try:
n_type = NodeVisitor._get_var_list(branching_tree, sub_node)[sub_node.id][-1]["type"]
# n_type = var_list[sub_node.id][-1]["type"]
except IndexError:
# The variable is not initialized inside context, so no way to determine type
n_type = None
else:
n_type = sub_node.ret_type # Inherent from previous operation
if sub_node.__class__.__name__ == "Name":
NodeVisitor._add_var_trace(branching_tree, sub_node.id, "Attribute", n_type, sub_node)
# var_list[sub_node.id].append({"op": "Attribute", "type": n_type})
node.ret_type = None # No easy way to know
else:
# The rest are unlikely to be used
raise RuntimeError("Unexpected expression encountered. Please inspect manually.")
def visit_FunctionDef(self, node):
"""Do not try to change the function name!"""
func_def = {"name": node.name, "args": [], "vararg": False, "kwarg": False}
args = node.args
# Parse arguments
if args.vararg is not None:
func_def["vararg"] = True
if args.kwarg is not None:
func_def["kwarg"] = True
if len(args.args) == 0:
pass
else:
for arg in args.args:
func_def["args"].append({"name": arg.arg})
if len(args.defaults) != 0:
idx = len(func_def["args"]) - 1
for i in range(len(args.defaults) - 1, -1, -1):
if args.defaults[i].__class__.__name__ == "Constant":
func_def["args"][idx]["default_val"] = args.defaults[i].value
else:
warnings.warn("Non-standard argument encountered. Visual inspection recommended.")
func_def["args"][idx]["default_val"] = ast.dump(args.defaults[i])
idx -= 1
# Parse body
# Collect used operators, statements and function names. Note that elif will be counted as multiple ifs.
# It also traces a history of all operations on variables. This can be used to detect e.g. str * int.
# The tracing is error prone, since Python is dynamic type.
# Add parent to nodes
node.parent = None
node.from_else = False
for n in ast.walk(node):
for child in ast.iter_child_nodes(n):
if n.__class__.__name__ == "If" and len(n.orelse) > 0 and child in n.orelse:
child.from_else = True
else:
child.from_else = False
child.parent = n
# Build branching tree
branching_tree = BranchingTree(node)
branching_tree.build()
# Variable tracing
body_stat = defaultdict(int)
func_stat = defaultdict(int)
op_list = []
# var_list = defaultdict(list)
for sub_nodes in node.body:
self.func_body_visit(sub_nodes, body_stat, func_stat, op_list, branching_tree)
func_def["body_stat"] = body_stat
func_def["func_stat"] = func_stat
func_def["op_list"] = op_list
func_def["branching_tree"] = branching_tree
# Parse returns
returns = []
for sub_nodes in node.body:
for n in ast.walk(sub_nodes):
if n.__class__.__name__ == "Return":
if n.value is None:
# Empty return
returns.append({"val": None, "type": None})
elif n.value.__class__.__name__ == "Constant":
returns.append({"val": n.value.value, "type": "Constant"})
elif n.value.__class__.__name__ == "Name":
returns.append({"val": n.value.id, "type": "Variable"})
else:
# the statement returns non-standard value (list, tuple, dict, func, etc.)
warnings.warn("Non-standard return statement encountered. Visual inspection recommended.")
returns.append({"val": ast.dump(n.value), "type": n.value.__class__.__name__})
func_def["returns"] = returns
self.caller.add_func_def(func_def)
@staticmethod
def _add_var_trace(branching_tree, name, op_type, ret_type, node):
# Find last branching spot
cur_node, search_type = NodeVisitor._find_branching_node(node)
var_hist = {"op": op_type, "type": ret_type}
tree_node = branching_tree.search(cur_node, search_type)
# Need to add history to all children of current node
if tree_node is None:
branching_tree.root.add_to_all_child_varlist(name, var_hist)
else:
tree_node.add_to_all_child_varlist(name, var_hist)
@staticmethod
def _find_branching_node(node):
cur_node = node
search_type = None
while cur_node:
if cur_node.parent and cur_node.parent.__class__.__name__ == "If":
if cur_node.from_else:
search_type = BranchingType.Else
else:
search_type = BranchingType.If
cur_node = cur_node.parent
break
cur_node = cur_node.parent
return cur_node, search_type
@staticmethod
def _get_var_list(branching_tree, node):
# Find last branching spot
cur_node, search_type = NodeVisitor._find_branching_node(node)
tree_node = branching_tree.search(cur_node, search_type)
if tree_node is None:
return branching_tree.root.var_list
else:
return tree_node.var_list
class FileParser:
def __init__(self, f):
"""
Read a python source file and parse it.
@param f: a file handler
"""
self.__tree = None
self.__func_list = {}
self.__tree = ast.parse(f.read())
@property
def func_list(self):
"""
Return the parsed list of functions
@return: a list of functions
"""
return self.__func_list
def add_func_def(self, func_def):
self.__func_list[func_def['name']] = func_def
def get_used_func(self):
"""
Get information on all functions defined in a python file. Note nested functions are flattened.
"""
nv = NodeVisitor(self)
nv.visit(self.__tree)
class MyPrettyPrinter(pprint.PrettyPrinter):
_dispatch = pprint.PrettyPrinter._dispatch.copy()
def _pprint_branching_tree(self, object, stream, indent, allowance, context, level):
stream.write('Branching Tree(')
stream.write(str(object))
stream.write(')')
_dispatch[BranchingTree.__repr__] = _pprint_branching_tree
if __name__ == "__main__":
fp = None
pp = MyPrettyPrinter()
with open('testfile', 'r') as f:
try:
fp = FileParser(f)
except SyntaxError as e:
print("SyntaxError in opened file")
traceback.print_exc()
except Exception:
traceback.print_exc()
if fp is not None:
fp.get_used_func()
pp.pprint(fp.func_list)