Skip to content

Commit

Permalink
Add support for sync DI to sync callbacks
Browse files Browse the repository at this point in the history
  • Loading branch information
FasterSpeeding committed Oct 4, 2022
1 parent cf145f4 commit ac6a8c0
Show file tree
Hide file tree
Showing 3 changed files with 120 additions and 87 deletions.
103 changes: 64 additions & 39 deletions tanjun/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,15 @@
from . import permissions
from ._internal import localisation

if typing.TYPE_CHECKING:
import typing_extensions

_P = typing_extensions.ParamSpec("_P")

_PermissionErrorSigBase = collections.Callable[typing_extensions.Concatenate[hikari.Permissions, _P], Exception]
_PermissionErrorSig = _PermissionErrorSigBase[...]


_CommandT = typing.TypeVar("_CommandT", bound="tanjun.ExecutableCommand[typing.Any]")
# This errors on earlier 3.9 releases when not quotes cause dumb handling of the [_CommandT] list
_CallbackReturnT = typing.Union[_CommandT, "collections.Callable[[_CommandT], _CommandT]"]
Expand Down Expand Up @@ -116,7 +125,7 @@ def _handle_result(
) -> bool:
if not result:
if self._error:
raise self._error(*args) from None
raise ctx.call_with_di(self._error, args) from None
if self._halt_execution:
raise errors.HaltExecution from None
if self._error_message:
Expand All @@ -137,7 +146,7 @@ class OwnerCheck(_Check):
def __init__(
self,
*,
error: typing.Optional[collections.Callable[[], Exception]] = None,
error: typing.Optional[collections.Callable[..., Exception]] = None,
error_message: typing.Union[str, collections.Mapping[str, str], None] = "Only bot owners can use this command",
halt_execution: bool = False,
) -> None:
Expand All @@ -148,7 +157,8 @@ def __init__(
error
Callback used to create a custom error to raise if the check fails.
This takes priority over `error_message`.
This takes no positional arguments, supports sync DI and takes
priority over `error_message`.
error_message
The error message to send in response as a command error if the check fails.
Expand Down Expand Up @@ -216,7 +226,7 @@ class NsfwCheck(_Check):
def __init__(
self,
*,
error: typing.Optional[collections.Callable[[], Exception]] = None,
error: typing.Optional[collections.Callable[..., Exception]] = None,
error_message: typing.Union[
str, collections.Mapping[str, str], None
] = "Command can only be used in NSFW channels",
Expand All @@ -229,7 +239,8 @@ def __init__(
error
Callback used to create a custom error to raise if the check fails.
This takes priority over `error_message`.
This takes no positional arguments, supports sync DI and takes
priority over `error_message`.
error_message
The error message to send in response as a command error if the check fails.
Expand Down Expand Up @@ -269,7 +280,7 @@ class SfwCheck(_Check):
def __init__(
self,
*,
error: typing.Optional[collections.Callable[[], Exception]] = None,
error: typing.Optional[collections.Callable[..., Exception]] = None,
error_message: typing.Union[
str, collections.Mapping[str, str], None
] = "Command can only be used in SFW channels",
Expand All @@ -282,7 +293,8 @@ def __init__(
error
Callback used to create a custom error to raise if the check fails.
This takes priority over `error_message`.
This takes no positonal arguments, supports sync DI and takes
priority over `error_message`.
error_message
The error message to send in response as a command error if the check fails.
Expand Down Expand Up @@ -322,7 +334,7 @@ class DmCheck(_Check):
def __init__(
self,
*,
error: typing.Optional[collections.Callable[[], Exception]] = None,
error: typing.Optional[collections.Callable[..., Exception]] = None,
error_message: typing.Union[str, collections.Mapping[str, str], None] = "Command can only be used in DMs",
halt_execution: bool = False,
) -> None:
Expand All @@ -333,7 +345,8 @@ def __init__(
error
Callback used to create a custom error to raise if the check fails.
This takes priority over `error_message`.
This takes no positonal arguments, supports sync DI and takes
priority over `error_message`.
error_message
The error message to send in response as a command error if the check fails.
Expand Down Expand Up @@ -369,7 +382,7 @@ class GuildCheck(_Check):
def __init__(
self,
*,
error: typing.Optional[collections.Callable[[], Exception]] = None,
error: typing.Optional[collections.Callable[..., Exception]] = None,
error_message: typing.Union[
str, collections.Mapping[str, str], None
] = "Command can only be used in guild channels",
Expand All @@ -382,7 +395,8 @@ def __init__(
error
Callback used to create a custom error to raise if the check fails.
This takes priority over `error_message`.
This takes no positonal arguments, supports sync DI and takes
priority over `error_message`.
error_message
The error message to send in response as a command error if the check fails.
Expand Down Expand Up @@ -420,7 +434,7 @@ def __init__(
permissions: typing.Union[hikari.Permissions, int],
/,
*,
error: typing.Optional[collections.Callable[[hikari.Permissions], Exception]] = None,
error: typing.Optional[_PermissionErrorSig] = None,
error_message: typing.Union[
str, collections.Mapping[str, str], None
] = "You don't have the permissions required to use this command",
Expand All @@ -436,7 +450,8 @@ def __init__(
Callback used to create a custom error to raise if the check fails.
This should take 1 positional argument of type [hikari.permissions.Permissions][]
which represents the missing permissions required for this command to run.
which represents the missing permissions required for this command
to run, return an [Exception][] to raise, and supports sync DI.
This takes priority over `error_message`.
error_message
Expand Down Expand Up @@ -499,7 +514,7 @@ def __init__(
permissions: typing.Union[hikari.Permissions, int],
/,
*,
error: typing.Optional[collections.Callable[[hikari.Permissions], Exception]] = None,
error: typing.Optional[_PermissionErrorSig] = None,
error_message: typing.Union[
str, collections.Mapping[str, str], None
] = "Bot doesn't have the permissions required to run this command",
Expand All @@ -515,7 +530,8 @@ def __init__(
Callback used to create a custom error to raise if the check fails.
This should take 1 positional argument of type [hikari.permissions.Permissions][]
which represents the missing permissions required for this command to run.
which represents the missing permissions required for this command
to run, return an [Exception][] to raise, and supports sync DI.
This takes priority over `error_message`.
error_message
Expand Down Expand Up @@ -570,7 +586,7 @@ def with_dm_check(command: _CommandT, /) -> _CommandT:
@typing.overload
def with_dm_check(
*,
error: typing.Optional[collections.Callable[[], Exception]] = None,
error: typing.Optional[collections.Callable[..., Exception]] = None,
error_message: typing.Union[str, collections.Mapping[str, str], None] = "Command can only be used in DMs",
follow_wrapped: bool = False,
halt_execution: bool = False,
Expand All @@ -582,7 +598,7 @@ def with_dm_check(
command: typing.Optional[_CommandT] = None,
/,
*,
error: typing.Optional[collections.Callable[[], Exception]] = None,
error: typing.Optional[collections.Callable[..., Exception]] = None,
error_message: typing.Union[str, collections.Mapping[str, str], None] = "Command can only be used in DMs",
follow_wrapped: bool = False,
halt_execution: bool = False,
Expand All @@ -596,7 +612,8 @@ def with_dm_check(
error
Callback used to create a custom error to raise if the check fails.
This takes priority over `error_message`.
This takes no positonal arguments, supports sync DI and takes priority
over `error_message`.
error_message
The error message to send in response as a command error if the check fails.
Expand Down Expand Up @@ -631,7 +648,7 @@ def with_guild_check(command: _CommandT, /) -> _CommandT:
@typing.overload
def with_guild_check(
*,
error: typing.Optional[collections.Callable[[], Exception]] = None,
error: typing.Optional[collections.Callable[..., Exception]] = None,
error_message: typing.Union[
str, collections.Mapping[str, str], None
] = "Command can only be used in guild channels",
Expand All @@ -645,7 +662,7 @@ def with_guild_check(
command: typing.Optional[_CommandT] = None,
/,
*,
error: typing.Optional[collections.Callable[[], Exception]] = None,
error: typing.Optional[collections.Callable[..., Exception]] = None,
error_message: typing.Union[
str, collections.Mapping[str, str], None
] = "Command can only be used in guild channels",
Expand All @@ -661,7 +678,8 @@ def with_guild_check(
error
Callback used to create a custom error to raise if the check fails.
This takes priority over `error_message`.
This takes no positonal arguments, supports sync DI and takes priority
over `error_message`.
error_message
The error message to send in response as a command error if the check fails.
Expand Down Expand Up @@ -696,7 +714,7 @@ def with_nsfw_check(command: _CommandT, /) -> _CommandT:
@typing.overload
def with_nsfw_check(
*,
error: typing.Optional[collections.Callable[[], Exception]] = None,
error: typing.Optional[collections.Callable[..., Exception]] = None,
error_message: typing.Union[str, collections.Mapping[str, str], None] = "Command can only be used in NSFW channels",
follow_wrapped: bool = False,
halt_execution: bool = False,
Expand All @@ -708,7 +726,7 @@ def with_nsfw_check(
command: typing.Optional[_CommandT] = None,
/,
*,
error: typing.Optional[collections.Callable[[], Exception]] = None,
error: typing.Optional[collections.Callable[..., Exception]] = None,
error_message: typing.Union[str, collections.Mapping[str, str], None] = "Command can only be used in NSFW channels",
follow_wrapped: bool = False,
halt_execution: bool = False,
Expand All @@ -722,7 +740,8 @@ def with_nsfw_check(
error
Callback used to create a custom error to raise if the check fails.
This takes priority over `error_message`.
This takes no positonal arguments, supports sync DI and takes priority
over `error_message`.
error_message
The error message to send in response as a command error if the check fails.
Expand Down Expand Up @@ -757,7 +776,7 @@ def with_sfw_check(command: _CommandT, /) -> _CommandT:
@typing.overload
def with_sfw_check(
*,
error: typing.Optional[collections.Callable[[], Exception]] = None,
error: typing.Optional[collections.Callable[..., Exception]] = None,
error_message: typing.Union[str, collections.Mapping[str, str], None] = "Command can only be used in SFW channels",
follow_wrapped: bool = False,
halt_execution: bool = False,
Expand All @@ -769,7 +788,7 @@ def with_sfw_check(
command: typing.Optional[_CommandT] = None,
/,
*,
error: typing.Optional[collections.Callable[[], Exception]] = None,
error: typing.Optional[collections.Callable[..., Exception]] = None,
error_message: typing.Union[str, collections.Mapping[str, str], None] = "Command can only be used in SFW channels",
follow_wrapped: bool = False,
halt_execution: bool = False,
Expand All @@ -783,7 +802,8 @@ def with_sfw_check(
error
Callback used to create a custom error to raise if the check fails.
This takes priority over `error_message`.
This takes no positonal arguments, supports sync DI and takes priority
over `error_message`.
error_message
The error message to send in response as a command error if the check fails.
Expand Down Expand Up @@ -818,7 +838,7 @@ def with_owner_check(command: _CommandT, /) -> _CommandT:
@typing.overload
def with_owner_check(
*,
error: typing.Optional[collections.Callable[[], Exception]] = None,
error: typing.Optional[collections.Callable[..., Exception]] = None,
error_message: typing.Union[str, collections.Mapping[str, str], None] = "Only bot owners can use this command",
follow_wrapped: bool = False,
halt_execution: bool = False,
Expand All @@ -830,7 +850,7 @@ def with_owner_check(
command: typing.Optional[_CommandT] = None,
/,
*,
error: typing.Optional[collections.Callable[[], Exception]] = None,
error: typing.Optional[collections.Callable[..., Exception]] = None,
error_message: typing.Union[str, collections.Mapping[str, str], None] = "Only bot owners can use this command",
follow_wrapped: bool = False,
halt_execution: bool = False,
Expand All @@ -844,7 +864,8 @@ def with_owner_check(
error
Callback used to create a custom error to raise if the check fails.
This takes priority over `error_message`.
This takes no positonal arguments, supports sync DI and takes priority
over `error_message`.
error_message
The error message to send in response as a command error if the check fails.
Expand Down Expand Up @@ -874,7 +895,7 @@ def with_owner_check(
def with_author_permission_check(
permissions: typing.Union[hikari.Permissions, int],
*,
error: typing.Optional[collections.Callable[[hikari.Permissions], Exception]] = None,
error: typing.Optional[_PermissionErrorSig] = None,
error_message: typing.Union[
str, collections.Mapping[str, str], None
] = "You don't have the permissions required to use this command",
Expand All @@ -895,7 +916,8 @@ def with_author_permission_check(
Callback used to create a custom error to raise if the check fails.
This should take 1 positional argument of type [hikari.permissions.Permissions][]
which represents the missing permissions required for this command to run.
which represents the missing permissions required for this command to
run, return an [Exception][] to raise, and supports sync DI.
This takes priority over `error_message`.
error_message
Expand Down Expand Up @@ -929,7 +951,7 @@ def with_author_permission_check(
def with_own_permission_check(
permissions: typing.Union[hikari.Permissions, int],
*,
error: typing.Optional[collections.Callable[[hikari.Permissions], Exception]] = None,
error: typing.Optional[_PermissionErrorSig] = None,
error_message: typing.Union[
str, collections.Mapping[str, str], None
] = "Bot doesn't have the permissions required to run this command",
Expand All @@ -950,7 +972,8 @@ def with_own_permission_check(
Callback used to create a custom error to raise if the check fails.
This should take 1 positional argument of type [hikari.permissions.Permissions][]
which represents the missing permissions required for this command to run.
which represents the missing permissions required for this command to
run, return an [Exception][] to raise, and supports sync DI.
This takes priority over `error_message`.
error_message
Expand Down Expand Up @@ -1076,7 +1099,7 @@ class _AnyChecks(_Check):
def __init__(
self,
checks: list[tanjun.CheckSig],
error: typing.Optional[collections.Callable[[], Exception]],
error: typing.Optional[collections.Callable[..., Exception]],
error_message: typing.Union[str, collections.Mapping[str, str], None],
halt_execution: bool,
suppress: tuple[type[Exception], ...],
Expand Down Expand Up @@ -1109,7 +1132,7 @@ def any_checks(
check: tanjun.CheckSig,
/,
*checks: tanjun.CheckSig,
error: typing.Optional[collections.Callable[[], Exception]] = None,
error: typing.Optional[collections.Callable[..., Exception]] = None,
error_message: typing.Union[str, collections.Mapping[str, str], None],
halt_execution: bool = False,
suppress: tuple[type[Exception], ...] = (errors.CommandError, errors.HaltExecution),
Expand All @@ -1128,7 +1151,8 @@ def any_checks(
error
Callback used to create a custom error to raise if the check fails.
This takes priority over `error_message`.
This takes no positonal arguments, supports sync DI and takes priority
over `error_message`.
error_message
The error message to send in response as a command error if the check fails.
Expand All @@ -1153,7 +1177,7 @@ def with_any_checks(
check: tanjun.CheckSig,
/,
*checks: tanjun.CheckSig,
error: typing.Optional[collections.Callable[[], Exception]] = None,
error: typing.Optional[collections.Callable[..., Exception]] = None,
error_message: typing.Union[str, collections.Mapping[str, str], None],
follow_wrapped: bool = False,
halt_execution: bool = False,
Expand All @@ -1173,7 +1197,8 @@ def with_any_checks(
error
Callback used to create a custom error to raise if the check fails.
This takes priority over `error_message`.
This takes no positonal arguments, supports sync DI and takes priority
over `error_message`.
error_message
The error message to send in response as a command error if the check fails.
Expand Down
Loading

0 comments on commit ac6a8c0

Please sign in to comment.