Skip to content

Commit

Permalink
✨ better indentation
Browse files Browse the repository at this point in the history
  • Loading branch information
Wytamma committed Nov 10, 2023
1 parent a1ec796 commit ef92541
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 7 deletions.
40 changes: 38 additions & 2 deletions tests/test_cst_docstring_adder.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest
import libcst as cst
from write_the.cst.docstring_adder import DocstringAdder
from write_the.cst.docstring_adder import DocstringAdder, add_docstrings_to_tree
from write_the.cst.utils import has_docstring, get_docstring


Expand All @@ -22,7 +22,7 @@ def function_def_node():
return cst.FunctionDef(
name=cst.Name("function_name"),
params=cst.Parameters(),
body=cst.IndentedBlock(body=[cst.SimpleStatementLine(body=[cst.Pass()])]),
body=cst.IndentedBlock(body=[cst.SimpleStatementLine(body=[cst.Pass()])], indent=" "),
)


Expand Down Expand Up @@ -117,3 +117,39 @@ def test_add_docstring_escape_newline(docstrings, function_def_node):
get_docstring(updated_node).strip('"""').strip()
== """\\\\ntest\n test\\\\n\\\\n"""
)

def tree():
return cst.parse_module(
"""
def function_name():
pass
class ClassName:
def method_name():
pass
"""
)

def test_add_docstring_indentation():
docstrings = {
"function_name": """
This is a docstring for a function.
Args:
a (int): The first number to add.
b (int): The second number to add.
Returns:
int: The sum of `a` and `b`.
""",
"ClassName.method_name": """
This is a docstring for a method.
Args:
a (int): The first number to add.
b (int): The second number to add.
Returns:
int: The sum of `a` and `b`.
""",
}
modified_tree = add_docstrings_to_tree(tree(), docstrings, force=True)
code = modified_tree.code
assert " This is a docstring for a function." in code
assert " This is a docstring for a method." in code
14 changes: 9 additions & 5 deletions write_the/cst/docstring_adder.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@


class DocstringAdder(cst.CSTTransformer):
def __init__(self, docstrings, force):
def __init__(self, docstrings, force, indent=" "):
self.docstrings = docstrings
self.force = force
self.indent = indent
self.current_class = None

def leave_FunctionDef(
Expand Down Expand Up @@ -61,13 +62,16 @@ def add_docstring(self, node):
node = remove_docstring(node)
escaped_docstring = re.sub(r"(?<!\\)\\n", "\\\\\\\\n", docstring)
dedented_docstring = textwrap.dedent(escaped_docstring)
indented_docstring = textwrap.indent(dedented_docstring, " ")
new_docstring = cst.parse_statement(f'"""{indented_docstring} """')
body = node.body.with_changes(body=[new_docstring] + list(node.body.body))
indent = self.indent
if self.current_class:
indent = indent * 2
indented_docstring = textwrap.indent(dedented_docstring, indent)
new_docstring = cst.parse_statement(f'"""{indented_docstring}{indent}"""')
body = node.body.with_changes(body=(new_docstring, *node.body.body))
return node.with_changes(body=body)

return node


def add_docstrings_to_tree(tree, docstring_dict, force=False):
return tree.visit(DocstringAdder(docstring_dict, force))
return tree.visit(DocstringAdder(docstring_dict, force=force, indent=tree.config_for_parsing.default_indent))

0 comments on commit ef92541

Please sign in to comment.