Skip to content

Keep inline annotations over docstring generated ones #61

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jul 17, 2025
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions examples/example_pkg-stubs/_basic.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ def func_contains(
def func_literals(
a1: Literal[1, 3, "foo"], a2: Literal["uno", 2, "drei", "four"] = ...
) -> None: ...
def override_docstring_param(
d1: dict[str, float], d2: dict[Literal["a", "b", "c"], int]
) -> None: ...
def override_docstring_return() -> list[Literal[-1, 0, 1] | float]: ...
def func_use_from_elsewhere(
a1: CustomException,
a2: ExampleClass,
Expand All @@ -37,6 +41,9 @@ def func_use_from_elsewhere(
) -> tuple[CustomException, ExampleClass.NestedClass]: ...

class ExampleClass:

b1: int

class NestedClass:
def method_in_nested_class(self, a1: complex) -> None: ...

Expand Down
27 changes: 26 additions & 1 deletion examples/example_pkg/_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

# Existing imports are preserved
import logging
from typing import Literal

# Assign-statements are preserved
logger = logging.getLogger(__name__) # Inline comments are stripped
Expand Down Expand Up @@ -51,6 +52,25 @@ def func_literals(a1, a2="uno"):
"""


def override_docstring_param(d1, d2: dict[Literal["a", "b", "c"], int]):
"""Check type hint is kept and overrides docstring.

Parameters
----------
d1 : dict of {str : float}
d2 : dict of {str : int}
"""


def override_docstring_return() -> list[Literal[-1, 0, 1] | float]:
"""Check type hint is kept and overrides docstring.

Returns
-------
{"-inf", 0, 1, "inf"}
"""


def func_use_from_elsewhere(a1, a2, a3, a4):
"""Check if types with full import names are matched.

Expand All @@ -75,10 +95,15 @@ class ExampleClass:
----------
a1 : str
a2 : float, default 0

Attributes
----------
b1 : Sized
"""

class NestedClass:
b1: int

class NestedClass:
def method_in_nested_class(self, a1):
"""

Expand Down
75 changes: 54 additions & 21 deletions src/docstub/_stubs.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,25 +571,31 @@ def leave_FunctionDef(self, original_node, updated_node):
assert ds_annotations.returns.value
annotation_value = ds_annotations.returns.value

if original_node.returns is not None:
if original_node.returns is None:
annotation = cst.Annotation(cst.parse_expression(annotation_value))
node_changes["returns"] = annotation
# TODO: check imports
self._required_imports |= ds_annotations.returns.imports

else:
# Notify about ignored docstring annotation
# TODO: either remove message or print only in verbose mode
position = self.get_metadata(
cst.metadata.PositionProvider, original_node
).start
reporter = self.reporter.copy_with(
path=self.current_source, line=position.line
)
replaced = _inline_node_as_code(original_node.returns.annotation)
to_keep = _inline_node_as_code(original_node.returns.annotation)
details = (
f"{replaced}\n{reporter.underline(replaced)} -> {annotation_value}"
f"{reporter.underline(to_keep)} "
f"ignoring docstring: {annotation_value}"
)
reporter.message(
short="Replacing existing inline return annotation",
short="Keeping existing inline return annotation",
details=details,
)

annotation = cst.Annotation(cst.parse_expression(annotation_value))
node_changes["returns"] = annotation
self._required_imports |= ds_annotations.returns.imports
elif original_node.returns is None:
annotation = cst.Annotation(cst.parse_expression("None"))
node_changes["returns"] = annotation
Expand Down Expand Up @@ -633,10 +639,35 @@ def leave_Param(self, original_node, updated_node):
if pytype:
if defaults_to_none:
pytype = pytype.as_optional()
annotation = cst.Annotation(cst.parse_expression(pytype.value))
node_changes["annotation"] = annotation
if pytype.imports:
self._required_imports |= pytype.imports
annotation_value = pytype.value

if original_node.annotation is None:
annotation = cst.Annotation(cst.parse_expression(annotation_value))
node_changes["annotation"] = annotation
# TODO: check imports
if pytype.imports:
self._required_imports |= pytype.imports

else:
# Notify about ignored docstring annotation
# TODO: either remove message or print only in verbose mode
position = self.get_metadata(
cst.metadata.PositionProvider, original_node
).start
reporter = self.reporter.copy_with(
path=self.current_source, line=position.line
)
to_keep = cst.Module([]).code_for_node(
original_node.annotation.annotation
)
details = (
f"{reporter.underline(to_keep)} "
f"ignoring docstring: {annotation_value}"
)
reporter.message(
short="Keeping existing inline parameter annotation",
details=details,
)

# Potentially use "Incomplete" except for first param in (class)methods
elif not is_self_or_cls and updated_node.annotation is None:
Expand Down Expand Up @@ -764,31 +795,33 @@ def leave_AnnAssign(self, original_node, updated_node):
if pytypes and name in pytypes.attributes:
pytype = pytypes.attributes[name]
expr = cst.parse_expression(pytype.value)
self._required_imports |= pytype.imports

if updated_node.annotation is not None:
# Turn original annotation into str and print with context
if updated_node.annotation is None:
self._required_imports |= pytype.imports
updated_node = updated_node.with_deep_changes(
updated_node.annotation, annotation=expr
)

else:
# Notify about ignored docstring annotation
# TODO: either remove message or print only in verbose mode
position = self.get_metadata(
cst.metadata.PositionProvider, original_node
).start
reporter = self.reporter.copy_with(
path=self.current_source, line=position.line
)
replaced = cst.Module([]).code_for_node(
to_keep = cst.Module([]).code_for_node(
updated_node.annotation.annotation
)
details = (
f"{replaced}\n{reporter.underline(replaced)} -> {pytype.value}"
f"{reporter.underline(to_keep)} ignoring docstring: {pytype.value}"
)
reporter.message(
short="Replacing existing inline annotation",
short="Keeping existing inline annotation for assignment",
details=details,
)

updated_node = updated_node.with_deep_changes(
updated_node.annotation, annotation=expr
)

return updated_node

def visit_Module(self, node):
Expand Down
146 changes: 133 additions & 13 deletions tests/test_stubs.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,14 +220,14 @@ def test_attributes_no_doctype(self, assign, expected, scope):
@pytest.mark.parametrize(
("assign", "doctype", "expected"),
[
# ("plain = 3", "plain : int", "plain: int"),
# ("plain = None", "plain : int", "plain: int"),
# ("x, y = (1, 2)", "x : int", "x: int; y: Incomplete"),
# Replace pre-existing annotations
("annotated: float = 1.0", "annotated : int", "annotated: int"),
("plain = 3", "plain : int", "plain: int"),
("plain = None", "plain : int", "plain: int"),
("x, y = (1, 2)", "x : int", "x: int; y: Incomplete"),
# Keep pre-existing annotations
("annotated: float = 1.0", "annotated : int", "annotated: float"),
# Type aliases are untouched
# ("alias: TypeAlias = int", "alias : str", "alias: TypeAlias = int"),
# ("type alias = int", "alias : str", "type alias = int"),
("alias: TypeAlias = int", "alias : str", "alias: TypeAlias = int"),
("type alias = int", "alias : str", "type alias = int"),
],
)
@pytest.mark.parametrize("scope", ["module", "class", "nested class"])
Expand Down Expand Up @@ -283,7 +283,7 @@ class Foo:
a: int
b: float

c: tuple
c: list
d: ClassVar[bool]

def __init__(self, a) -> None: ...
Expand All @@ -298,7 +298,127 @@ def test_undocumented_objects(self):
# https://typing.readthedocs.io/en/latest/guides/writing_stubs.html#undocumented-objects
pass

def test_existing_typed_return(self):
def test_keep_assign_param(self):
source = dedent(
"""
a: str
"""
)
expected = dedent(
"""
a: str
"""
)
transformer = Py2StubTransformer()
result = transformer.python_to_stub(source)
assert expected == result

def test_keep_inline_assign_with_doctype(self, capsys):
source = dedent(
'''
"""
Attributes
----------
a : Sized
"""
a: str
'''
)
expected = dedent(
"""
a: str
"""
)
transformer = Py2StubTransformer()
result = transformer.python_to_stub(source)
assert expected == result

captured = capsys.readouterr()
assert "Keeping existing inline annotation for assignment" in captured.out

def test_keep_class_assign_param(self):
source = dedent(
"""
class Foo:
a: str
"""
)
expected = dedent(
"""
class Foo:
a: str
"""
)
transformer = Py2StubTransformer()
result = transformer.python_to_stub(source)
assert expected == result

def test_keep_inline_class_assign_with_doctype(self, capsys):
source = dedent(
'''
class Foo:
"""
Attributes
----------
a : Sized
"""
a: str
'''
)
expected = dedent(
"""
class Foo:
a: str
"""
)
transformer = Py2StubTransformer()
result = transformer.python_to_stub(source)
assert expected == result

captured = capsys.readouterr()
assert "Keeping existing inline annotation for assignment" in captured.out

def test_keep_inline_param(self):
source = dedent(
"""
def foo(a: str) -> None:
pass
"""
)
expected = dedent(
"""
def foo(a: str) -> None: ...
"""
)
transformer = Py2StubTransformer()
result = transformer.python_to_stub(source)
assert expected == result

def test_keep_inline_param_with_doctype(self, capsys):
source = dedent(
'''
def foo(a: int) -> None:
"""
Parameters
----------
a : Sized
"""
pass
'''
)
expected = dedent(
"""
def foo(a: int) -> None: ...
"""
)
transformer = Py2StubTransformer()
result = transformer.python_to_stub(source)
assert expected == result

captured = capsys.readouterr()
assert "Keeping existing inline parameter annotation" in captured.out

def test_keep_inline_return(self):
source = dedent(
"""
def foo() -> str:
Expand All @@ -314,14 +434,14 @@ def foo() -> str: ...
result = transformer.python_to_stub(source)
assert expected == result

def test_overwriting_typed_return(self, capsys):
def test_keep_inline_return_with_doctype(self, capsys):
source = dedent(
'''
def foo() -> dict[str, int]:
def foo() -> int:
"""
Returns
-------
out : int
out : Sized
"""
pass
'''
Expand All @@ -336,7 +456,7 @@ def foo() -> int: ...
assert expected == result

captured = capsys.readouterr()
assert "Replacing existing inline return annotation" in captured.out
assert "Keeping existing inline return annotation" in captured.out

def test_preserved_type_comment(self):
source = dedent(
Expand Down