Skip to content

Commit

Permalink
[commands] Refactor typing evaluation to not use get_type_hints
Browse files Browse the repository at this point in the history
get_type_hints had a few issues:

1. It would convert = None default parameters to Optional
2. It would not allow values as type annotations
3. It would not implicitly convert some string literals as ForwardRef

In Python 3.9 `list['Foo']` does not convert into
`list[ForwardRef('Foo')]` even though `typing.List` does this
behaviour. In order to streamline it, evaluation had to be rewritten
manually to support our usecases.

This patch also flattens nested typing.Literal which was not done
until Python 3.9.2.
  • Loading branch information
Rapptz committed Apr 11, 2021
1 parent 27886e5 commit cf98dc6
Showing 1 changed file with 103 additions and 58 deletions.
161 changes: 103 additions & 58 deletions discord/ext/commands/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,20 @@
DEALINGS IN THE SOFTWARE.
"""

from typing import (
Any,
Dict,
ForwardRef,
Iterable,
Literal,
Tuple,
Union,
get_args as get_typing_args,
get_origin as get_typing_origin,
)
import asyncio
import functools
import inspect
import typing
import datetime
import sys

Expand Down Expand Up @@ -64,6 +74,83 @@
'bot_has_guild_permissions'
)

PY_310 = sys.version_info >= (3, 10)

def flatten_literal_params(parameters: Iterable[Any]) -> Tuple[Any, ...]:
params = []
literal_cls = type(Literal[0])
for p in parameters:
if isinstance(p, literal_cls):
params.extend(p.__args__)
else:
params.append(p)
return tuple(params)

def _evaluate_annotation(tp: Any, globals: Dict[str, Any], cache: Dict[str, Any] = {}, *, implicit_str=True):
if isinstance(tp, ForwardRef):
tp = tp.__forward_arg__
# ForwardRefs always evaluate their internals
implicit_str = True

if implicit_str and isinstance(tp, str):
if tp in cache:
return cache[tp]
evaluated = eval(tp, globals)
cache[tp] = evaluated
return _evaluate_annotation(evaluated, globals, cache)

if hasattr(tp, '__args__'):
implicit_str = True
args = tp.__args__
if tp.__origin__ is Literal:
if not PY_38:
args = flatten_literal_params(tp.__args__)
implicit_str = False

evaluated_args = tuple(
_evaluate_annotation(arg, globals, cache, implicit_str=implicit_str) for arg in args
)

if evaluated_args == args:
return tp

try:
return tp.copy_with(evaluated_args)
except AttributeError:
return tp.__origin__[evaluated_args]

return tp

def resolve_annotation(annotation: Any, globalns: Dict[str, Any], cache: Dict[str, Any] = {}) -> Any:
if annotation is None:
return type(None)
if isinstance(annotation, str):
annotation = ForwardRef(annotation)
return _evaluate_annotation(annotation, globalns, cache)

def get_signature_parameters(function) -> Dict[str, inspect.Parameter]:
globalns = function.__globals__
signature = inspect.signature(function)
params = {}
cache: Dict[str, Any] = {}
for name, parameter in signature.parameters.items():
annotation = parameter.annotation
if annotation is parameter.empty:
params[name] = parameter
continue
if annotation is None:
params[name] = parameter.replace(annotation=type(None))
continue

annotation = _evaluate_annotation(annotation, globalns, cache)
if annotation is converters.Greedy:
raise TypeError('Unparameterized Greedy[...] is disallowed in signature.')

params[name] = parameter.replace(annotation=annotation)

return params


def wrap_callback(coro):
@functools.wraps(coro)
async def wrapped(*args, **kwargs):
Expand Down Expand Up @@ -300,40 +387,7 @@ def callback(self):
def callback(self, function):
self._callback = function
self.module = function.__module__

signature = inspect.signature(function)
self.params = signature.parameters.copy()

# see: https://bugs.python.org/issue41341
resolve = self._recursive_resolve if sys.version_info < (3, 9) else self._return_resolved

try:
type_hints = {k: resolve(v) for k, v in typing.get_type_hints(function).items()}
except NameError as e:
raise NameError(f'unresolved forward reference: {e.args[0]}') from None

for key, value in self.params.items():
# coalesce the forward references
if key in type_hints:
self.params[key] = value = value.replace(annotation=type_hints[key])

# fail early for when someone passes an unparameterized Greedy type
if value.annotation is converters.Greedy:
raise TypeError('Unparameterized Greedy[...] is disallowed in signature.')

def _return_resolved(self, type, **kwargs):
return type

def _recursive_resolve(self, type, *, globals=None):
if not isinstance(type, typing.ForwardRef):
return type

resolved = eval(type.__forward_arg__, globals)
args = typing.get_args(resolved)
for index, arg in enumerate(args):
inner_resolve_result = self._recursive_resolve(arg, globals=globals)
resolved[index] = inner_resolve_result
return resolved
self.params = get_signature_parameters(function)

def add_check(self, func):
"""Adds a check to the command.
Expand Down Expand Up @@ -493,12 +547,12 @@ async def _actual_conversion(self, ctx, converter, argument, param):
raise BadArgument(f'Converting to "{name}" failed for parameter "{param.name}".') from exc

async def do_conversion(self, ctx, converter, argument, param):
origin = typing.get_origin(converter)
origin = get_typing_origin(converter)

if origin is typing.Union:
if origin is Union:
errors = []
_NoneType = type(None)
for conv in typing.get_args(converter):
for conv in get_typing_args(converter):
# if we got to this part in the code, then the previous conversions have failed
# so we should just undo the view, return the default, and allow parsing to continue
# with the other parameters
Expand All @@ -514,13 +568,12 @@ async def do_conversion(self, ctx, converter, argument, param):
return value

# if we're here, then we failed all the converters
raise BadUnionArgument(param, typing.get_args(converter), errors)
raise BadUnionArgument(param, get_typing_args(converter), errors)

if origin is typing.Literal:
if origin is Literal:
errors = []
conversions = {}
literal_args = tuple(self._flattened_typing_literal_args(converter))
for literal in literal_args:
for literal in converter.__args__:
literal_type = type(literal)
try:
value = conversions[literal_type]
Expand All @@ -538,7 +591,7 @@ async def do_conversion(self, ctx, converter, argument, param):
return value

# if we're here, then we failed to match all the literals
raise BadLiteralArgument(param, literal_args, errors)
raise BadLiteralArgument(param, converter.__args__, errors)

return await self._actual_conversion(ctx, converter, argument, param)

Expand Down Expand Up @@ -1021,14 +1074,7 @@ def short_doc(self):
return ''

def _is_typing_optional(self, annotation):
return typing.get_origin(annotation) is typing.Union and typing.get_args(annotation)[-1] is type(None)

def _flattened_typing_literal_args(self, annotation):
for literal in typing.get_args(annotation):
if typing.get_origin(literal) is typing.Literal:
yield from self._flattened_typing_literal_args(literal)
else:
yield literal
return get_typing_origin(annotation) is Union and get_typing_args(annotation)[-1] is type(None)

@property
def signature(self):
Expand All @@ -1048,17 +1094,16 @@ def signature(self):
# for typing.Literal[...], typing.Optional[typing.Literal[...]], and Greedy[typing.Literal[...]], the
# parameter signature is a literal list of it's values
annotation = param.annotation.converter if greedy else param.annotation
origin = typing.get_origin(annotation)
if not greedy and origin is typing.Union:
union_args = typing.get_args(annotation)
origin = get_typing_origin(annotation)
if not greedy and origin is Union:
union_args = get_typing_args(annotation)
optional = union_args[-1] is type(None)
if optional:
annotation = union_args[0]
origin = typing.get_origin(annotation)
origin = get_typing_origin(annotation)

if origin is typing.Literal:
name = '|'.join(f'"{v}"' if isinstance(v, str) else str(v)
for v in self._flattened_typing_literal_args(annotation))
if origin is Literal:
name = '|'.join(f'"{v}"' if isinstance(v, str) else str(v) for v in annotation.__args__)
if param.default is not param.empty:
# We don't want None or '' to trigger the [name=value] case and instead it should
# do [name] since [name=None] or [name=] are not exactly useful for the user.
Expand Down

0 comments on commit cf98dc6

Please sign in to comment.