During my time as a student researcher for NVIDIA, my team worked on improving code generation models. We carefully studied the different type of errors that the code generation models were producing and explored ways to categorise them and find ways to automatically detect the type of error that was produced.

We first started off by using open source linters and static analysis tools such as pylint and flake8 to understand what kind of errors it was able to pick up and report. However, we quickly realized that the kind of error messages that it was able to produce is not specific enough. This is largely because in many examples, the generated code was not executable.

Under the hood, all these static analysis tools actually make use of Python’s AST module (Abstract Syntax Tree). It breaks down python “text” into entities which make up the nodes of the syntax tree. In fact, this is what happens every time any python script is run - the text is fed into the AST and then it is compiled and run.

Anyway, since the open source static analysis tools were not suitable for our use case, we decided to write our own code analyser using the AST library. The AST module is really easy to use and gives us plenty of control over what we would like to achieve. Look at some of the examples below to get a sense of what the module can do.

  1. Subclass ast.NodeVisitor

     import ast
        
     class SyntaxAnalyser(ast.NodeVisitor):
         def __init__(self):
     		...
    

    We need to first create a new class that specifies everything that we would like our code analyser to perform (counting functions or loops). This class should inherit from the ast.NodeVisitor module because that handles the traversal of the AST for us. The traversal is done depth-first.

  2. Include custom attributes

     class SyntaxAnalyser(ast.NodeVisitor):
         def __init__(self):
             self.imported = []
             self.created_clases = []
             self.libraries_used = []
             self.assigned_variables = []
             self.function_data = dict()
    

    This is pretty straight forward but this step is just to track/collect any relevant information about your code as it is being traversed.

  3. Override functions

     class SyntaxAnalyser(ast.NodeVisitor):
         def __init__(self):
             ...
        
     		def visit_Call(self, node):
     				...
        
     		def visit_Attribute(self, node):
     				...
        
     		def visit_ClassDef(self, node):
     				...
    

    This is where we can specify what exactly we would like our code analyser to do when it encounters a particular type of node. For example, if we would like to count the number of functions in our code and also store the name of the functions, we can override the visit_FunctionDef method to extract the name of the function and also increment a counter whenever a function node is encountered.

  4. Search recursively with self.generic_visit(node)

     class SyntaxAnalyser(ast.NodeVisitor):
         def __init__(self):
             ...
     		def visit_FunctionDef(self, node):
     				...
     				# do something - extract the function name...
     				**self.generic_visit(node)**
    

    self.generic_visit(node) is a handy tool that helps us to recursively search within the node subtree. This means we can easily count nested functions or loops.

  5. Complete example - how to find unused imports:

     import ast
        
     class UnusedImportVisitor(ast.NodeVisitor):
         def __init__(self):
             self.imports = set()
             self.used = set()
        
         def visit_Import(self, node):
             for alias in node.names:
                 self.imports.add(alias.name)
        
         def visit_ImportFrom(self, node):
             for alias in node.names:
                 self.imports.add(alias.name)
        
         def visit_Name(self, node):
             if isinstance(node.ctx, ast.Load) and node.id in self.imports:
                 self.used.add(node.id)
        
         def report_unused_imports(self):
             unused_imports = sorted(self.imports - self.used)
             if unused_imports:
                 print(f"Unused imports found: {unused_imports}"))
        
     if __name__ == "__main__":
         code = """
         import sys
         from os import path
         import datetime
         from math import sqrt
        
         def main():
             print(sqrt(16))
        
         if __name__ == "__main__":
             main()
         """
         tree = ast.parse(code)
         visitor = UnusedImportVisitor()
         visitor.visit(tree)
         visitor.report_unused_imports()
    

Why even use AST?

It might seem like the AST module is not doing much and there are simpler ways of doing things like counting number of functions or import purely using text analysis. For example:

filename = "path/to/file.py"
function_count = 0

with open(filename, "r") as f:
    for line in f:
        if "def" in line.strip().split():
            function_count += 1

print(f"There are {function_count} functions in {filename}")

But what if file.py looks like this (example of generated code from codegen):

def mapValues(self, f):
        """
        Pass each value in the key-value pair RDD through a map function
        without changing the keys; this also retains the original RDD's
        partitioning.

        >>> x = sc.parallelize([("a", ["apple", "banana", "lemon"]), ("b", ["grapes"])])
        >>> def f(x): return len(x)
        >>> x.mapValues(f).collect()
        [('a', 3), ('b', 1)]
        """
        return self._mapValues(f, self._jrdd.mapValues)

or this:

def is_palindrome(string):
    """
    Check if a given string is a palindrome.

    >>> is_palindrome('racecar')
    True
    >>> is_palindrome('hello')
    False
    >>> is_palindrome('')
    True
    """
    return string == string[::-1]

"""
def is_palindrome(word):
    if len(word) <= 1:
        return True
    else:
        return word[0] == word[-1] and is_palindrome(word[1:-1])
"""

The problem is that string matching approaches are not robust enough and can make mistakes if there are statements that resemble a function in the code. AST based approaches can avoid these issues.

However, using AST still might not be useful if the code is not executable in the first place. In this case, one approach could be to remove lines from the end of the file one-by-one and try to compile until it successfully compiles. It is possible that the the error is on the first line so the whole file is unable to be analysed using AST.