Skip to content

Commit

Permalink
Add support for sync DI to sync callbacks
Browse files Browse the repository at this point in the history
  • Loading branch information
FasterSpeeding committed Oct 15, 2022
1 parent 0ba037b commit 2cfb2e7
Show file tree
Hide file tree
Showing 3 changed files with 118 additions and 87 deletions.
101 changes: 62 additions & 39 deletions tanjun/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,13 @@
from ._internal import localisation

if typing.TYPE_CHECKING:
import typing_extensions

_P = typing_extensions.ParamSpec("_P")

_PermissionErrorSigBase = collections.Callable[typing_extensions.Concatenate[hikari.Permissions, _P], Exception]
_PermissionErrorSig = _PermissionErrorSigBase[...]


class _AnyCallback(typing.Protocol):
async def __call__(
Expand Down Expand Up @@ -131,7 +138,7 @@ def _handle_result(
) -> bool:
if not result:
if self._error:
raise self._error(*args) from None
raise ctx.call_with_di(self._error, args) from None
if self._halt_execution:
raise errors.HaltExecution from None
if self._error_message:
Expand All @@ -152,7 +159,7 @@ class OwnerCheck(_Check):
def __init__(
self,
*,
error: typing.Optional[collections.Callable[[], Exception]] = None,
error: typing.Optional[collections.Callable[..., Exception]] = None,
error_message: typing.Union[str, collections.Mapping[str, str], None] = "Only bot owners can use this command",
halt_execution: bool = False,
) -> None:
Expand All @@ -163,7 +170,8 @@ def __init__(
error
Callback used to create a custom error to raise if the check fails.
This takes priority over `error_message`.
This takes no positional arguments, supports sync DI and takes
priority over `error_message`.
error_message
The error message to send in response as a command error if the check fails.
Expand Down Expand Up @@ -231,7 +239,7 @@ class NsfwCheck(_Check):
def __init__(
self,
*,
error: typing.Optional[collections.Callable[[], Exception]] = None,
error: typing.Optional[collections.Callable[..., Exception]] = None,
error_message: typing.Union[
str, collections.Mapping[str, str], None
] = "Command can only be used in NSFW channels",
Expand All @@ -244,7 +252,8 @@ def __init__(
error
Callback used to create a custom error to raise if the check fails.
This takes priority over `error_message`.
This takes no positional arguments, supports sync DI and takes
priority over `error_message`.
error_message
The error message to send in response as a command error if the check fails.
Expand Down Expand Up @@ -284,7 +293,7 @@ class SfwCheck(_Check):
def __init__(
self,
*,
error: typing.Optional[collections.Callable[[], Exception]] = None,
error: typing.Optional[collections.Callable[..., Exception]] = None,
error_message: typing.Union[
str, collections.Mapping[str, str], None
] = "Command can only be used in SFW channels",
Expand All @@ -297,7 +306,8 @@ def __init__(
error
Callback used to create a custom error to raise if the check fails.
This takes priority over `error_message`.
This takes no positonal arguments, supports sync DI and takes
priority over `error_message`.
error_message
The error message to send in response as a command error if the check fails.
Expand Down Expand Up @@ -337,7 +347,7 @@ class DmCheck(_Check):
def __init__(
self,
*,
error: typing.Optional[collections.Callable[[], Exception]] = None,
error: typing.Optional[collections.Callable[..., Exception]] = None,
error_message: typing.Union[str, collections.Mapping[str, str], None] = "Command can only be used in DMs",
halt_execution: bool = False,
) -> None:
Expand All @@ -348,7 +358,8 @@ def __init__(
error
Callback used to create a custom error to raise if the check fails.
This takes priority over `error_message`.
This takes no positonal arguments, supports sync DI and takes
priority over `error_message`.
error_message
The error message to send in response as a command error if the check fails.
Expand Down Expand Up @@ -384,7 +395,7 @@ class GuildCheck(_Check):
def __init__(
self,
*,
error: typing.Optional[collections.Callable[[], Exception]] = None,
error: typing.Optional[collections.Callable[..., Exception]] = None,
error_message: typing.Union[
str, collections.Mapping[str, str], None
] = "Command can only be used in guild channels",
Expand All @@ -397,7 +408,8 @@ def __init__(
error
Callback used to create a custom error to raise if the check fails.
This takes priority over `error_message`.
This takes no positonal arguments, supports sync DI and takes
priority over `error_message`.
error_message
The error message to send in response as a command error if the check fails.
Expand Down Expand Up @@ -435,7 +447,7 @@ def __init__(
permissions: typing.Union[hikari.Permissions, int],
/,
*,
error: typing.Optional[collections.Callable[[hikari.Permissions], Exception]] = None,
error: typing.Optional[_PermissionErrorSig] = None,
error_message: typing.Union[
str, collections.Mapping[str, str], None
] = "You don't have the permissions required to use this command",
Expand All @@ -451,7 +463,8 @@ def __init__(
Callback used to create a custom error to raise if the check fails.
This should take 1 positional argument of type [hikari.permissions.Permissions][]
which represents the missing permissions required for this command to run.
which represents the missing permissions required for this command
to run, return an [Exception][] to raise, and supports sync DI.
This takes priority over `error_message`.
error_message
Expand Down Expand Up @@ -514,7 +527,7 @@ def __init__(
permissions: typing.Union[hikari.Permissions, int],
/,
*,
error: typing.Optional[collections.Callable[[hikari.Permissions], Exception]] = None,
error: typing.Optional[_PermissionErrorSig] = None,
error_message: typing.Union[
str, collections.Mapping[str, str], None
] = "Bot doesn't have the permissions required to run this command",
Expand All @@ -530,7 +543,8 @@ def __init__(
Callback used to create a custom error to raise if the check fails.
This should take 1 positional argument of type [hikari.permissions.Permissions][]
which represents the missing permissions required for this command to run.
which represents the missing permissions required for this command
to run, return an [Exception][] to raise, and supports sync DI.
This takes priority over `error_message`.
error_message
Expand Down Expand Up @@ -585,7 +599,7 @@ def with_dm_check(command: _CommandT, /) -> _CommandT:
@typing.overload
def with_dm_check(
*,
error: typing.Optional[collections.Callable[[], Exception]] = None,
error: typing.Optional[collections.Callable[..., Exception]] = None,
error_message: typing.Union[str, collections.Mapping[str, str], None] = "Command can only be used in DMs",
follow_wrapped: bool = False,
halt_execution: bool = False,
Expand All @@ -597,7 +611,7 @@ def with_dm_check(
command: typing.Optional[_CommandT] = None,
/,
*,
error: typing.Optional[collections.Callable[[], Exception]] = None,
error: typing.Optional[collections.Callable[..., Exception]] = None,
error_message: typing.Union[str, collections.Mapping[str, str], None] = "Command can only be used in DMs",
follow_wrapped: bool = False,
halt_execution: bool = False,
Expand All @@ -611,7 +625,8 @@ def with_dm_check(
error
Callback used to create a custom error to raise if the check fails.
This takes priority over `error_message`.
This takes no positonal arguments, supports sync DI and takes priority
over `error_message`.
error_message
The error message to send in response as a command error if the check fails.
Expand Down Expand Up @@ -646,7 +661,7 @@ def with_guild_check(command: _CommandT, /) -> _CommandT:
@typing.overload
def with_guild_check(
*,
error: typing.Optional[collections.Callable[[], Exception]] = None,
error: typing.Optional[collections.Callable[..., Exception]] = None,
error_message: typing.Union[
str, collections.Mapping[str, str], None
] = "Command can only be used in guild channels",
Expand All @@ -660,7 +675,7 @@ def with_guild_check(
command: typing.Optional[_CommandT] = None,
/,
*,
error: typing.Optional[collections.Callable[[], Exception]] = None,
error: typing.Optional[collections.Callable[..., Exception]] = None,
error_message: typing.Union[
str, collections.Mapping[str, str], None
] = "Command can only be used in guild channels",
Expand All @@ -676,7 +691,8 @@ def with_guild_check(
error
Callback used to create a custom error to raise if the check fails.
This takes priority over `error_message`.
This takes no positonal arguments, supports sync DI and takes priority
over `error_message`.
error_message
The error message to send in response as a command error if the check fails.
Expand Down Expand Up @@ -711,7 +727,7 @@ def with_nsfw_check(command: _CommandT, /) -> _CommandT:
@typing.overload
def with_nsfw_check(
*,
error: typing.Optional[collections.Callable[[], Exception]] = None,
error: typing.Optional[collections.Callable[..., Exception]] = None,
error_message: typing.Union[str, collections.Mapping[str, str], None] = "Command can only be used in NSFW channels",
follow_wrapped: bool = False,
halt_execution: bool = False,
Expand All @@ -723,7 +739,7 @@ def with_nsfw_check(
command: typing.Optional[_CommandT] = None,
/,
*,
error: typing.Optional[collections.Callable[[], Exception]] = None,
error: typing.Optional[collections.Callable[..., Exception]] = None,
error_message: typing.Union[str, collections.Mapping[str, str], None] = "Command can only be used in NSFW channels",
follow_wrapped: bool = False,
halt_execution: bool = False,
Expand All @@ -737,7 +753,8 @@ def with_nsfw_check(
error
Callback used to create a custom error to raise if the check fails.
This takes priority over `error_message`.
This takes no positonal arguments, supports sync DI and takes priority
over `error_message`.
error_message
The error message to send in response as a command error if the check fails.
Expand Down Expand Up @@ -772,7 +789,7 @@ def with_sfw_check(command: _CommandT, /) -> _CommandT:
@typing.overload
def with_sfw_check(
*,
error: typing.Optional[collections.Callable[[], Exception]] = None,
error: typing.Optional[collections.Callable[..., Exception]] = None,
error_message: typing.Union[str, collections.Mapping[str, str], None] = "Command can only be used in SFW channels",
follow_wrapped: bool = False,
halt_execution: bool = False,
Expand All @@ -784,7 +801,7 @@ def with_sfw_check(
command: typing.Optional[_CommandT] = None,
/,
*,
error: typing.Optional[collections.Callable[[], Exception]] = None,
error: typing.Optional[collections.Callable[..., Exception]] = None,
error_message: typing.Union[str, collections.Mapping[str, str], None] = "Command can only be used in SFW channels",
follow_wrapped: bool = False,
halt_execution: bool = False,
Expand All @@ -798,7 +815,8 @@ def with_sfw_check(
error
Callback used to create a custom error to raise if the check fails.
This takes priority over `error_message`.
This takes no positonal arguments, supports sync DI and takes priority
over `error_message`.
error_message
The error message to send in response as a command error if the check fails.
Expand Down Expand Up @@ -833,7 +851,7 @@ def with_owner_check(command: _CommandT, /) -> _CommandT:
@typing.overload
def with_owner_check(
*,
error: typing.Optional[collections.Callable[[], Exception]] = None,
error: typing.Optional[collections.Callable[..., Exception]] = None,
error_message: typing.Union[str, collections.Mapping[str, str], None] = "Only bot owners can use this command",
follow_wrapped: bool = False,
halt_execution: bool = False,
Expand All @@ -845,7 +863,7 @@ def with_owner_check(
command: typing.Optional[_CommandT] = None,
/,
*,
error: typing.Optional[collections.Callable[[], Exception]] = None,
error: typing.Optional[collections.Callable[..., Exception]] = None,
error_message: typing.Union[str, collections.Mapping[str, str], None] = "Only bot owners can use this command",
follow_wrapped: bool = False,
halt_execution: bool = False,
Expand All @@ -859,7 +877,8 @@ def with_owner_check(
error
Callback used to create a custom error to raise if the check fails.
This takes priority over `error_message`.
This takes no positonal arguments, supports sync DI and takes priority
over `error_message`.
error_message
The error message to send in response as a command error if the check fails.
Expand Down Expand Up @@ -889,7 +908,7 @@ def with_owner_check(
def with_author_permission_check(
permissions: typing.Union[hikari.Permissions, int],
*,
error: typing.Optional[collections.Callable[[hikari.Permissions], Exception]] = None,
error: typing.Optional[_PermissionErrorSig] = None,
error_message: typing.Union[
str, collections.Mapping[str, str], None
] = "You don't have the permissions required to use this command",
Expand All @@ -910,7 +929,8 @@ def with_author_permission_check(
Callback used to create a custom error to raise if the check fails.
This should take 1 positional argument of type [hikari.permissions.Permissions][]
which represents the missing permissions required for this command to run.
which represents the missing permissions required for this command to
run, return an [Exception][] to raise, and supports sync DI.
This takes priority over `error_message`.
error_message
Expand Down Expand Up @@ -944,7 +964,7 @@ def with_author_permission_check(
def with_own_permission_check(
permissions: typing.Union[hikari.Permissions, int],
*,
error: typing.Optional[collections.Callable[[hikari.Permissions], Exception]] = None,
error: typing.Optional[_PermissionErrorSig] = None,
error_message: typing.Union[
str, collections.Mapping[str, str], None
] = "Bot doesn't have the permissions required to run this command",
Expand All @@ -965,7 +985,8 @@ def with_own_permission_check(
Callback used to create a custom error to raise if the check fails.
This should take 1 positional argument of type [hikari.permissions.Permissions][]
which represents the missing permissions required for this command to run.
which represents the missing permissions required for this command to
run, return an [Exception][] to raise, and supports sync DI.
This takes priority over `error_message`.
error_message
Expand Down Expand Up @@ -1091,7 +1112,7 @@ class _AnyChecks(_Check):
def __init__(
self,
checks: list[tanjun.CheckSig],
error: typing.Optional[collections.Callable[[], Exception]],
error: typing.Optional[collections.Callable[..., Exception]],
error_message: typing.Union[str, collections.Mapping[str, str], None],
halt_execution: bool,
suppress: tuple[type[Exception], ...],
Expand Down Expand Up @@ -1124,7 +1145,7 @@ def any_checks(
check: tanjun.CheckSig,
/,
*checks: tanjun.CheckSig,
error: typing.Optional[collections.Callable[[], Exception]] = None,
error: typing.Optional[collections.Callable[..., Exception]] = None,
error_message: typing.Union[str, collections.Mapping[str, str], None],
halt_execution: bool = False,
suppress: tuple[type[Exception], ...] = (errors.CommandError, errors.HaltExecution),
Expand All @@ -1143,7 +1164,8 @@ def any_checks(
error
Callback used to create a custom error to raise if the check fails.
This takes priority over `error_message`.
This takes no positonal arguments, supports sync DI and takes priority
over `error_message`.
error_message
The error message to send in response as a command error if the check fails.
Expand All @@ -1168,7 +1190,7 @@ def with_any_checks(
check: tanjun.CheckSig,
/,
*checks: tanjun.CheckSig,
error: typing.Optional[collections.Callable[[], Exception]] = None,
error: typing.Optional[collections.Callable[..., Exception]] = None,
error_message: typing.Union[str, collections.Mapping[str, str], None],
follow_wrapped: bool = False,
halt_execution: bool = False,
Expand All @@ -1188,7 +1210,8 @@ def with_any_checks(
error
Callback used to create a custom error to raise if the check fails.
This takes priority over `error_message`.
This takes no positonal arguments, supports sync DI and takes priority
over `error_message`.
error_message
The error message to send in response as a command error if the check fails.
Expand Down
Loading

0 comments on commit 2cfb2e7

Please sign in to comment.