Last active
August 29, 2015 14:06
-
-
Save dhermes/85c3a3a464d2cff312ea to your computer and use it in GitHub Desktop.
An `ast` parser which (mostly) correctly indents Python docstrings.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import __builtin__ | |
import ast | |
import collections | |
import shutil | |
import sys | |
import tempfile | |
TRIPLE_QUOTES = ('"""', '\'\'\'') | |
# Also see: http://stackoverflow.com/a/17478618/1068170 | |
# and: ('https://bitbucket.org/aivarannamaa/thonny/src/' | |
# '85e09e98a08db63d75f158b435dd07fc7a00c27c/src/' | |
# 'ast_utils.py?at=default') | |
class HeavyAst(object): | |
def __init__(self, filename): | |
self.filename = filename | |
# NOTE: We use 'rU' so line splitting works. | |
with open(filename, 'rU') as fh: | |
contents = fh.read() | |
self.ast_tree = ast.parse(contents) | |
self.as_lines = contents.split('\n') | |
def find_content(self, doc_str): | |
first_line_end = doc_str.find('\n') | |
first_line_docstring = doc_str[:first_line_end] | |
matches = [i for i, line in enumerate(self.as_lines) | |
if first_line_docstring in line] | |
if len(matches) != 1: | |
raise ValueError('Can\'t find line in source code.') | |
candidate_line = matches[0] | |
actual_line = None | |
while actual_line is None: | |
first_three = self.as_lines[candidate_line].lstrip()[:3] | |
if first_three in TRIPLE_QUOTES: | |
actual_line = candidate_line | |
else: | |
candidate_line -= 1 | |
# NOTE: We add 1 since the lines in the file are 1-indexed. | |
return actual_line + 1 | |
def rewrite_file(self, all_docstrings): | |
backup_file_tmp = tempfile.mkstemp()[1] | |
new_file_tmp = tempfile.mkstemp()[1] | |
backup_fh = open(backup_file_tmp, 'w') | |
new_fh = open(new_file_tmp, 'w') | |
curr_docstring = None | |
line_no = 0 | |
for line_val in self.as_lines: | |
line_no += 1 # Start at 1. | |
if line_no != 1: # First line does not have a preceding line. | |
new_fh.write('\n') | |
backup_fh.write('\n') | |
# Write the lines as-is to the backup. | |
backup_fh.write(line_val) | |
# Check if a docstring is starting. | |
if line_no in all_docstrings: | |
if curr_docstring is not None: | |
raise ValueError('Two docstrings can\'t be simultaneous.') | |
curr_docstring = all_docstrings[line_no] | |
if curr_docstring is None: | |
new_fh.write(line_val) | |
else: | |
curr_docstring.write_line(new_fh, line_no) | |
if line_no == curr_docstring.end: | |
curr_docstring = None | |
backup_fh.close() | |
new_fh.close() | |
shutil.copyfile(backup_file_tmp, self.filename + '.bak') | |
print 'Created', self.filename + '.bak' | |
shutil.copyfile(new_file_tmp, self.filename) | |
print 'Over-wrote', self.filename | |
class DocstringObj(object): | |
def __init__(self, heavy_ast, ast_parent, ast_docstr_expr): | |
self.doc_str = ast.get_docstring(ast_parent) | |
self.doc_str_lines = self.doc_str.split('\n') | |
if isinstance(ast_parent, ast.Module): | |
self.start = 1 | |
self.col_offset = 0 | |
else: | |
self.start = heavy_ast.find_content(self.doc_str) | |
self.col_offset = ast_parent.col_offset + 4 | |
self.end = ast_docstr_expr.lineno | |
stated_length = self.end - self.start + 1 | |
missing_length = stated_length - len(self.doc_str_lines) | |
if missing_length < 0: | |
raise ValueError('Docstring is too long for reported start and end.') | |
self.doc_str_lines += [''] * missing_length | |
def get_line(self, line_no): | |
return self.doc_str_lines[line_no - self.start] | |
def write_line(self, fh, line_no): | |
line_val = self.get_line(line_no) | |
if line_no == self.start: | |
line_val = '"""' + line_val | |
if line_no == self.end: | |
line_val += '"""' | |
line_val = (' ' * self.col_offset) + line_val | |
fh.write(line_val.rstrip()) | |
def __repr__(self): | |
return 'DocstringObj(start=%d,end=%d)' % (self.start, self.end) | |
def get_docstring_obj(ast_obj, heavy_ast): | |
if not isinstance(ast_obj, (ast.Module, ast.ClassDef, ast.FunctionDef)): | |
# Only module, class and function/methods can have a docstring. | |
return None | |
obj_body = getattr(ast_obj, 'body', []) | |
if len(obj_body) == 0: | |
return | |
docstring_candidate = obj_body[0] | |
if (isinstance(docstring_candidate, ast.Expr) and | |
isinstance(docstring_candidate.value, ast.Str)): | |
return DocstringObj(heavy_ast, ast_obj, docstring_candidate) | |
def _get_all_docstrings(ast_obj, result, heavy_ast): | |
docstring_obj = get_docstring_obj(ast_obj, heavy_ast) | |
if docstring_obj is not None: | |
if docstring_obj.start in result: | |
raise KeyError('Start: %d already in result.' % (docstring_obj.start,)) | |
result[docstring_obj.start] = docstring_obj | |
child_objects = getattr(ast_obj, 'body', []) | |
if not isinstance(child_objects, collections.Iterable): | |
child_objects = [child_objects] | |
for child_ast_obj in child_objects: | |
_get_all_docstrings(child_ast_obj, result, heavy_ast) | |
def get_all_docstrings(heavy_ast): | |
if not isinstance(heavy_ast.ast_tree, ast.Module): | |
raise TypeError('Expected tree to be a module.') | |
result = {} | |
_get_all_docstrings(heavy_ast.ast_tree, result, heavy_ast) | |
return result | |
def rewrite_docstrings(filename): | |
heavy_ast = HeavyAst(filename) | |
all_docstrings = get_all_docstrings(heavy_ast) | |
heavy_ast.rewrite_file(all_docstrings) | |
return all_docstrings | |
def example(): | |
A_ORIG = '\n'.join([ | |
'def hello_func(name):', | |
' """Prints hello with the name.', | |
'', | |
' Args:', | |
' name: String, to print.', | |
' """', | |
' print \'Hello %s, nice to meet you.\' % (name,)', | |
'', | |
]) | |
A_PEP8IFY = '\n'.join([ | |
'def hello_func(name):', | |
' """Prints hello with the name.', | |
'', | |
' Args:', | |
' name: String, to print.', | |
' """', | |
' print \'Hello %s, nice to meet you.\' % (name, )', | |
'', | |
]) | |
A_DESIRED = '\n'.join([ | |
'def hello_func(name):', | |
' """Prints hello with the name.', | |
'', | |
' Args:', | |
' name: String, to print.', | |
' """', | |
' print \'Hello %s, nice to meet you.\' % (name, )', | |
'', | |
]) | |
filename = tempfile.mkstemp()[1] | |
print 'Making example with temp file:', filename | |
with open(filename, 'w') as fh: | |
fh.write(A_PEP8IFY) | |
all_docstrings = rewrite_docstrings(filename) | |
print '=' * 70 | |
print 'All docstrings found:' | |
for start in sorted(all_docstrings.keys()): | |
print all_docstrings[start] | |
# Check that back-up worked. | |
with open(filename + '.bak', 'r') as fh: | |
backed_up = fh.read() | |
if backed_up == A_PEP8IFY: | |
print 'Back-up succeeded.' | |
else: | |
raise ValueError('Back-up did not work correctly.') | |
# Check that the indent was correct. | |
with open(filename, 'r') as fh: | |
rewrite_content = fh.read() | |
if rewrite_content == A_DESIRED: | |
print 'Indent succeeded, new file:' | |
print ('=' * 70) | |
print rewrite_content, ('=' * 70) | |
else: | |
raise ValueError('Indent did not work correctly.') | |
if __name__ == '__main__': | |
# H/T: http://stackoverflow.com/a/9093598/1068170 | |
if hasattr(__builtin__, '__IPYTHON__'): | |
print 'In IPYTHON, not running main().' | |
else: | |
if len(sys.argv) > 1: | |
filename = sys.argv[1] | |
rewrite_docstrings(filename) | |
else: | |
example() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment