Skip to content
This repository was archived by the owner on Aug 28, 2019. It is now read-only.

Commit cf98dc6

Browse files
committed
[commands] Refactor typing evaluation to not use get_type_hints
get_type_hints had a few issues: 1. It would convert = None default parameters to Optional 2. It would not allow values as type annotations 3. It would not implicitly convert some string literals as ForwardRef In Python 3.9 `list['Foo']` does not convert into `list[ForwardRef('Foo')]` even though `typing.List` does this behaviour. In order to streamline it, evaluation had to be rewritten manually to support our usecases. This patch also flattens nested typing.Literal which was not done until Python 3.9.2.
1 parent 27886e5 commit cf98dc6

File tree

1 file changed

+103
-58
lines changed

1 file changed

+103
-58
lines changed

discord/ext/commands/core.py

Lines changed: 103 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,20 @@
2222
DEALINGS IN THE SOFTWARE.
2323
"""
2424

25+
from typing import (
26+
Any,
27+
Dict,
28+
ForwardRef,
29+
Iterable,
30+
Literal,
31+
Tuple,
32+
Union,
33+
get_args as get_typing_args,
34+
get_origin as get_typing_origin,
35+
)
2536
import asyncio
2637
import functools
2738
import inspect
28-
import typing
2939
import datetime
3040
import sys
3141

@@ -64,6 +74,83 @@
6474
'bot_has_guild_permissions'
6575
)
6676

77+
PY_310 = sys.version_info >= (3, 10)
78+
79+
def flatten_literal_params(parameters: Iterable[Any]) -> Tuple[Any, ...]:
80+
params = []
81+
literal_cls = type(Literal[0])
82+
for p in parameters:
83+
if isinstance(p, literal_cls):
84+
params.extend(p.__args__)
85+
else:
86+
params.append(p)
87+
return tuple(params)
88+
89+
def _evaluate_annotation(tp: Any, globals: Dict[str, Any], cache: Dict[str, Any] = {}, *, implicit_str=True):
90+
if isinstance(tp, ForwardRef):
91+
tp = tp.__forward_arg__
92+
# ForwardRefs always evaluate their internals
93+
implicit_str = True
94+
95+
if implicit_str and isinstance(tp, str):
96+
if tp in cache:
97+
return cache[tp]
98+
evaluated = eval(tp, globals)
99+
cache[tp] = evaluated
100+
return _evaluate_annotation(evaluated, globals, cache)
101+
102+
if hasattr(tp, '__args__'):
103+
implicit_str = True
104+
args = tp.__args__
105+
if tp.__origin__ is Literal:
106+
if not PY_38:
107+
args = flatten_literal_params(tp.__args__)
108+
implicit_str = False
109+
110+
evaluated_args = tuple(
111+
_evaluate_annotation(arg, globals, cache, implicit_str=implicit_str) for arg in args
112+
)
113+
114+
if evaluated_args == args:
115+
return tp
116+
117+
try:
118+
return tp.copy_with(evaluated_args)
119+
except AttributeError:
120+
return tp.__origin__[evaluated_args]
121+
122+
return tp
123+
124+
def resolve_annotation(annotation: Any, globalns: Dict[str, Any], cache: Dict[str, Any] = {}) -> Any:
125+
if annotation is None:
126+
return type(None)
127+
if isinstance(annotation, str):
128+
annotation = ForwardRef(annotation)
129+
return _evaluate_annotation(annotation, globalns, cache)
130+
131+
def get_signature_parameters(function) -> Dict[str, inspect.Parameter]:
132+
globalns = function.__globals__
133+
signature = inspect.signature(function)
134+
params = {}
135+
cache: Dict[str, Any] = {}
136+
for name, parameter in signature.parameters.items():
137+
annotation = parameter.annotation
138+
if annotation is parameter.empty:
139+
params[name] = parameter
140+
continue
141+
if annotation is None:
142+
params[name] = parameter.replace(annotation=type(None))
143+
continue
144+
145+
annotation = _evaluate_annotation(annotation, globalns, cache)
146+
if annotation is converters.Greedy:
147+
raise TypeError('Unparameterized Greedy[...] is disallowed in signature.')
148+
149+
params[name] = parameter.replace(annotation=annotation)
150+
151+
return params
152+
153+
67154
def wrap_callback(coro):
68155
@functools.wraps(coro)
69156
async def wrapped(*args, **kwargs):
@@ -300,40 +387,7 @@ def callback(self):
300387
def callback(self, function):
301388
self._callback = function
302389
self.module = function.__module__
303-
304-
signature = inspect.signature(function)
305-
self.params = signature.parameters.copy()
306-
307-
# see: https://bugs.python.org/issue41341
308-
resolve = self._recursive_resolve if sys.version_info < (3, 9) else self._return_resolved
309-
310-
try:
311-
type_hints = {k: resolve(v) for k, v in typing.get_type_hints(function).items()}
312-
except NameError as e:
313-
raise NameError(f'unresolved forward reference: {e.args[0]}') from None
314-
315-
for key, value in self.params.items():
316-
# coalesce the forward references
317-
if key in type_hints:
318-
self.params[key] = value = value.replace(annotation=type_hints[key])
319-
320-
# fail early for when someone passes an unparameterized Greedy type
321-
if value.annotation is converters.Greedy:
322-
raise TypeError('Unparameterized Greedy[...] is disallowed in signature.')
323-
324-
def _return_resolved(self, type, **kwargs):
325-
return type
326-
327-
def _recursive_resolve(self, type, *, globals=None):
328-
if not isinstance(type, typing.ForwardRef):
329-
return type
330-
331-
resolved = eval(type.__forward_arg__, globals)
332-
args = typing.get_args(resolved)
333-
for index, arg in enumerate(args):
334-
inner_resolve_result = self._recursive_resolve(arg, globals=globals)
335-
resolved[index] = inner_resolve_result
336-
return resolved
390+
self.params = get_signature_parameters(function)
337391

338392
def add_check(self, func):
339393
"""Adds a check to the command.
@@ -493,12 +547,12 @@ async def _actual_conversion(self, ctx, converter, argument, param):
493547
raise BadArgument(f'Converting to "{name}" failed for parameter "{param.name}".') from exc
494548

495549
async def do_conversion(self, ctx, converter, argument, param):
496-
origin = typing.get_origin(converter)
550+
origin = get_typing_origin(converter)
497551

498-
if origin is typing.Union:
552+
if origin is Union:
499553
errors = []
500554
_NoneType = type(None)
501-
for conv in typing.get_args(converter):
555+
for conv in get_typing_args(converter):
502556
# if we got to this part in the code, then the previous conversions have failed
503557
# so we should just undo the view, return the default, and allow parsing to continue
504558
# with the other parameters
@@ -514,13 +568,12 @@ async def do_conversion(self, ctx, converter, argument, param):
514568
return value
515569

516570
# if we're here, then we failed all the converters
517-
raise BadUnionArgument(param, typing.get_args(converter), errors)
571+
raise BadUnionArgument(param, get_typing_args(converter), errors)
518572

519-
if origin is typing.Literal:
573+
if origin is Literal:
520574
errors = []
521575
conversions = {}
522-
literal_args = tuple(self._flattened_typing_literal_args(converter))
523-
for literal in literal_args:
576+
for literal in converter.__args__:
524577
literal_type = type(literal)
525578
try:
526579
value = conversions[literal_type]
@@ -538,7 +591,7 @@ async def do_conversion(self, ctx, converter, argument, param):
538591
return value
539592

540593
# if we're here, then we failed to match all the literals
541-
raise BadLiteralArgument(param, literal_args, errors)
594+
raise BadLiteralArgument(param, converter.__args__, errors)
542595

543596
return await self._actual_conversion(ctx, converter, argument, param)
544597

@@ -1021,14 +1074,7 @@ def short_doc(self):
10211074
return ''
10221075

10231076
def _is_typing_optional(self, annotation):
1024-
return typing.get_origin(annotation) is typing.Union and typing.get_args(annotation)[-1] is type(None)
1025-
1026-
def _flattened_typing_literal_args(self, annotation):
1027-
for literal in typing.get_args(annotation):
1028-
if typing.get_origin(literal) is typing.Literal:
1029-
yield from self._flattened_typing_literal_args(literal)
1030-
else:
1031-
yield literal
1077+
return get_typing_origin(annotation) is Union and get_typing_args(annotation)[-1] is type(None)
10321078

10331079
@property
10341080
def signature(self):
@@ -1048,17 +1094,16 @@ def signature(self):
10481094
# for typing.Literal[...], typing.Optional[typing.Literal[...]], and Greedy[typing.Literal[...]], the
10491095
# parameter signature is a literal list of it's values
10501096
annotation = param.annotation.converter if greedy else param.annotation
1051-
origin = typing.get_origin(annotation)
1052-
if not greedy and origin is typing.Union:
1053-
union_args = typing.get_args(annotation)
1097+
origin = get_typing_origin(annotation)
1098+
if not greedy and origin is Union:
1099+
union_args = get_typing_args(annotation)
10541100
optional = union_args[-1] is type(None)
10551101
if optional:
10561102
annotation = union_args[0]
1057-
origin = typing.get_origin(annotation)
1103+
origin = get_typing_origin(annotation)
10581104

1059-
if origin is typing.Literal:
1060-
name = '|'.join(f'"{v}"' if isinstance(v, str) else str(v)
1061-
for v in self._flattened_typing_literal_args(annotation))
1105+
if origin is Literal:
1106+
name = '|'.join(f'"{v}"' if isinstance(v, str) else str(v) for v in annotation.__args__)
10621107
if param.default is not param.empty:
10631108
# We don't want None or '' to trigger the [name=value] case and instead it should
10641109
# do [name] since [name=None] or [name=] are not exactly useful for the user.

0 commit comments

Comments
 (0)