Skip to content

Commit

Permalink
Merge pull request #7 from Wytamma/batch-nodes
Browse files Browse the repository at this point in the history
Batch nodes
  • Loading branch information
Wytamma authored Apr 30, 2023
2 parents 163b3fe + 3a6ddbb commit 4753e76
Show file tree
Hide file tree
Showing 33 changed files with 550 additions and 145 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ site/
test.py
coverage.xml
.DS_Store
wizard/
3 changes: 2 additions & 1 deletion tests/data/calculate.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
class Calculate:
def multiply(a, b):
return a * b

def add(a, b):
return a * b
return a * b
2 changes: 1 addition & 1 deletion tests/data/multiply.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
def multiply(a, b):
return a * b
return a * b
2 changes: 1 addition & 1 deletion tests/data/multiply_docstring.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@ def multiply(a, b):
>>> multiply(2, 3)
6
"""
return a * b
return a * b
24 changes: 15 additions & 9 deletions tests/test_cli_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import unittest.mock as mock


@pytest.fixture(scope='function')
@pytest.fixture(scope="function")
def file_path(tmp_path) -> Path:
temp_file = tmp_path / "test_add.py"
temp_file.write_text("def add(a, b):\n return a + b")
Expand Down Expand Up @@ -36,7 +36,10 @@ def test_callback_version():
(False, True, True, True),
],
)
@mock.patch('write_the.llm.LLM.run', return_value="\n\nadd:\n Sums 2 numbers.\n Args:\n a (int): The first number to add.\n b (int): The second number to add.\n Returns:\n int: The sum of `a` and `b`.\n Examples:\n >>> add(1, 2)\n 3\n\n")
@mock.patch(
"write_the.llm.LLM.run",
return_value="\n\nadd:\n Sums 2 numbers.\n Args:\n a (int): The first number to add.\n b (int): The second number to add.\n Returns:\n int: The sum of `a` and `b`.\n Examples:\n >>> add(1, 2)\n 3\n\n",
)
def test_docs_mocked(mocked_run, file_path: Path, nodes, save, context, pretty, force):
runner = CliRunner()
args = ["docs", str(file_path)]
Expand Down Expand Up @@ -75,9 +78,9 @@ def test_mkdocs(tmp_path: Path):
print(result.stdout)
assert result.exit_code == 0
files = [f.name for f in tmp_path.glob("*")]
assert 'mkdocs.yml' in files
assert '.github' in files
assert 'docs' in files
assert "mkdocs.yml" in files
assert ".github" in files
assert "docs" in files


@pytest.mark.parametrize(
Expand All @@ -90,14 +93,17 @@ def test_mkdocs(tmp_path: Path):
(False, True, True),
],
)
@mock.patch('write_the.llm.LLM.run', return_value="""@pytest.mark.parametrize(
@mock.patch(
"write_the.llm.LLM.run",
return_value="""@pytest.mark.parametrize(
"a, b, expected", [(2, 3, 5), (0, 5, 5), (-2, -3, -5), (2.5, 3, 5.5), (2, -3, -1)]
)
def test_add(a, b, expected):
assert add(a, b) == expected""")
assert add(a, b) == expected""",
)
def test_tests_mocked(mocked_run, file_path: Path, save, pretty, force):
runner = CliRunner()
test_dir = file_path.parent / 'docs'
test_dir = file_path.parent / "docs"
args = ["tests", str(file_path), "--out", test_dir]

if save:
Expand All @@ -117,4 +123,4 @@ def test_tests_mocked(mocked_run, file_path: Path, save, pretty, force):
assert "assert add(a, b) == expected" in test_file.read_text()
assert str(file_path) in result.stdout
else:
assert "assert add(a, b) == expected" in result.stdout
assert "assert add(a, b) == expected" in result.stdout
34 changes: 17 additions & 17 deletions tests/test_cst_docstring_adder.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,28 +22,28 @@ def function_def_node():
return cst.FunctionDef(
name=cst.Name("function_name"),
params=cst.Parameters(),
body=cst.IndentedBlock(
body=[cst.SimpleStatementLine(body=[cst.Pass()])]),
body=cst.IndentedBlock(body=[cst.SimpleStatementLine(body=[cst.Pass()])]),
)


@pytest.fixture
def class_def_node():
return cst.ClassDef(name=cst.Name("ClassName"), body=cst.IndentedBlock(body=[cst.SimpleStatementLine(body=[cst.Pass()])]))
return cst.ClassDef(
name=cst.Name("ClassName"),
body=cst.IndentedBlock(body=[cst.SimpleStatementLine(body=[cst.Pass()])]),
)


@pytest.fixture
def method_def_node():
method_def = cst.FunctionDef(
name=cst.Name("method_name"),
params=cst.Parameters(params=[cst.Param(name=cst.Name("cls"))]),
body=cst.IndentedBlock(
body=[cst.SimpleStatementLine(body=[cst.Pass()])]),
body=cst.IndentedBlock(body=[cst.SimpleStatementLine(body=[cst.Pass()])]),
)
return cst.ClassDef(
name=cst.Name("ClassName"),
body=cst.IndentedBlock(
body=[cst.SimpleStatementLine(body=[method_def])]),
body=cst.IndentedBlock(body=[cst.SimpleStatementLine(body=[method_def])]),
)


Expand All @@ -66,16 +66,14 @@ def test_leave_function_def_without_docstring(docstrings, force, function_def_no

def test_leave_class_def_with_docstring(docstrings, force, class_def_node):
docstring_adder = DocstringAdder(docstrings, force)
updated_node = docstring_adder.leave_ClassDef(
class_def_node, class_def_node)
updated_node = docstring_adder.leave_ClassDef(class_def_node, class_def_node)
assert has_docstring(updated_node) is False


def test_leave_class_def_without_docstring(docstrings, force, class_def_node):
docstrings.pop("ClassName.method_name")
docstring_adder = DocstringAdder(docstrings, force)
updated_node = docstring_adder.leave_ClassDef(
class_def_node, class_def_node)
updated_node = docstring_adder.leave_ClassDef(class_def_node, class_def_node)
assert not has_docstring(updated_node)


Expand All @@ -84,10 +82,8 @@ def test_leave_method_def_without_docstring(
):
docstrings.pop("ClassName.method_name")
docstring_adder = DocstringAdder(docstrings, force)
updated_node = docstring_adder.leave_ClassDef(
class_def_node, class_def_node)
updated_node = docstring_adder.leave_FunctionDef(
method_def_node, method_def_node)
updated_node = docstring_adder.leave_ClassDef(class_def_node, class_def_node)
updated_node = docstring_adder.leave_FunctionDef(method_def_node, method_def_node)
assert not has_docstring(updated_node)


Expand All @@ -110,10 +106,14 @@ def test_add_docstring_with_force(docstrings, function_def_node):
updated_node = docstring_adder.add_docstring(function_def_node)
assert has_docstring(updated_node)


def test_add_docstring_escape_newline(docstrings, function_def_node):
force = True
docstrings['function_name'] = """\\ntest\ntest\\\\n\\n"""
docstrings["function_name"] = """\\ntest\ntest\\\\n\\n"""
docstring_adder = DocstringAdder(docstrings, force)
updated_node = docstring_adder.add_docstring(function_def_node)
assert has_docstring(updated_node)
assert get_docstring(updated_node).strip('"""').strip() == """\\\\ntest\n test\\\\n\\\\n"""
assert (
get_docstring(updated_node).strip('"""').strip()
== """\\\\ntest\n test\\\\n\\\\n"""
)
3 changes: 2 additions & 1 deletion tests/test_cst_docstring_remover.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from write_the.cst.docstring_remover import DocstringRemover, remove_docstrings
from write_the.cst.utils import get_docstring


@pytest.fixture
def tree():
return cst.parse_module(
Expand Down Expand Up @@ -34,7 +35,7 @@ def nodes():
def test_leave_FunctionDef(tree, nodes):
remover = DocstringRemover(nodes)
updated_tree = tree.visit(remover)
assert get_docstring(updated_tree.body[0]) is None
assert get_docstring(updated_tree.body[0]) is None
assert get_docstring(updated_tree.body[1]) == "'''This is another docstring.'''"


Expand Down
4 changes: 2 additions & 2 deletions tests/test_cst_function_and_class_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ def test_visit_ClassDef_with_docstring(tree):


def test_get_node_names(tree, force):
assert get_node_names(tree, force) == ["foo", "Bar"]
assert get_node_names(tree, force) == ["Bar", "foo"]


def test_get_node_names_with_force_true(tree, force):
assert get_node_names(tree, True) == ["foo", "Bar"]
assert get_node_names(tree, True) == ["Bar", "foo"]
1 change: 1 addition & 0 deletions tests/test_cst_node_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import libcst as cst
from write_the.cst.node_extractor import NodeExtractor, extract_nodes_from_tree


@pytest.fixture
def tree():
return cst.parse_module(
Expand Down
45 changes: 28 additions & 17 deletions tests/test_cst_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,31 +2,42 @@
from write_the.cst.utils import has_docstring, nodes_to_tree
import pytest


@pytest.fixture
def cst_function_def():
return cst.FunctionDef(
name=cst.Name("function_name"),
params=cst.Parameters(),
body=cst.IndentedBlock(
body=[
cst.SimpleStatementLine(body=[cst.Expr(value=cst.SimpleString('"""This is a docstring."""'))]),
cst.SimpleStatementLine(body=[cst.Pass()]),
]
),
)
name=cst.Name("function_name"),
params=cst.Parameters(),
body=cst.IndentedBlock(
body=[
cst.SimpleStatementLine(
body=[
cst.Expr(value=cst.SimpleString('"""This is a docstring."""'))
]
),
cst.SimpleStatementLine(body=[cst.Pass()]),
]
),
)


@pytest.fixture
def cst_class_def():
return cst.ClassDef(
name=cst.Name("ClassName"),
body=cst.IndentedBlock(
body=[
cst.SimpleStatementLine(body=[cst.Expr(value=cst.SimpleString('"""This is a class docstring."""'))]),
cst.SimpleStatementLine(body=[cst.Pass()]),
]
),
)
name=cst.Name("ClassName"),
body=cst.IndentedBlock(
body=[
cst.SimpleStatementLine(
body=[
cst.Expr(
value=cst.SimpleString('"""This is a class docstring."""')
)
]
),
cst.SimpleStatementLine(body=[cst.Pass()]),
]
),
)


def test_has_docstring_function_def(cst_function_def):
Expand Down
2 changes: 1 addition & 1 deletion write_the/__about__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.7.7"
__version__ = "0.8.0"
2 changes: 1 addition & 1 deletion write_the/cli/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .main import app
from .main import app
Loading

0 comments on commit 4753e76

Please sign in to comment.