Skip to content

Commit

Permalink
Merge pull request #25 from Wytamma/update-docstrings
Browse files Browse the repository at this point in the history
Update docstrings
  • Loading branch information
Wytamma authored Mar 15, 2024
2 parents 72e3f5b + 1e356f3 commit 937105f
Show file tree
Hide file tree
Showing 11 changed files with 134 additions and 28 deletions.
4 changes: 2 additions & 2 deletions tests/data/multiply_docstring.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
def multiply(a, b):
def multiply(a, c):
"""
Multiplies 2 numbers.
Args:
Expand All @@ -10,4 +10,4 @@ def multiply(a, b):
>>> multiply(2, 3)
6
"""
return a * b
return a * c
4 changes: 2 additions & 2 deletions tests/test_cst_docstring_remover.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest
import libcst as cst
from write_the.cst.docstring_remover import DocstringRemover, remove_docstrings
from write_the.cst.docstring_remover import DocstringRemover, remove_docstrings_from_tree
from write_the.cst.utils import get_docstring


Expand Down Expand Up @@ -47,7 +47,7 @@ def test_leave_ClassDef(tree, nodes):


def test_remove_docstrings(tree, nodes):
updated_tree = remove_docstrings(tree, nodes)
updated_tree = remove_docstrings_from_tree(tree, nodes)
assert get_docstring(updated_tree.body[0]) is None
assert get_docstring(updated_tree.body[1]) == "'''This is another docstring.'''"
assert get_docstring(updated_tree.body[2]) == "'''This is a class docstring.'''"
Expand Down
7 changes: 7 additions & 0 deletions write_the/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,12 @@ async def docs(
"-f",
help="Generate docstings even if they already exist.",
),
update: bool = typer.Option(
False,
"--update/--no-update",
"-u",
help="Update the existing docstrings.",
),
batch: bool = typer.Option(
False, "--batch/--no-batch", "-b", help="Send each node as a separate request."
),
Expand Down Expand Up @@ -133,6 +139,7 @@ async def docs(
file,
nodes=nodes,
force=force,
update=update,
save=save,
context=context,
background=background,
Expand Down
27 changes: 21 additions & 6 deletions write_the/cli/tasks.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from write_the.commands import write_the_docs
from write_the.errors import FileSkippedError
from write_the.utils import create_tree, format_source_code, load_source_code
from rich.syntax import Syntax
from rich.progress import Progress
Expand All @@ -10,6 +11,7 @@
async def async_cli_task(
file: Path,
nodes: List,
update: bool,
force: bool,
save: bool,
context: bool,
Expand Down Expand Up @@ -44,6 +46,7 @@ async def async_cli_task(
"""
task_id = progress.add_task(description=f"{file}", total=None)
failed = False
skipped = False
source_code = load_source_code(file=file)
if pretty:
source_code = format_source_code(source_code)
Expand All @@ -56,6 +59,7 @@ async def async_cli_task(
result = await write_the_docs(
tree,
node_names=nodes,
update=update,
force=force,
save=save,
context=context,
Expand All @@ -70,16 +74,27 @@ async def async_cli_task(
except InvalidRequestError as e:
msg = f" - {e}"
failed = True
except FileSkippedError as e:
msg = f" - {e}"
skipped = True

progress.remove_task(task_id)
progress.refresh()
if print_status or save or failed:
icon = "❌" if failed else "✅"
colour = "red" if failed else "green"
if print_status or save or failed or skipped:
if skipped:
icon = "⏭️"
colour = "yellow"
elif failed:
icon = "❌"
colour = "red"
else:
icon = "✅"
colour = "green"
progress.print(
f"[not underline]{icon} [/not underline]{file}{msg}",
style=f"bold {colour} underline",
f"[not underline]{icon} [/not underline][underline]{file}[/underline]{msg}",
style=f"bold {colour}",
)
if failed:
if failed or skipped:
return None
if save:
with open(file, "w") as f:
Expand Down
20 changes: 14 additions & 6 deletions write_the/commands/docs/docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,15 @@
from write_the.cst.node_extractor import extract_nodes_from_tree
from write_the.cst.node_batcher import create_batches
from write_the.commands.docs.utils import extract_block
from write_the.errors import FileSkippedError
from write_the.llm import LLM
from .prompts import write_docstings_for_nodes_prompt
from .prompts import write_docstrings_for_nodes_prompt, update_docstrings_for_nodes_prompt


async def write_the_docs(
tree: cst.Module,
node_names=[],
update=False,
force=False,
save=False,
context=False,
Expand Down Expand Up @@ -55,11 +57,16 @@ async def write_the_docs(
extract_specific_nodes = True
force = True
else:
node_names = get_node_names(tree, force)
node_names = get_node_names(tree, force=force, update=update)
if not node_names:
return tree.code
# batch
llm = LLM(write_docstings_for_nodes_prompt, model_name=model)
raise FileSkippedError("No nodes found, skipping file...")
if update:
remove_docstrings = False
llm = LLM(update_docstrings_for_nodes_prompt, model_name=model)
else:
remove_docstrings = True
llm = LLM(write_docstrings_for_nodes_prompt, model_name=model)

batches = create_batches(
tree=tree,
node_names=node_names,
Expand All @@ -69,6 +76,7 @@ async def write_the_docs(
max_batch_size=max_batch_size,
send_background_context=background,
send_node_context=context,
remove_docstrings=remove_docstrings
)
promises = []
node_names_list = []
Expand All @@ -82,7 +90,7 @@ async def write_the_docs(
docstring_dict = {}
for node_names, result in zip(node_names_list, results):
docstring_dict.update(extract_block(result, node_names))
modified_tree = add_docstrings_to_tree(tree, docstring_dict, force=force)
modified_tree = add_docstrings_to_tree(tree, docstring_dict, force=force or update)
if not save and extract_specific_nodes:
extracted_nodes = extract_nodes_from_tree(modified_tree, node_names)
modified_tree = nodes_to_tree(extracted_nodes)
Expand Down
52 changes: 51 additions & 1 deletion write_the/commands/docs/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,56 @@ def add(a, b):
{code}
Formatted docstrings for {nodes}:
"""
write_docstings_for_nodes_prompt = PromptTemplate(
write_docstrings_for_nodes_prompt = PromptTemplate(
input_variables=["code", "nodes"], template=docs_template
)

update_docs_template = """
Update the Google style docstrings to match the code.
Add, update or remove description, parameter types, exceptions, side effects, notes, examples, etc. if required.
Return only the docstrings, with function/class names as yaml keys.
Use the Class.method format for methods.
Example:
def add(first, second, third=0):
\"\"\"
Sums 2 numbers.
Args:
a (int): The first number to add.
b (int): The second number to add.
Returns:
int: The sum of a and b.
Examples:
>>> add(1, 2)
3
\"\"\"
return first + second + third
Updated docstrings for add:
add:
Sums up to 3 numbers.
Args:
first (int): The first number to add.
second (int): The second number to add.
third (int, optional): The third number to add. Defaults to 0.
Returns:
int: The sum of first, second, and third.
Examples:
>>> add(1, 2)
3
>>> add(1, 2, 3)
6
Code:
{code}
Updated docstrings for {nodes}:
"""

update_docstrings_for_nodes_prompt = PromptTemplate(
input_variables=["code", "nodes"], template=update_docs_template
)
2 changes: 1 addition & 1 deletion write_the/cst/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .docstring_adder import DocstringAdder, has_docstring
from .docstring_remover import DocstringRemover, remove_docstrings
from .docstring_remover import DocstringRemover, remove_docstrings_from_tree
from .function_and_class_collector import FunctionAndClassCollector, get_node_names
from .node_extractor import NodeExtractor, extract_nodes_from_tree
from .node_remover import NodeRemover, remove_nodes_from_tree
Expand Down
2 changes: 1 addition & 1 deletion write_the/cst/docstring_remover.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def leave_ClassDef(
return updated_node


def remove_docstrings(tree, nodes):
def remove_docstrings_from_tree(tree, nodes):
"""
Removes the docstrings from a tree of nodes.
Args:
Expand Down
36 changes: 29 additions & 7 deletions write_the/cst/function_and_class_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,27 @@


class FunctionAndClassCollector(cst.CSTVisitor):
def __init__(self, force):
"""
A CSTVisitor that collects the names of functions and classes from a CST tree.
"""
def __init__(self, force, update=False):
"""
Initializes the FunctionAndClassCollector.
Args:
force (bool): Whether to force the collection of functions and classes even if they have docstrings.
update (bool): Whether to update the collection of functions and classes if they have docstrings.
"""
self.functions = []
self.classes = []
self.force = force
self.update = update
self.current_class = None

def visit_FunctionDef(self, node: cst.FunctionDef) -> None:
"""
Visits a FunctionDef node and adds it to the list of functions if it does not have a docstring or if `force` is `True`.
Visits a FunctionDef node and adds its name to the list of functions if it does not have a docstring or if `force` or `update` is `True`.
Args:
node (cst.FunctionDef): The FunctionDef node to visit.
"""
Expand All @@ -25,33 +32,48 @@ def visit_FunctionDef(self, node: cst.FunctionDef) -> None:
if self.current_class
else node.name.value
)
if not has_docstring(node) or self.force:
if self.force:
self.functions.append(name)
elif has_docstring(node) and self.update:
self.functions.append(name)
elif not has_docstring(node) and not self.update:
self.functions.append(name)

def visit_ClassDef(self, node: cst.ClassDef) -> None:
"""
Visits a ClassDef node and adds it to the list of classes if it does not have a docstring or if `force` is `True`.
Visits a ClassDef node and adds its name to the list of classes if it does not have a docstring or if `force` or `update` is `True`. Also sets the current class name for nested function collection.
Args:
node (cst.ClassDef): The ClassDef node to visit.
"""
self.current_class = node.name.value
if not has_docstring(node) or self.force:
if self.force:
self.classes.append(node.name.value)
elif has_docstring(node) and self.update:
self.classes.append(node.name.value)
elif not has_docstring(node) and not self.update:
self.classes.append(node.name.value)
# self.visit_ClassDef(node) # Call the superclass method to continue the visit

def leave_ClassDef(self, node: cst.ClassDef) -> None:
"""
Resets the current class name when leaving a ClassDef node.
"""
self.current_class = None


def get_node_names(tree, force):
def get_node_names(tree, force, update=False):
"""
Gets the names of functions and classes from a CST tree.
Args:
tree (cst.CSTNode): The CST tree to traverse.
force (bool): Whether to force the collection of functions and classes even if they have docstrings.
update (bool, optional): Whether to update the collection of functions and classes if they have docstrings. Defaults to False.
Returns:
list[str]: A list of function and class names.
"""
collector = FunctionAndClassCollector(force)
collector = FunctionAndClassCollector(force, update)
tree.visit(collector)
return collector.classes + collector.functions
6 changes: 4 additions & 2 deletions write_the/cst/node_batcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import List, Optional
import libcst as cst
import tiktoken
from write_the.cst.docstring_remover import remove_docstrings
from write_the.cst.docstring_remover import remove_docstrings_from_tree
from write_the.cst.node_extractor import extract_node_from_tree, extract_nodes_from_tree
from write_the.cst.node_remover import remove_nodes_from_tree
from write_the.cst.utils import get_code_from_node, nodes_to_tree
Expand Down Expand Up @@ -173,6 +173,7 @@ def create_batches(
max_batch_size=None,
send_background_context=True,
send_node_context=True,
remove_docstrings=True,
) -> List[NodeBatch]:
"""
Creates batches of nodes from a tree.
Expand All @@ -191,7 +192,8 @@ def create_batches(
>>> create_batches(tree, node_names, max_tokens, prompt_size, response_size_per_node)
[NodeBatch(...), NodeBatch(...)]
"""
tree = remove_docstrings(tree, node_names) # TODO: fix to use Class.method syntax
if remove_docstrings:
tree = remove_docstrings_from_tree(tree, node_names) # TODO: fix to use Class.method syntax
batches = []
background = None
if send_background_context:
Expand Down
2 changes: 2 additions & 0 deletions write_the/errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
class FileSkippedError(Exception):
"""Exception raised when a file operation is intentionally skipped."""

0 comments on commit 937105f

Please sign in to comment.