22
22
DEALINGS IN THE SOFTWARE.
23
23
"""
24
24
25
+ from typing import (
26
+ Any ,
27
+ Dict ,
28
+ ForwardRef ,
29
+ Iterable ,
30
+ Literal ,
31
+ Tuple ,
32
+ Union ,
33
+ get_args as get_typing_args ,
34
+ get_origin as get_typing_origin ,
35
+ )
25
36
import asyncio
26
37
import functools
27
38
import inspect
28
- import typing
29
39
import datetime
30
40
import sys
31
41
64
74
'bot_has_guild_permissions'
65
75
)
66
76
77
+ PY_310 = sys .version_info >= (3 , 10 )
78
+
79
+ def flatten_literal_params (parameters : Iterable [Any ]) -> Tuple [Any , ...]:
80
+ params = []
81
+ literal_cls = type (Literal [0 ])
82
+ for p in parameters :
83
+ if isinstance (p , literal_cls ):
84
+ params .extend (p .__args__ )
85
+ else :
86
+ params .append (p )
87
+ return tuple (params )
88
+
89
+ def _evaluate_annotation (tp : Any , globals : Dict [str , Any ], cache : Dict [str , Any ] = {}, * , implicit_str = True ):
90
+ if isinstance (tp , ForwardRef ):
91
+ tp = tp .__forward_arg__
92
+ # ForwardRefs always evaluate their internals
93
+ implicit_str = True
94
+
95
+ if implicit_str and isinstance (tp , str ):
96
+ if tp in cache :
97
+ return cache [tp ]
98
+ evaluated = eval (tp , globals )
99
+ cache [tp ] = evaluated
100
+ return _evaluate_annotation (evaluated , globals , cache )
101
+
102
+ if hasattr (tp , '__args__' ):
103
+ implicit_str = True
104
+ args = tp .__args__
105
+ if tp .__origin__ is Literal :
106
+ if not PY_38 :
107
+ args = flatten_literal_params (tp .__args__ )
108
+ implicit_str = False
109
+
110
+ evaluated_args = tuple (
111
+ _evaluate_annotation (arg , globals , cache , implicit_str = implicit_str ) for arg in args
112
+ )
113
+
114
+ if evaluated_args == args :
115
+ return tp
116
+
117
+ try :
118
+ return tp .copy_with (evaluated_args )
119
+ except AttributeError :
120
+ return tp .__origin__ [evaluated_args ]
121
+
122
+ return tp
123
+
124
+ def resolve_annotation (annotation : Any , globalns : Dict [str , Any ], cache : Dict [str , Any ] = {}) -> Any :
125
+ if annotation is None :
126
+ return type (None )
127
+ if isinstance (annotation , str ):
128
+ annotation = ForwardRef (annotation )
129
+ return _evaluate_annotation (annotation , globalns , cache )
130
+
131
+ def get_signature_parameters (function ) -> Dict [str , inspect .Parameter ]:
132
+ globalns = function .__globals__
133
+ signature = inspect .signature (function )
134
+ params = {}
135
+ cache : Dict [str , Any ] = {}
136
+ for name , parameter in signature .parameters .items ():
137
+ annotation = parameter .annotation
138
+ if annotation is parameter .empty :
139
+ params [name ] = parameter
140
+ continue
141
+ if annotation is None :
142
+ params [name ] = parameter .replace (annotation = type (None ))
143
+ continue
144
+
145
+ annotation = _evaluate_annotation (annotation , globalns , cache )
146
+ if annotation is converters .Greedy :
147
+ raise TypeError ('Unparameterized Greedy[...] is disallowed in signature.' )
148
+
149
+ params [name ] = parameter .replace (annotation = annotation )
150
+
151
+ return params
152
+
153
+
67
154
def wrap_callback (coro ):
68
155
@functools .wraps (coro )
69
156
async def wrapped (* args , ** kwargs ):
@@ -300,40 +387,7 @@ def callback(self):
300
387
def callback (self , function ):
301
388
self ._callback = function
302
389
self .module = function .__module__
303
-
304
- signature = inspect .signature (function )
305
- self .params = signature .parameters .copy ()
306
-
307
- # see: https://bugs.python.org/issue41341
308
- resolve = self ._recursive_resolve if sys .version_info < (3 , 9 ) else self ._return_resolved
309
-
310
- try :
311
- type_hints = {k : resolve (v ) for k , v in typing .get_type_hints (function ).items ()}
312
- except NameError as e :
313
- raise NameError (f'unresolved forward reference: { e .args [0 ]} ' ) from None
314
-
315
- for key , value in self .params .items ():
316
- # coalesce the forward references
317
- if key in type_hints :
318
- self .params [key ] = value = value .replace (annotation = type_hints [key ])
319
-
320
- # fail early for when someone passes an unparameterized Greedy type
321
- if value .annotation is converters .Greedy :
322
- raise TypeError ('Unparameterized Greedy[...] is disallowed in signature.' )
323
-
324
- def _return_resolved (self , type , ** kwargs ):
325
- return type
326
-
327
- def _recursive_resolve (self , type , * , globals = None ):
328
- if not isinstance (type , typing .ForwardRef ):
329
- return type
330
-
331
- resolved = eval (type .__forward_arg__ , globals )
332
- args = typing .get_args (resolved )
333
- for index , arg in enumerate (args ):
334
- inner_resolve_result = self ._recursive_resolve (arg , globals = globals )
335
- resolved [index ] = inner_resolve_result
336
- return resolved
390
+ self .params = get_signature_parameters (function )
337
391
338
392
def add_check (self , func ):
339
393
"""Adds a check to the command.
@@ -493,12 +547,12 @@ async def _actual_conversion(self, ctx, converter, argument, param):
493
547
raise BadArgument (f'Converting to "{ name } " failed for parameter "{ param .name } ".' ) from exc
494
548
495
549
async def do_conversion (self , ctx , converter , argument , param ):
496
- origin = typing . get_origin (converter )
550
+ origin = get_typing_origin (converter )
497
551
498
- if origin is typing . Union :
552
+ if origin is Union :
499
553
errors = []
500
554
_NoneType = type (None )
501
- for conv in typing . get_args (converter ):
555
+ for conv in get_typing_args (converter ):
502
556
# if we got to this part in the code, then the previous conversions have failed
503
557
# so we should just undo the view, return the default, and allow parsing to continue
504
558
# with the other parameters
@@ -514,13 +568,12 @@ async def do_conversion(self, ctx, converter, argument, param):
514
568
return value
515
569
516
570
# if we're here, then we failed all the converters
517
- raise BadUnionArgument (param , typing . get_args (converter ), errors )
571
+ raise BadUnionArgument (param , get_typing_args (converter ), errors )
518
572
519
- if origin is typing . Literal :
573
+ if origin is Literal :
520
574
errors = []
521
575
conversions = {}
522
- literal_args = tuple (self ._flattened_typing_literal_args (converter ))
523
- for literal in literal_args :
576
+ for literal in converter .__args__ :
524
577
literal_type = type (literal )
525
578
try :
526
579
value = conversions [literal_type ]
@@ -538,7 +591,7 @@ async def do_conversion(self, ctx, converter, argument, param):
538
591
return value
539
592
540
593
# if we're here, then we failed to match all the literals
541
- raise BadLiteralArgument (param , literal_args , errors )
594
+ raise BadLiteralArgument (param , converter . __args__ , errors )
542
595
543
596
return await self ._actual_conversion (ctx , converter , argument , param )
544
597
@@ -1021,14 +1074,7 @@ def short_doc(self):
1021
1074
return ''
1022
1075
1023
1076
def _is_typing_optional (self , annotation ):
1024
- return typing .get_origin (annotation ) is typing .Union and typing .get_args (annotation )[- 1 ] is type (None )
1025
-
1026
- def _flattened_typing_literal_args (self , annotation ):
1027
- for literal in typing .get_args (annotation ):
1028
- if typing .get_origin (literal ) is typing .Literal :
1029
- yield from self ._flattened_typing_literal_args (literal )
1030
- else :
1031
- yield literal
1077
+ return get_typing_origin (annotation ) is Union and get_typing_args (annotation )[- 1 ] is type (None )
1032
1078
1033
1079
@property
1034
1080
def signature (self ):
@@ -1048,17 +1094,16 @@ def signature(self):
1048
1094
# for typing.Literal[...], typing.Optional[typing.Literal[...]], and Greedy[typing.Literal[...]], the
1049
1095
# parameter signature is a literal list of it's values
1050
1096
annotation = param .annotation .converter if greedy else param .annotation
1051
- origin = typing . get_origin (annotation )
1052
- if not greedy and origin is typing . Union :
1053
- union_args = typing . get_args (annotation )
1097
+ origin = get_typing_origin (annotation )
1098
+ if not greedy and origin is Union :
1099
+ union_args = get_typing_args (annotation )
1054
1100
optional = union_args [- 1 ] is type (None )
1055
1101
if optional :
1056
1102
annotation = union_args [0 ]
1057
- origin = typing . get_origin (annotation )
1103
+ origin = get_typing_origin (annotation )
1058
1104
1059
- if origin is typing .Literal :
1060
- name = '|' .join (f'"{ v } "' if isinstance (v , str ) else str (v )
1061
- for v in self ._flattened_typing_literal_args (annotation ))
1105
+ if origin is Literal :
1106
+ name = '|' .join (f'"{ v } "' if isinstance (v , str ) else str (v ) for v in annotation .__args__ )
1062
1107
if param .default is not param .empty :
1063
1108
# We don't want None or '' to trigger the [name=value] case and instead it should
1064
1109
# do [name] since [name=None] or [name=] are not exactly useful for the user.
0 commit comments