Skip to content

Commit 937105f

Browse files
authored
Merge pull request #25 from Wytamma/update-docstrings
Update docstrings
2 parents 72e3f5b + 1e356f3 commit 937105f

File tree

11 files changed

+134
-28
lines changed

11 files changed

+134
-28
lines changed

tests/data/multiply_docstring.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
def multiply(a, b):
1+
def multiply(a, c):
22
"""
33
Multiplies 2 numbers.
44
Args:
@@ -10,4 +10,4 @@ def multiply(a, b):
1010
>>> multiply(2, 3)
1111
6
1212
"""
13-
return a * b
13+
return a * c

tests/test_cst_docstring_remover.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import pytest
22
import libcst as cst
3-
from write_the.cst.docstring_remover import DocstringRemover, remove_docstrings
3+
from write_the.cst.docstring_remover import DocstringRemover, remove_docstrings_from_tree
44
from write_the.cst.utils import get_docstring
55

66

@@ -47,7 +47,7 @@ def test_leave_ClassDef(tree, nodes):
4747

4848

4949
def test_remove_docstrings(tree, nodes):
50-
updated_tree = remove_docstrings(tree, nodes)
50+
updated_tree = remove_docstrings_from_tree(tree, nodes)
5151
assert get_docstring(updated_tree.body[0]) is None
5252
assert get_docstring(updated_tree.body[1]) == "'''This is another docstring.'''"
5353
assert get_docstring(updated_tree.body[2]) == "'''This is a class docstring.'''"

write_the/cli/main.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,12 @@ async def docs(
9898
"-f",
9999
help="Generate docstings even if they already exist.",
100100
),
101+
update: bool = typer.Option(
102+
False,
103+
"--update/--no-update",
104+
"-u",
105+
help="Update the existing docstrings.",
106+
),
101107
batch: bool = typer.Option(
102108
False, "--batch/--no-batch", "-b", help="Send each node as a separate request."
103109
),
@@ -133,6 +139,7 @@ async def docs(
133139
file,
134140
nodes=nodes,
135141
force=force,
142+
update=update,
136143
save=save,
137144
context=context,
138145
background=background,

write_the/cli/tasks.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from write_the.commands import write_the_docs
2+
from write_the.errors import FileSkippedError
23
from write_the.utils import create_tree, format_source_code, load_source_code
34
from rich.syntax import Syntax
45
from rich.progress import Progress
@@ -10,6 +11,7 @@
1011
async def async_cli_task(
1112
file: Path,
1213
nodes: List,
14+
update: bool,
1315
force: bool,
1416
save: bool,
1517
context: bool,
@@ -44,6 +46,7 @@ async def async_cli_task(
4446
"""
4547
task_id = progress.add_task(description=f"{file}", total=None)
4648
failed = False
49+
skipped = False
4750
source_code = load_source_code(file=file)
4851
if pretty:
4952
source_code = format_source_code(source_code)
@@ -56,6 +59,7 @@ async def async_cli_task(
5659
result = await write_the_docs(
5760
tree,
5861
node_names=nodes,
62+
update=update,
5963
force=force,
6064
save=save,
6165
context=context,
@@ -70,16 +74,27 @@ async def async_cli_task(
7074
except InvalidRequestError as e:
7175
msg = f" - {e}"
7276
failed = True
77+
except FileSkippedError as e:
78+
msg = f" - {e}"
79+
skipped = True
80+
7381
progress.remove_task(task_id)
7482
progress.refresh()
75-
if print_status or save or failed:
76-
icon = "❌" if failed else "✅"
77-
colour = "red" if failed else "green"
83+
if print_status or save or failed or skipped:
84+
if skipped:
85+
icon = "⏭️"
86+
colour = "yellow"
87+
elif failed:
88+
icon = "❌"
89+
colour = "red"
90+
else:
91+
icon = "✅"
92+
colour = "green"
7893
progress.print(
79-
f"[not underline]{icon} [/not underline]{file}{msg}",
80-
style=f"bold {colour} underline",
94+
f"[not underline]{icon} [/not underline][underline]{file}[/underline]{msg}",
95+
style=f"bold {colour}",
8196
)
82-
if failed:
97+
if failed or skipped:
8398
return None
8499
if save:
85100
with open(file, "w") as f:

write_the/commands/docs/docs.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,15 @@
88
from write_the.cst.node_extractor import extract_nodes_from_tree
99
from write_the.cst.node_batcher import create_batches
1010
from write_the.commands.docs.utils import extract_block
11+
from write_the.errors import FileSkippedError
1112
from write_the.llm import LLM
12-
from .prompts import write_docstings_for_nodes_prompt
13+
from .prompts import write_docstrings_for_nodes_prompt, update_docstrings_for_nodes_prompt
1314

1415

1516
async def write_the_docs(
1617
tree: cst.Module,
1718
node_names=[],
19+
update=False,
1820
force=False,
1921
save=False,
2022
context=False,
@@ -55,11 +57,16 @@ async def write_the_docs(
5557
extract_specific_nodes = True
5658
force = True
5759
else:
58-
node_names = get_node_names(tree, force)
60+
node_names = get_node_names(tree, force=force, update=update)
5961
if not node_names:
60-
return tree.code
61-
# batch
62-
llm = LLM(write_docstings_for_nodes_prompt, model_name=model)
62+
raise FileSkippedError("No nodes found, skipping file...")
63+
if update:
64+
remove_docstrings = False
65+
llm = LLM(update_docstrings_for_nodes_prompt, model_name=model)
66+
else:
67+
remove_docstrings = True
68+
llm = LLM(write_docstrings_for_nodes_prompt, model_name=model)
69+
6370
batches = create_batches(
6471
tree=tree,
6572
node_names=node_names,
@@ -69,6 +76,7 @@ async def write_the_docs(
6976
max_batch_size=max_batch_size,
7077
send_background_context=background,
7178
send_node_context=context,
79+
remove_docstrings=remove_docstrings
7280
)
7381
promises = []
7482
node_names_list = []
@@ -82,7 +90,7 @@ async def write_the_docs(
8290
docstring_dict = {}
8391
for node_names, result in zip(node_names_list, results):
8492
docstring_dict.update(extract_block(result, node_names))
85-
modified_tree = add_docstrings_to_tree(tree, docstring_dict, force=force)
93+
modified_tree = add_docstrings_to_tree(tree, docstring_dict, force=force or update)
8694
if not save and extract_specific_nodes:
8795
extracted_nodes = extract_nodes_from_tree(modified_tree, node_names)
8896
modified_tree = nodes_to_tree(extracted_nodes)

write_the/commands/docs/prompts.py

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,56 @@ def add(a, b):
2929
{code}
3030
Formatted docstrings for {nodes}:
3131
"""
32-
write_docstings_for_nodes_prompt = PromptTemplate(
32+
write_docstrings_for_nodes_prompt = PromptTemplate(
3333
input_variables=["code", "nodes"], template=docs_template
3434
)
35+
36+
update_docs_template = """
37+
Update the Google style docstrings to match the code.
38+
Add, update or remove description, parameter types, exceptions, side effects, notes, examples, etc. if required.
39+
Return only the docstrings, with function/class names as yaml keys.
40+
Use the Class.method format for methods.
41+
42+
Example:
43+
def add(first, second, third=0):
44+
\"\"\"
45+
Sums 2 numbers.
46+
47+
Args:
48+
a (int): The first number to add.
49+
b (int): The second number to add.
50+
51+
Returns:
52+
int: The sum of a and b.
53+
54+
Examples:
55+
>>> add(1, 2)
56+
3
57+
\"\"\"
58+
return first + second + third
59+
Updated docstrings for add:
60+
add:
61+
Sums up to 3 numbers.
62+
63+
Args:
64+
first (int): The first number to add.
65+
second (int): The second number to add.
66+
third (int, optional): The third number to add. Defaults to 0.
67+
68+
Returns:
69+
int: The sum of first, second, and third.
70+
71+
Examples:
72+
>>> add(1, 2)
73+
3
74+
>>> add(1, 2, 3)
75+
6
76+
77+
Code:
78+
{code}
79+
Updated docstrings for {nodes}:
80+
"""
81+
82+
update_docstrings_for_nodes_prompt = PromptTemplate(
83+
input_variables=["code", "nodes"], template=update_docs_template
84+
)

write_the/cst/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from .docstring_adder import DocstringAdder, has_docstring
2-
from .docstring_remover import DocstringRemover, remove_docstrings
2+
from .docstring_remover import DocstringRemover, remove_docstrings_from_tree
33
from .function_and_class_collector import FunctionAndClassCollector, get_node_names
44
from .node_extractor import NodeExtractor, extract_nodes_from_tree
55
from .node_remover import NodeRemover, remove_nodes_from_tree

write_the/cst/docstring_remover.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def leave_ClassDef(
5252
return updated_node
5353

5454

55-
def remove_docstrings(tree, nodes):
55+
def remove_docstrings_from_tree(tree, nodes):
5656
"""
5757
Removes the docstrings from a tree of nodes.
5858
Args:

write_the/cst/function_and_class_collector.py

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,27 @@
33

44

55
class FunctionAndClassCollector(cst.CSTVisitor):
6-
def __init__(self, force):
6+
"""
7+
A CSTVisitor that collects the names of functions and classes from a CST tree.
8+
"""
9+
def __init__(self, force, update=False):
710
"""
811
Initializes the FunctionAndClassCollector.
12+
913
Args:
1014
force (bool): Whether to force the collection of functions and classes even if they have docstrings.
15+
update (bool): Whether to update the collection of functions and classes if they have docstrings.
1116
"""
1217
self.functions = []
1318
self.classes = []
1419
self.force = force
20+
self.update = update
1521
self.current_class = None
1622

1723
def visit_FunctionDef(self, node: cst.FunctionDef) -> None:
1824
"""
19-
Visits a FunctionDef node and adds it to the list of functions if it does not have a docstring or if `force` is `True`.
25+
Visits a FunctionDef node and adds its name to the list of functions if it does not have a docstring or if `force` or `update` is `True`.
26+
2027
Args:
2128
node (cst.FunctionDef): The FunctionDef node to visit.
2229
"""
@@ -25,33 +32,48 @@ def visit_FunctionDef(self, node: cst.FunctionDef) -> None:
2532
if self.current_class
2633
else node.name.value
2734
)
28-
if not has_docstring(node) or self.force:
35+
if self.force:
36+
self.functions.append(name)
37+
elif has_docstring(node) and self.update:
38+
self.functions.append(name)
39+
elif not has_docstring(node) and not self.update:
2940
self.functions.append(name)
3041

3142
def visit_ClassDef(self, node: cst.ClassDef) -> None:
3243
"""
33-
Visits a ClassDef node and adds it to the list of classes if it does not have a docstring or if `force` is `True`.
44+
Visits a ClassDef node and adds its name to the list of classes if it does not have a docstring or if `force` or `update` is `True`. Also sets the current class name for nested function collection.
45+
3446
Args:
3547
node (cst.ClassDef): The ClassDef node to visit.
3648
"""
3749
self.current_class = node.name.value
38-
if not has_docstring(node) or self.force:
50+
if self.force:
51+
self.classes.append(node.name.value)
52+
elif has_docstring(node) and self.update:
53+
self.classes.append(node.name.value)
54+
elif not has_docstring(node) and not self.update:
3955
self.classes.append(node.name.value)
4056
# self.visit_ClassDef(node) # Call the superclass method to continue the visit
4157

4258
def leave_ClassDef(self, node: cst.ClassDef) -> None:
59+
"""
60+
Resets the current class name when leaving a ClassDef node.
61+
"""
4362
self.current_class = None
4463

4564

46-
def get_node_names(tree, force):
65+
def get_node_names(tree, force, update=False):
4766
"""
4867
Gets the names of functions and classes from a CST tree.
68+
4969
Args:
5070
tree (cst.CSTNode): The CST tree to traverse.
5171
force (bool): Whether to force the collection of functions and classes even if they have docstrings.
72+
update (bool, optional): Whether to update the collection of functions and classes if they have docstrings. Defaults to False.
73+
5274
Returns:
5375
list[str]: A list of function and class names.
5476
"""
55-
collector = FunctionAndClassCollector(force)
77+
collector = FunctionAndClassCollector(force, update)
5678
tree.visit(collector)
5779
return collector.classes + collector.functions

write_the/cst/node_batcher.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from typing import List, Optional
33
import libcst as cst
44
import tiktoken
5-
from write_the.cst.docstring_remover import remove_docstrings
5+
from write_the.cst.docstring_remover import remove_docstrings_from_tree
66
from write_the.cst.node_extractor import extract_node_from_tree, extract_nodes_from_tree
77
from write_the.cst.node_remover import remove_nodes_from_tree
88
from write_the.cst.utils import get_code_from_node, nodes_to_tree
@@ -173,6 +173,7 @@ def create_batches(
173173
max_batch_size=None,
174174
send_background_context=True,
175175
send_node_context=True,
176+
remove_docstrings=True,
176177
) -> List[NodeBatch]:
177178
"""
178179
Creates batches of nodes from a tree.
@@ -191,7 +192,8 @@ def create_batches(
191192
>>> create_batches(tree, node_names, max_tokens, prompt_size, response_size_per_node)
192193
[NodeBatch(...), NodeBatch(...)]
193194
"""
194-
tree = remove_docstrings(tree, node_names) # TODO: fix to use Class.method syntax
195+
if remove_docstrings:
196+
tree = remove_docstrings_from_tree(tree, node_names) # TODO: fix to use Class.method syntax
195197
batches = []
196198
background = None
197199
if send_background_context:

0 commit comments

Comments
 (0)