Skip to content

Commit

Permalink
Fixed implicit conversion handling in stubgen
Browse files Browse the repository at this point in the history
  • Loading branch information
TimSchneider42 committed Jun 3, 2024
1 parent 750153d commit fac29cc
Showing 1 changed file with 22 additions and 12 deletions.
34 changes: 22 additions & 12 deletions custom_stubgen.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,27 @@
#!/usr/bin/env python3
import logging
from pathlib import Path
from typing import Dict, Optional
from typing import Dict, Optional, Sequence
from collections import defaultdict

from pybind11_stubgen import Writer, QualifiedName, Printer, arg_parser, stub_parser_from_args, to_output_and_subdir, \
run
from pybind11_stubgen.structs import Function, ResolvedType, Module


class CustomWriter(Writer):
def __init__(self, implicit_conversions: Dict[str, str], stub_ext: str = "pyi"):
def __init__(self, alternative_types: Dict[str, Sequence[str, ...]], stub_ext: str = "pyi"):
super().__init__(stub_ext=stub_ext)
self.implicit_conversions = {
QualifiedName.from_str(k): QualifiedName.from_str(v) for k, v in implicit_conversions.items()
self.alternative_types = {
QualifiedName.from_str(k): tuple(QualifiedName.from_str(e) for e in v) for k, v in alternative_types.items()
}

def _patch_function(self, function: Function):
for argument in function.args:
if argument.annotation is not None and argument.annotation.name in self.implicit_conversions:
converted_type = ResolvedType(self.implicit_conversions[argument.annotation.name])
if argument.annotation is not None and argument.annotation.name in self.alternative_types:
converted_types = [ResolvedType(e) for e in self.alternative_types[argument.annotation.name]]
argument.annotation = ResolvedType(
QualifiedName.from_str("typing.Union"), [argument.annotation, converted_type])
QualifiedName.from_str("typing.Union"), [argument.annotation] + converted_types)

def write_module(self, module: Module, printer: Printer, to: Path, sub_dir: Optional[Path] = None):
for cls in module.classes:
Expand All @@ -32,11 +33,20 @@ def write_module(self, module: Module, printer: Printer, to: Path, sub_dir: Opti
super().write_module(module, printer, to, sub_dir=sub_dir)


IMPLICIT_CONVERSIONS = {
"bool": "Condition",
"float": "RelativeDynamicsFactor",
"Affine": "RobotPose",
}
IMPLICIT_CONVERSIONS = [
("bool", "Condition"),
("float", "RelativeDynamicsFactor"),
("Affine", "RobotPose"),
("Twist", "RobotVelocity"),
("RobotPose", "CartesianState"),
("Affine", "CartesianState"),
("list[float]", "JointState"),
("np.ndarray", "JointState")
]

alternatives = defaultdict(list)
for from_type, to_type in IMPLICIT_CONVERSIONS:
alternatives[to_type].append(from_type)

if __name__ == "__main__":
logging.basicConfig(
Expand Down

0 comments on commit fac29cc

Please sign in to comment.