diff --git a/README.md b/README.md index 62918af..0ffff75 100644 --- a/README.md +++ b/README.md @@ -61,7 +61,7 @@ bot = Bot( homeserver="https://matrix.org", # your homeserver user_id="@__example__:matrix.org", # the user ID to log in as (Fully qualified) command_prefix="!", # the prefix to respond to (case sensitive, must be lowercase if below is True) - case_insensitive=True, # messages will be lower()cased before being handled. This is recommended. + case_insensitive=True, # messages will be casefold()ed before being handled. This is recommended. owner_id="@owner:homeserver.com" # The user ID who owns this bot. Optional, but required for bot.is_owner(...). ) diff --git a/src/niobot/__init__.py b/src/niobot/__init__.py index 8ee6e41..ca3eb4d 100644 --- a/src/niobot/__init__.py +++ b/src/niobot/__init__.py @@ -9,7 +9,7 @@ from .utils import * try: - import __version__ as version_meta + import __version__ as version_meta # type: ignore except ImportError: class __VersionMeta: diff --git a/src/niobot/__main__.py b/src/niobot/__main__.py index 983f4b4..27bc519 100644 --- a/src/niobot/__main__.py +++ b/src/niobot/__main__.py @@ -194,7 +194,7 @@ def test_homeserver(homeserver: str): parsed = urllib.parse.urlparse(homeserver) if not parsed.scheme: logger.info("No scheme found, assuming HTTPS.") - parsed = urllib.parse.urlparse("https://" + homeserver) + parsed = urllib.parse.urlparse(f"https://{homeserver}") if not parsed.netloc: logger.critical("No netloc found, cannot continue.") @@ -205,7 +205,7 @@ def test_homeserver(homeserver: str): logger.info("Trying well-known of %r...", parsed.netloc) base_url = None try: - response = httpx.get("https://%s/.well-known/matrix/client" % parsed.netloc, timeout=30) + response = httpx.get(f"https://{parsed.netloc}/.well-known/matrix/client", timeout=30) except httpx.HTTPError as e: logger.critical("Failed to get well-known: %r", e) return @@ -238,13 +238,13 @@ def test_homeserver(homeserver: str): if not base_url: logger.info("No well-known found. Assuming %r as homeserver.", parsed.netloc) - base_url = urllib.parse.urlparse("https://" + parsed.netloc) + base_url = urllib.parse.urlparse(f"https://{parsed.netloc}") base_url = base_url.geturl() logger.info("Using %r as homeserver.", base_url) logger.info("Validating homeserver...") try: - response = httpx.get(base_url + "/_matrix/client/versions", timeout=30) + response = httpx.get(f"{base_url}/_matrix/client/versions", timeout=30) except httpx.HTTPError as e: logger.critical("Failed to get versions: %r", e) return @@ -310,7 +310,7 @@ def get_access_token(ctx, username: str, password: str, homeserver: str, device_ status_code = None try: response = httpx.post( - homeserver + "/_matrix/client/r0/login", + f"{homeserver}/_matrix/client/r0/login", json={ "type": "m.login.password", "identifier": {"type": "m.id.user", "user": username}, @@ -327,10 +327,10 @@ def get_access_token(ctx, username: str, password: str, homeserver: str, device_ response.raise_for_status() except httpx.HTTPError as e: click.secho("Failed!", fg="red", nl=False) - click.secho(" (%s)" % status_code or str(e), bg="red") + click.secho(f" ({status_code or str(e)})", bg="red") else: click.secho("OK", fg="green") - click.secho("Access token: %s" % response.json()["access_token"], fg="green") + click.secho(f'Access token: {response.json()["access_token"]}', fg="green") @cli_root.group() diff --git a/src/niobot/attachment.py b/src/niobot/attachment.py index 2c1169e..2e555b3 100644 --- a/src/niobot/attachment.py +++ b/src/niobot/attachment.py @@ -17,6 +17,7 @@ import urllib.parse import warnings from typing import Union as U +from typing import overload import aiofiles import aiohttp @@ -85,7 +86,7 @@ SUPPORTED_CODECS = SUPPORTED_VIDEO_CODECS + SUPPORTED_AUDIO_CODECS + SUPPORTED_IMAGE_CODECS -def detect_mime_type(file: typing.Union[str, io.BytesIO, pathlib.Path]) -> str: +def detect_mime_type(file: U[str, io.BytesIO, pathlib.Path]) -> str: """ Detect the mime type of a file. @@ -112,7 +113,7 @@ def detect_mime_type(file: typing.Union[str, io.BytesIO, pathlib.Path]) -> str: raise TypeError("File must be a string, BytesIO, or Path object.") -def get_metadata_ffmpeg(file: typing.Union[str, pathlib.Path]) -> typing.Dict[str, typing.Any]: +def get_metadata_ffmpeg(file: U[str, pathlib.Path]) -> dict[str, typing.Any]: """ Gets metadata for a file via ffprobe. @@ -136,7 +137,7 @@ def get_metadata_ffmpeg(file: typing.Union[str, pathlib.Path]) -> typing.Dict[st return data -def get_metadata_imagemagick(file: pathlib.Path) -> typing.Dict[str, typing.Any]: +def get_metadata_imagemagick(file: pathlib.Path) -> dict[str, typing.Any]: """The same as `get_metadata_ffmpeg` but for ImageMagick. Only returns a limited subset of the data, such as one stream, which contains the format, and size, @@ -180,7 +181,7 @@ def get_metadata_imagemagick(file: pathlib.Path) -> typing.Dict[str, typing.Any] return data -def get_metadata(file: typing.Union[str, pathlib.Path], mime_type: str = None) -> typing.Dict[str, typing.Any]: +def get_metadata(file: U[str, pathlib.Path], mime_type: typing.Optional[str] = None) -> dict[str, typing.Any]: """ Gets metadata for a file. @@ -308,9 +309,19 @@ def _file_okay(file: U[pathlib.Path, io.BytesIO]) -> typing.Literal[True]: return True -def _to_path(file: U[str, pathlib.Path, io.BytesIO]) -> typing.Union[pathlib.Path, io.BytesIO]: +@overload +def _to_path(file: U[str, pathlib.Path]) -> pathlib.Path: + ... + + +@overload +def _to_path(file: io.BytesIO) -> io.BytesIO: + ... + + +def _to_path(file: U[str, pathlib.Path, io.BytesIO]) -> U[pathlib.Path, io.BytesIO]: """Converts a string to a Path object.""" - if not isinstance(file, (str, pathlib.PurePath, io.BytesIO)): + if not isinstance(file, (str, pathlib.Path, io.BytesIO)): raise TypeError("File must be a string, BytesIO, or Path object.") if isinstance(file, io.BytesIO): @@ -330,8 +341,8 @@ def _size(file: U[pathlib.Path, io.BytesIO]) -> int: def which( - file: U[io.BytesIO, pathlib.Path, str], mime_type: str = None -) -> typing.Union[ + file: U[io.BytesIO, pathlib.Path, str], mime_type: typing.Optional[str] = None +) -> U[ typing.Type["FileAttachment"], typing.Type["ImageAttachment"], typing.Type["AudioAttachment"], @@ -423,30 +434,36 @@ class BaseAttachment(abc.ABC): """ if typing.TYPE_CHECKING: - file: typing.Union[pathlib.Path, io.BytesIO] + file: U[pathlib.Path, io.BytesIO] file_name: str mime_type: str size: int type: AttachmentType url: typing.Optional[str] - keys: typing.Optional[typing.Dict[str, str]] + keys: typing.Optional[dict[str, str]] def __init__( self, - file: typing.Union[str, io.BytesIO, pathlib.Path], - file_name: str = None, - mime_type: str = None, - size_bytes: int = None, + file: U[str, io.BytesIO, pathlib.Path], + file_name: typing.Optional[str] = None, + mime_type: typing.Optional[str] = None, + size_bytes: typing.Optional[int] = None, *, attachment_type: AttachmentType = AttachmentType.FILE, ): self.file = _to_path(file) - self.file_name = self.file.name if isinstance(self.file, pathlib.Path) else file_name + # Ignore type error as the type is checked right afterwards + self.file_name = self.file.name if isinstance(self.file, pathlib.Path) else file_name # type: ignore if not self.file_name: raise ValueError("file_name must be specified when uploading a BytesIO object.") self.mime_type = mime_type or detect_mime_type(self.file) - self.size = size_bytes or os.path.getsize(self.file) + if size_bytes: + self.size = size_bytes + elif isinstance(self.file, io.BytesIO): + self.size = len(self.file.getbuffer()) + else: + os.path.getsize(self.file) self.type = attachment_type self.url = None @@ -458,14 +475,14 @@ def __repr__(self): "mime_type={0.mime_type!r} size={0.size!r} type={0.type!r}>".format(self) ) - def as_body(self, body: str = None) -> dict: + def as_body(self, body: typing.Optional[str] = None) -> dict: """ Generates the body for the attachment for sending. The attachment must've been uploaded first. :param body: The body to use (should be a textual description). Defaults to the file name. :return: """ - body = { + output_body = { "body": body or self.file_name, "info": { "mimetype": self.mime_type, @@ -476,14 +493,14 @@ def as_body(self, body: str = None) -> dict: "url": self.url, } if self.keys: - body["file"] = self.keys - return body + output_body["file"] = self.keys + return output_body @classmethod async def from_file( cls, - file: typing.Union[str, io.BytesIO, pathlib.Path], - file_name: str = None, + file: U[str, io.BytesIO, pathlib.Path], + file_name: typing.Optional[str] = None, ) -> "BaseAttachment": """ Creates an attachment from a file. @@ -495,11 +512,10 @@ async def from_file( :return: Loaded attachment. """ file = _to_path(file) - if isinstance(file, io.BytesIO): - if not file_name: + if not file_name: + if isinstance(file, io.BytesIO): raise ValueError("file_name must be specified when uploading a BytesIO object.") - else: - if not file_name: + else: file_name = file.name mime_type = await run_blocking(detect_mime_type, file) @@ -537,7 +553,7 @@ async def from_mxc( async def from_http( cls, url: str, - client_session: aiohttp.ClientSession = None, + client_session: typing.Optional[aiohttp.ClientSession] = None, *, force_write: U[bool, pathlib.Path] = False, ) -> "BaseAttachment": @@ -588,14 +604,14 @@ async def from_http( else: save_path = tempdir - if save_path is not None: - async with aiofiles.open(save_path, "wb") as fh: - async for chunk in response.content.iter_chunked(1024): - await fh.write(chunk) - return await cls.from_file(save_path, file_name) - else: + if save_path is None: return await cls.from_file(io.BytesIO(await response.read()), file_name) + async with aiofiles.open(save_path, "wb") as fh: + async for chunk in response.content.iter_chunked(1024): + await fh.write(chunk) + return await cls.from_file(save_path, file_name) + @property def size_bytes(self) -> int: """Returns the size of this attachment in bytes.""" @@ -612,7 +628,7 @@ def size_as( "gb", "gib", ], - ) -> typing.Union[int, float]: + ) -> U[int, float]: """ Helper function to convert the size of this attachment into a different unit. @@ -715,16 +731,16 @@ class SupportXYZAmorganBlurHash(BaseAttachment): if typing.TYPE_CHECKING: xyz_amorgan_blurhash: str - def __init__(self, *args, xyz_amorgan_blurhash: str = None, **kwargs): + def __init__(self, *args, xyz_amorgan_blurhash: typing.Optional[str] = None, **kwargs): super().__init__(*args, **kwargs) self.xyz_amorgan_blurhash = xyz_amorgan_blurhash @classmethod async def from_file( cls, - file: typing.Union[str, io.BytesIO, pathlib.Path], - file_name: str = None, - xyz_amorgan_blurhash: U[str, bool] = None, + file: U[str, io.BytesIO, pathlib.Path], + file_name: typing.Optional[str] = None, + xyz_amorgan_blurhash: typing.Optional[U[str, bool]] = None, ) -> "SupportXYZAmorganBlurHash": file = _to_path(file) if isinstance(file, io.BytesIO): @@ -765,7 +781,9 @@ def thumbnailify_image( return image async def get_blurhash( - self, quality: typing.Tuple[int, int] = (4, 3), file: U[str, pathlib.Path, io.BytesIO, PIL.Image.Image] = None + self, + quality: typing.Tuple[int, int] = (4, 3), + file: typing.Optional[U[str, pathlib.Path, io.BytesIO, PIL.Image.Image]] = None, ) -> str: """ Gets the blurhash of the attachment. See: [woltapp/blurhash](https://github.com/woltapp/blurhash) @@ -798,11 +816,11 @@ async def get_blurhash( self.xyz_amorgan_blurhash = x return x - def as_body(self, body: str = None) -> dict: - body = super().as_body(body) + def as_body(self, body: typing.Optional[str] = None) -> dict: + output_body = super().as_body(body) if isinstance(self.xyz_amorgan_blurhash, str): - body["info"]["xyz.amorgan.blurhash"] = self.xyz_amorgan_blurhash - return body + output_body["info"]["xyz.amorgan.blurhash"] = self.xyz_amorgan_blurhash + return output_body class FileAttachment(BaseAttachment): @@ -822,10 +840,10 @@ class FileAttachment(BaseAttachment): def __init__( self, - file: typing.Union[str, io.BytesIO, pathlib.Path], - file_name: str = None, - mime_type: str = None, - size_bytes: int = None, + file: U[str, io.BytesIO, pathlib.Path], + file_name: typing.Optional[str] = None, + mime_type: typing.Optional[str] = None, + size_bytes: typing.Optional[int] = None, ): super().__init__(file, file_name, mime_type, size_bytes, attachment_type=AttachmentType.FILE) @@ -849,14 +867,14 @@ class ImageAttachment(SupportXYZAmorganBlurHash): def __init__( self, - file: typing.Union[str, io.BytesIO, pathlib.Path], - file_name: str = None, - mime_type: str = None, - size_bytes: int = None, - height: int = None, - width: int = None, - thumbnail: "ImageAttachment" = None, - xyz_amorgan_blurhash: str = None, + file: U[str, io.BytesIO, pathlib.Path], + file_name: typing.Optional[str] = None, + mime_type: typing.Optional[str] = None, + size_bytes: typing.Optional[int] = None, + height: typing.Optional[int] = None, + width: typing.Optional[int] = None, + thumbnail: typing.Optional["ImageAttachment"] = None, + xyz_amorgan_blurhash: typing.Optional[str] = None, ): super().__init__( file, @@ -877,11 +895,11 @@ def __init__( @classmethod async def from_file( cls, - file: typing.Union[str, io.BytesIO, pathlib.Path], - file_name: str = None, - height: int = None, - width: int = None, - thumbnail: "ImageAttachment" = None, + file: U[str, io.BytesIO, pathlib.Path], + file_name: typing.Optional[str] = None, + height: typing.Optional[int] = None, + width: typing.Optional[int] = None, + thumbnail: typing.Optional["ImageAttachment"] = None, generate_blurhash: bool = True, *, unsafe: bool = False, @@ -929,15 +947,15 @@ async def from_file( await self.get_blurhash() return self - def as_body(self, body: str = None) -> dict: - body = super().as_body(body) - body["info"] = {**body["info"], **self.info} + def as_body(self, body: typing.Optional[str] = None) -> dict: + output_body = super().as_body(body) + output_body["info"] = {**output_body["info"], **self.info} if self.thumbnail: if self.thumbnail.keys: - body["info"]["thumbnail_file"] = self.thumbnail.keys - body["info"]["thumbnail_info"] = self.thumbnail.info - body["info"]["thumbnail_url"] = self.thumbnail.url - return body + output_body["info"]["thumbnail_file"] = self.thumbnail.keys + output_body["info"]["thumbnail_info"] = self.thumbnail.info + output_body["info"]["thumbnail_url"] = self.thumbnail.url + return output_body class VideoAttachment(BaseAttachment): @@ -956,14 +974,14 @@ class VideoAttachment(BaseAttachment): def __init__( self, - file: typing.Union[str, io.BytesIO, pathlib.Path], - file_name: str = None, - mime_type: str = None, - size_bytes: int = None, - duration: int = None, - height: int = None, - width: int = None, - thumbnail: "ImageAttachment" = None, + file: U[str, io.BytesIO, pathlib.Path], + file_name: typing.Optional[str] = None, + mime_type: typing.Optional[str] = None, + size_bytes: typing.Optional[int] = None, + duration: typing.Optional[int] = None, + height: typing.Optional[int] = None, + width: typing.Optional[int] = None, + thumbnail: typing.Optional["ImageAttachment"] = None, ): super().__init__(file, file_name, mime_type, size_bytes, attachment_type=AttachmentType.VIDEO) self.info = { @@ -978,12 +996,12 @@ def __init__( @classmethod async def from_file( cls, - file: typing.Union[str, io.BytesIO, pathlib.Path], - file_name: str = None, - duration: int = None, - height: int = None, - width: int = None, - thumbnail: U[ImageAttachment, typing.Literal[False]] = None, + file: U[str, io.BytesIO, pathlib.Path], + file_name: typing.Optional[str] = None, + duration: typing.Optional[int] = None, + height: typing.Optional[int] = None, + width: typing.Optional[int] = None, + thumbnail: typing.Optional[U[ImageAttachment, typing.Literal[False]]] = None, generate_blurhash: bool = True, ) -> "VideoAttachment": """ @@ -1046,12 +1064,14 @@ async def from_file( if isinstance(self.thumbnail, ImageAttachment): await self.thumbnail.get_blurhash() elif isinstance(file, pathlib.Path) and original_thumbnail is not False: - thumbnail = await run_blocking(first_frame, file) - self.thumbnail = await ImageAttachment.from_file(io.BytesIO(thumbnail), file_name="thumbnail.webp") + thumbnail_bytes = await run_blocking(first_frame, file) + self.thumbnail = await ImageAttachment.from_file( + io.BytesIO(thumbnail_bytes), file_name="thumbnail.webp" + ) return self @staticmethod - async def generate_thumbnail(video: typing.Union[str, pathlib.Path, "VideoAttachment"]) -> ImageAttachment: + async def generate_thumbnail(video: U[str, pathlib.Path, "VideoAttachment"]) -> ImageAttachment: """ Generates a thumbnail for a video. @@ -1068,15 +1088,15 @@ async def generate_thumbnail(video: typing.Union[str, pathlib.Path, "VideoAttach x = await run_blocking(first_frame, video, "webp") return await ImageAttachment.from_file(io.BytesIO(x), file_name="thumbnail.webp") - def as_body(self, body: str = None) -> dict: - body = super().as_body(body) - body["info"] = {**body["info"], **self.info} + def as_body(self, body: typing.Optional[str] = None) -> dict: + output_body = super().as_body(body) + output_body["info"] = {**output_body["info"], **self.info} if self.thumbnail: if self.thumbnail.keys: - body["info"]["thumbnail_file"] = self.thumbnail.keys - body["info"]["thumbnail_info"] = self.thumbnail.info - body["info"]["thumbnail_url"] = self.thumbnail.url - return body + output_body["info"]["thumbnail_file"] = self.thumbnail.keys + output_body["info"]["thumbnail_info"] = self.thumbnail.info + output_body["info"]["thumbnail_url"] = self.thumbnail.url + return output_body class AudioAttachment(BaseAttachment): @@ -1086,11 +1106,11 @@ class AudioAttachment(BaseAttachment): def __init__( self, - file: typing.Union[str, io.BytesIO, pathlib.Path], - file_name: str = None, - mime_type: str = None, - size_bytes: int = None, - duration: int = None, + file: U[str, io.BytesIO, pathlib.Path], + file_name: typing.Optional[str] = None, + mime_type: typing.Optional[str] = None, + size_bytes: typing.Optional[int] = None, + duration: typing.Optional[int] = None, ): super().__init__(file, file_name, mime_type, size_bytes, attachment_type=AttachmentType.AUDIO) self.info = { @@ -1102,9 +1122,9 @@ def __init__( @classmethod async def from_file( cls, - file: typing.Union[str, io.BytesIO, pathlib.Path], - file_name: str = None, - duration: int = None, + file: U[str, io.BytesIO, pathlib.Path], + file_name: typing.Optional[str] = None, + duration: typing.Optional[int] = None, ) -> "AudioAttachment": """ Generates an audio attachment @@ -1130,7 +1150,7 @@ async def from_file( self = cls(file, file_name, mime_type, size, duration) return self - def as_body(self, body: str = None) -> dict: - body = super().as_body(body) - body["info"] = {**body["info"], **self.info} - return body + def as_body(self, body: typing.Optional[str] = None) -> dict: + output_body = super().as_body(body) + output_body["info"] = {**output_body["info"], **self.info} + return output_body diff --git a/src/niobot/client.py b/src/niobot/client.py index c74e88c..796a54a 100644 --- a/src/niobot/client.py +++ b/src/niobot/client.py @@ -13,10 +13,7 @@ import nio from nio.crypto import ENCRYPTION_ENABLED -try: - from .attachment import BaseAttachment -except ImportError: - BaseAttachment = None +from .attachment import BaseAttachment from .commands import Command, Module from .exceptions import * from .utils import Typing, force_await, run_blocking @@ -37,7 +34,7 @@ class NioBot(nio.AsyncClient): :param device_id: The device ID to log in as. e.g. nio-bot :param store_path: The path to the store file. Defaults to ./store. Must be a directory. :param command_prefix: The prefix to use for commands. e.g. ! - :param case_insensitive: Whether to ignore case when checking for commands. If True, this lower()s + :param case_insensitive: Whether to ignore case when checking for commands. If True, this casefold()s incoming messages for parsing. :param global_message_type: The message type to default to. Defaults to m.notice :param ignore_old_events: Whether to simply discard events before the bot's login. @@ -52,15 +49,15 @@ def __init__( homeserver: str, user_id: str, device_id: str = "nio-bot", - store_path: str = None, + store_path: typing.Optional[str] = None, *, command_prefix: typing.Union[str, re.Pattern], case_insensitive: bool = True, - owner_id: str = None, - config: nio.AsyncClientConfig = None, + owner_id: typing.Optional[str] = None, + config: typing.Optional[nio.AsyncClientConfig] = None, ssl: bool = True, - proxy: str = None, - help_command: typing.Union[Command, typing.Callable[["Context"], typing.Any]] = None, + proxy: typing.Optional[str] = None, + help_command: typing.Optional[typing.Union[Command, typing.Callable[["Context"], typing.Any]]] = None, global_message_type: str = "m.notice", ignore_old_events: bool = True, auto_join_rooms: bool = True, @@ -104,7 +101,7 @@ def __init__( if command_prefix == "/": self.log.warning("The prefix '/' may interfere with client-side commands on some clients, such as Element.") - if re.match(r"\s", command_prefix): + if isinstance(command_prefix, str) and re.match(r"\s", command_prefix): raise RuntimeError("Command prefix cannot contain whitespace.") self.start_time: typing.Optional[float] = None @@ -134,10 +131,9 @@ def __init__( # NOTE: `m.notice` prevents bot messages sending off room notifications, and shows darker text # (In element at least). - # noinspection PyTypeChecker - self.add_event_callback(self.process_message, nio.RoomMessageText) + self.add_event_callback(self.process_message, nio.RoomMessageText) # type: ignore self.add_event_callback(self.update_read_receipts, nio.RoomMessage) - self.direct_rooms: typing.Dict[str, nio.MatrixRoom] = {} + self.direct_rooms: dict[str, nio.MatrixRoom] = {} self.message_cache: typing.Deque[typing.Tuple[nio.MatrixRoom, nio.RoomMessageText]] = deque( maxlen=max_message_cache @@ -145,10 +141,9 @@ def __init__( self.is_ready = asyncio.Event() self._waiting_events = {} - if self.auto_join_rooms is True: + if self.auto_join_rooms: self.log.info("Auto-joining rooms enabled.") - # noinspection PyTypeChecker - self.add_event_callback(self._auto_join_room_backlog_callback, nio.InviteMemberEvent) + self.add_event_callback(self._auto_join_room_backlog_callback, nio.InviteMemberEvent) # type: ignore async def sync(self, *args, **kwargs) -> U[nio.SyncResponse, nio.SyncError]: sync = await super().sync(*args, **kwargs) @@ -184,7 +179,7 @@ async def _auto_join_room_backlog_callback(self, room: nio.MatrixRoom, event: ni await self._auto_join_room_callback(room, event) @staticmethod - def latency(event: nio.Event, *, received_at: float = None) -> float: + def latency(event: nio.Event, *, received_at: typing.Optional[float] = None) -> float: """Returns the latency for a given event in milliseconds :param event: The event to measure latency with @@ -212,7 +207,7 @@ def dispatch(self, event_name: str, *args, **kwargs): self.log.debug("%r is not in registered events: %s", event_name, self._events) def is_old(self, event: nio.Event) -> bool: - """Checks if an event was sent before the bot started. Always returns False when ignore_old_evens is False""" + """Checks if an event was sent before the bot started. Always returns False when ignore_old_events is False""" if not self.start_time: self.log.warning("have not started yet, using relative age comparison") start_time = time.time() - 30 # relative @@ -238,15 +233,19 @@ async def update_read_receipts(self, room: U[str, nio.MatrixRoom], event: nio.Ev if self.is_old(event): self.log.debug("Ignoring event %s, sent before bot started.", event.event_id) return - event = event.event_id - result = await self.room_read_markers(room, event, event) + event_id = event.event_id + result = await self.room_read_markers(room, event_id, event_id) if not isinstance(result, nio.RoomReadMarkersResponse): - self.log.warning("Failed to update read receipts for %s: %s", room, result.message) + msg = result.message if isinstance(result, nio.ErrorResponse) else "?" + self.log.warning("Failed to update read receipts for %s: %s", room, msg) else: self.log.debug("Updated read receipts for %s to %s.", room, event) - async def process_message(self, room: nio.MatrixRoom, event: nio.RoomMessageText): + async def process_message(self, room: nio.MatrixRoom, event: nio.RoomMessageText) -> None: """Processes a message and runs the command it is trying to invoke if any.""" + if self.start_time is None: + raise RuntimeError("Bot has not started yet!") + self.message_cache.append((room, event)) self.dispatch("message", room, event) if event.sender == self.user: @@ -258,7 +257,7 @@ async def process_message(self, room: nio.MatrixRoom, event: nio.RoomMessageText return if self.case_insensitive: - content = event.body.lower() + content = event.body.casefold() else: content = event.body @@ -356,7 +355,7 @@ def mount_module(self, import_path: str) -> typing.Optional[list[Command]]: populated), but the event loop will be running. :param import_path: The import path (such as modules.file), which would be ./modules/file.py in a file tree. - :returns: Optional[List[Command]] - A list of commands mounted. None if the module's setup() was called. + :returns: Optional[list[Command]] - A list of commands mounted. None if the module's setup() was called. :raise ImportError: The module path is incorrect of there was another error while importing :raise TypeError: The module was not a subclass of Module. :raise ValueError: There was an error registering a command (e.g. name conflict) @@ -391,7 +390,7 @@ def mount_module(self, import_path: str) -> typing.Optional[list[Command]]: return added @property - def commands(self) -> typing.Dict[str, Command]: + def commands(self) -> dict[str, Command]: """Returns the internal command register. !!! warning @@ -406,7 +405,7 @@ def commands(self) -> typing.Dict[str, Command]: return self._commands @property - def modules(self) -> typing.Dict[typing.Type[Module], Module]: + def modules(self) -> dict[typing.Type[Module], Module]: """Returns the internal module register. !!! warning @@ -449,7 +448,7 @@ def remove_command(self, command: Command) -> None: for alias in command.aliases: self.log.debug("Removed command %r from the register.", self._commands.pop(alias, None)) - def command(self, name: str = None, **kwargs): + def command(self, name: typing.Optional[str] = None, **kwargs): """Registers a command with the bot.""" cls = kwargs.pop("cls", Command) @@ -467,15 +466,15 @@ def add_event_listener(self, event_type: str, func): self._events[event_type].append(func) self.log.debug("Added event listener %r for %r", func, event_type) - def on_event(self, event_type: str = None): + def on_event(self, event_type: typing.Optional[str] = None): """Wrapper that allows you to register an event handler""" - if event_type.startswith("on_"): - self.log.warning("No events start with 'on_' - stripping prefix") - event_type = event_type[3:] def wrapper(func): nonlocal event_type event_type = event_type or func.__name__ + if event_type.startswith("on_"): + self.log.warning("No events start with 'on_' - stripping prefix") + event_type = event_type[3:] self.add_event_listener(event_type, func) return func @@ -530,11 +529,11 @@ async def fetch_message(self, room_id: str, event_id: str): async def wait_for_message( self, - room_id: str = None, - sender: str = None, - check: typing.Callable[[nio.MatrixRoom, nio.RoomMessageText], typing.Any] = None, + room_id: typing.Optional[str] = None, + sender: typing.Optional[str] = None, + check: typing.Optional[typing.Callable[[nio.MatrixRoom, nio.RoomMessageText], typing.Any]] = None, *, - timeout: float = None, + timeout: typing.Optional[float] = None, ) -> typing.Optional[typing.Tuple[nio.MatrixRoom, nio.RoomMessageText]]: """Waits for a message, optionally with a filter. @@ -543,14 +542,12 @@ async def wait_for_message( value = None async def event_handler(_room, _event): - if room_id: - if _room.room_id != room_id: - self.log.debug("Ignoring bubbling message from %r (vs %r)", _room.room_id, room_id) - return False - if sender: - if _event.sender != sender: - self.log.debug("Ignoring bubbling message from %r (vs %r)", _event.sender, sender) - return False + if room_id and _room.room_id != room_id: + self.log.debug("Ignoring bubbling message from %r (vs %r)", _room.room_id, room_id) + return False + if sender and _event.sender != sender: + self.log.debug("Ignoring bubbling message from %r (vs %r)", _event.sender, sender) + return False if check: try: result = await force_await(check, _room, _event) @@ -581,7 +578,13 @@ async def _markdown_to_html(text: str) -> str: return rendered @staticmethod - def _get_id(obj) -> str: + def _get_id(obj: typing.Union[nio.Event, nio.MatrixRoom, nio.MatrixUser, str, typing.Any]) -> str: + """Gets the id of most objects as a string. + :param obj: The object who's ID to get, or the ID itself. + :type obj: typing.Union[nio.Event, nio.MatrixRoom, nio.MatrixUser, str, Any] + :returns: the ID of the object + :raises: ValueError - the Object doesn't have an ID + """ if hasattr(obj, "event_id"): return obj.event_id if hasattr(obj, "room_id"): @@ -612,8 +615,8 @@ def generate_mx_reply(room: nio.MatrixRoom, event: nio.RoomMessageText) -> str: ) async def _recursively_upload_attachments( - self, base: "BaseAttachment", encrypted: bool = False, __previous: list = None - ) -> list[typing.Union[nio.UploadResponse, nio.UploadError, type(None)]]: + self, base: "BaseAttachment", encrypted: bool = False, __previous: typing.Optional[list] = None + ) -> list[typing.Union[nio.UploadResponse, nio.UploadError, None]]: """Recursively uploads attachments.""" previous = (__previous or []).copy() if not base.url: @@ -652,20 +655,21 @@ async def get_dm_room(self, user: U[nio.MatrixUser, str]) -> nio.MatrixRoom: if not isinstance(room, nio.RoomCreateResponse): raise NioBotException("Unable to create DM room for %r: %r" % (user_id, room), response=room) self.log.debug("Created DM room for %r: %r", user_id, room) - room = self.rooms.get(room.room_id) + room_id = room.room_id + room = self.rooms.get(room_id) if not room: - raise RuntimeError("DM room %r was created, but could not be found in the room list!" % room.room_id) + raise RuntimeError("DM room %r was created, but could not be found in the room list!" % room_id) self.direct_rooms[user_id] = room return room async def send_message( self, room: U[nio.MatrixRoom, nio.MatrixUser, str], - content: str = None, - file: BaseAttachment = None, - reply_to: U[nio.RoomMessageText, str] = None, - message_type: str = None, - clean_mentions: bool = False, + content: typing.Optional[str] = None, + file: typing.Optional[BaseAttachment] = None, + reply_to: typing.Optional[U[nio.RoomMessageText, str]] = None, + message_type: typing.Optional[str] = None, + clean_mentions: typing.Optional[bool] = False, ) -> nio.RoomSendResponse: """ Sends a message. @@ -696,7 +700,7 @@ async def send_message( self.log.debug("Send message resolved room to %r", room) - body = { + body: dict[str, typing.Any] = { "msgtype": message_type or self.global_message_type, } @@ -711,7 +715,7 @@ async def send_message( body = file.as_body(content) else: - if clean_mentions: + if clean_mentions and content: content = content.replace("@", "@\u200b") body["body"] = content if self.automatic_markdown_renderer: @@ -745,7 +749,7 @@ async def edit_message( message: U[nio.Event, str], content: str, *, - message_type: str = None, + message_type: typing.Optional[str] = None, clean_mentions: bool = False, ) -> nio.RoomSendResponse: """ @@ -761,11 +765,13 @@ async def edit_message( :raises RuntimeError: If you are not the sender of the message. :raises TypeError: If the message is not text. """ + room = self._get_id(room) + if clean_mentions: content = content.replace("@", "@\u200b") event_id = self._get_id(message) message_type = message_type or self.global_message_type - content = { + content_dict = { "msgtype": message_type, "body": content, "format": "org.matrix.custom.html", @@ -774,27 +780,27 @@ async def edit_message( body = { "msgtype": message_type, - "body": " * %s" % content["body"], - "m.new_content": {**content}, + "body": " * %s" % content_dict["body"], + "m.new_content": {**content_dict}, "m.relates_to": { "rel_type": "m.replace", "event_id": event_id, }, } - async with Typing(self, room.room_id): + async with Typing(self, room): response = await self.room_send( - self._get_id(room), + room, "m.room.message", body, ) if isinstance(response, nio.RoomSendError): raise MessageException("Failed to edit message.", response) # Forcefully clear typing - await self.room_typing(room.room_id, False) + await self.room_typing(room, False) return response async def delete_message( - self, room: U[nio.MatrixRoom, str], message_id: U[nio.RoomMessage, str], reason: str = None + self, room: U[nio.MatrixRoom, str], message_id: U[nio.RoomMessage, str], reason: typing.Optional[str] = None ) -> nio.RoomRedactResponse: """ Delete an existing message. You must be the sender of the message. @@ -850,7 +856,12 @@ async def redact_reaction(self, room: U[nio.MatrixRoom, str], reaction: U[nio.Ro raise MessageException("Failed to delete reaction.", response) return response - async def start(self, password: str = None, access_token: str = None, sso_token: str = None) -> None: + async def start( + self, + password: typing.Optional[str] = None, + access_token: typing.Optional[str] = None, + sso_token: typing.Optional[str] = None, + ) -> None: """Starts the bot, running the sync loop.""" self.loop = asyncio.get_event_loop() if password or sso_token: @@ -864,10 +875,10 @@ async def start(self, password: str = None, access_token: str = None, sso_token: login_response = await self.login(password=password, token=sso_token, device_name=self.device_id) if isinstance(login_response, nio.LoginError): raise LoginException("Failed to log in.", login_response) - else: - self.log.info("Logged in as %s", login_response.user_id) - self.log.debug("Logged in: {0.access_token}, {0.user_id}".format(login_response)) - self.start_time = time.time() + + self.log.info("Logged in as %s", login_response.user_id) + self.log.debug("Logged in: {0.access_token}, {0.user_id}".format(login_response)) + self.start_time = time.time() elif access_token: self.log.info("Logging in with existing access token.") if self.store_path: @@ -905,7 +916,13 @@ async def start(self, password: str = None, access_token: str = None, sso_token: self.log.info("Closing http session and logging out.") await self.close() - def run(self, *, password: str = None, access_token: str = None, sso_token: str = None) -> None: + def run( + self, + *, + password: typing.Optional[str] = None, + access_token: typing.Optional[str] = None, + sso_token: typing.Optional[str] = None, + ) -> None: """ Runs the bot, blocking the program until the event loop exists. This should be the last function to be called in your script, as once it exits, the bot will stop running. diff --git a/src/niobot/commands.py b/src/niobot/commands.py index 886ba8c..1e11942 100644 --- a/src/niobot/commands.py +++ b/src/niobot/commands.py @@ -7,6 +7,7 @@ from .context import Context from .exceptions import * +from collections.abc import Callable if typing.TYPE_CHECKING: from .client import NioBot @@ -46,7 +47,7 @@ def __init__( name: str, arg_type: _T, *, - description: str = None, + description: typing.Optional[str] = None, default: typing.Any = ..., required: bool = ..., parser: typing.Callable[["Context", "Argument", str], typing.Optional[_T]] = ..., @@ -146,10 +147,10 @@ def hello(ctx: niobot.Context): def __init__( self, name: str, - callback: callable, + callback: Callable, *, - aliases: list[str] = None, - description: str = None, + aliases: typing.Optional[list[str]] = None, + description: typing.Optional[str] = None, disabled: bool = False, hidden: bool = False, greedy: bool = False, @@ -236,16 +237,15 @@ def display_usage(self) -> str: """Returns the usage string for this command, auto-resolved if not pre-defined""" if self.usage: return self.usage - else: - usage = [] - req = "<{!s}>" - opt = "[{!s}]" - for arg in self.arguments[1:]: - if arg.required: - usage.append(req.format(arg.name)) - else: - usage.append(opt.format(arg.name)) - return " ".join(usage) + usage = [] + req = "<{!s}>" + opt = "[{!s}]" + for arg in self.arguments[1:]: + if arg.required: + usage.append(req.format(arg.name)) + else: + usage.append(opt.format(arg.name)) + return " ".join(usage) async def invoke(self, ctx: Context) -> typing.Coroutine: """ @@ -278,9 +278,8 @@ async def invoke(self, ctx: Context) -> typing.Coroutine: if index >= len(ctx.args): if argument.required: raise CommandArgumentsError(f"Missing required argument {argument.name}") - else: - parsed_args.append(argument.default) - continue + parsed_args.append(argument.default) + continue self.log.debug("Resolved argument %s to %r", argument.name, ctx.args[index]) try: @@ -332,7 +331,7 @@ def construct_context( return cls(client, room, src_event, self, invoking_prefix=invoking_prefix, invoking_string=meta) -def command(name: str = None, **kwargs) -> callable: +def command(name: typing.Optional[str] = None, **kwargs) -> Callable: """ Allows you to register commands later on, by loading modules. @@ -356,8 +355,8 @@ def decorator(func): def check( function: typing.Callable[[Context], typing.Union[bool, typing.Coroutine[None, None, bool]]], - name: str = None, -) -> callable: + name: typing.Optional[str] = None, +) -> Callable: """ Allows you to register checks in modules. @@ -384,7 +383,7 @@ def decorator(command_function): return decorator -def event(name: str) -> callable: +def event(name: str) -> Callable: """ Allows you to register event listeners in modules. diff --git a/src/niobot/context.py b/src/niobot/context.py index 148957b..b22ff41 100644 --- a/src/niobot/context.py +++ b/src/niobot/context.py @@ -63,7 +63,7 @@ async def edit(self, content: str, **kwargs) -> "ContextualResponse": await self.ctx.client.edit_message(self.ctx.room, self._response.event_id, content, **kwargs) return self - async def delete(self, reason: str = None) -> None: + async def delete(self, reason: typing.Optional[str] = None) -> None: """ Redacts the current response. @@ -84,7 +84,7 @@ def __init__( command: "Command", *, invoking_prefix: typing.Optional[str] = None, - invoking_string: str = None, + invoking_string: typing.Optional[str] = None, ): self._init_ts = time.time() self._client = _client @@ -152,7 +152,9 @@ def latency(self) -> float: """Returns the current event's latency in milliseconds.""" return self.client.latency(self.event, received_at=self._init_ts) - async def respond(self, content: str = None, file: "BaseAttachment" = None) -> ContextualResponse: + async def respond( + self, content: typing.Optional[str] = None, file: typing.Optional["BaseAttachment"] = None + ) -> ContextualResponse: """ Responds to the current event. diff --git a/src/niobot/exceptions.py b/src/niobot/exceptions.py index 41380cb..ceab4c9 100644 --- a/src/niobot/exceptions.py +++ b/src/niobot/exceptions.py @@ -47,24 +47,24 @@ class NioBotException(Exception): def __init__( self, - message: str = None, - response: nio.ErrorResponse = None, + message: typing.Optional[str] = None, + response: typing.Optional[nio.ErrorResponse] = None, *, - exception: BaseException = None, - original: typing.Union[nio.ErrorResponse, BaseException] = None, + exception: typing.Optional[BaseException] = None, + original: typing.Optional[typing.Union[nio.ErrorResponse, BaseException]] = None, ): if original: warnings.warn(DeprecationWarning("original is deprecated, use response or exception instead")) self.original = original or response or exception self.response = response - self.exception: typing.Union[nio.ErrorResponse, BaseException] = exception + self.exception: typing.Optional[typing.Union[nio.ErrorResponse, BaseException]] = exception self.message = message if self.original is None and self.message is None: raise ValueError("If there is no error history, at least a human readable message should be provided.") def bottom_of_chain( - self, other: typing.Union[Exception, nio.ErrorResponse] = None + self, other: typing.Optional[typing.Union[Exception, nio.ErrorResponse]] = None ) -> typing.Union[BaseException, nio.ErrorResponse]: """Recursively checks the `original` attribute of the exception until it reaches the bottom of the chain. @@ -200,9 +200,9 @@ class CheckFailure(CommandPreparationError): def __init__( self, - check_name: str = None, - message: str = None, - exception: BaseException = None, + check_name: typing.Optional[str] = None, + message: typing.Optional[str] = None, + exception: typing.Optional[BaseException] = None, ): if not message: message = f"Check {check_name} failed." @@ -224,7 +224,12 @@ class NotOwner(CheckFailure): Exception raised when the command invoker is not the owner of the bot. """ - def __init__(self, check_name: str = None, message: str = None, exception: BaseException = None): + def __init__( + self, + check_name: typing.Optional[str] = None, + message: typing.Optional[str] = None, + exception: typing.Optional[BaseException] = None, + ): if not message: message = "You are not the owner of this bot." super().__init__(check_name, message, exception) @@ -236,7 +241,13 @@ class InsufficientPower(CheckFailure): """ def __init__( - self, check_name: str = None, message: str = None, exception: BaseException = None, *, needed: int, have: int + self, + check_name: typing.Optional[str] = None, + message: typing.Optional[str] = None, + exception: typing.Optional[BaseException] = None, + *, + needed: int, + have: int, ): if not message: message = "Insufficient power level. Needed %d, have %d." % (needed, have) @@ -248,7 +259,12 @@ class NotADirectRoom(CheckFailure): Exception raised when the current room is not `m.direct` (a DM room) """ - def __init__(self, check_name: str = None, message: str = None, exception: BaseException = None): + def __init__( + self, + check_name: typing.Optional[str] = None, + message: typing.Optional[str] = None, + exception: typing.Optional[BaseException] = None, + ): if not message: message = "This command can only be run in a direct message room." super().__init__(check_name, message, exception) diff --git a/src/niobot/utils/checks.py b/src/niobot/utils/checks.py index 27cb01c..b208d35 100644 --- a/src/niobot/utils/checks.py +++ b/src/niobot/utils/checks.py @@ -1,6 +1,7 @@ from ..commands import check from ..context import Context from ..exceptions import CheckFailure, InsufficientPower, NotOwner +from typing import Optional __all__ = ( "is_owner", @@ -10,7 +11,7 @@ ) -def is_owner(*extra_owner_ids, name: str = None): +def is_owner(*extra_owner_ids, name: Optional[str] = None): """ Requires the sender owns the bot ([`NioBot.owner_id`][]), or is in `extra_owner_ids`. :param extra_owner_ids: A set of `@localpart:homeserver.tld` strings to check against. @@ -29,7 +30,7 @@ def predicate(ctx): return check(predicate, name) -def is_dm(allow_dual_membership: bool = False, name: str = None): +def is_dm(allow_dual_membership: bool = False, name: Optional[str] = None): """ Requires that the current room is a DM with the sender. @@ -50,7 +51,7 @@ def predicate(ctx: "Context"): return check(predicate, name) -def sender_has_power(level: int, room_creator_bypass: bool = False, name: str = None): +def sender_has_power(level: int, room_creator_bypass: bool = False, name: Optional[str] = None): """ Requires that the sender has a certain power level in the current room before running the command. @@ -70,7 +71,7 @@ def predicate(ctx): return check(predicate, name) -def client_has_power(level: int, name: str = None): +def client_has_power(level: int, name: Optional[str] = None): """ Requires that the bot has a certain power level in the current room before running the command. diff --git a/src/niobot/utils/help_command.py b/src/niobot/utils/help_command.py index d58dc14..90ec69c 100644 --- a/src/niobot/utils/help_command.py +++ b/src/niobot/utils/help_command.py @@ -27,7 +27,7 @@ def clean_output( escape_room_references: bool = False, escape_all_periods: bool = False, escape_all_at_signs: bool = False, - escape_method: typing.Callable[[str], str] = None, + escape_method: typing.Optional[typing.Callable[[str], str]] = None, ) -> str: """ Escapes given text and sanitises it, ready for outputting to the user. @@ -51,9 +51,11 @@ def clean_output( """ if escape_method is None: - def escape_method(x: str) -> str: + def default_escape_method(x: str) -> str: return "\u200b".join(x.split()) + escape_method = default_escape_method + if escape_user_mentions: text = re.sub(r"@([A-Za-z0-9\-_=+./]+):([A-Za-z0-9\-_=+./]+)", escape_method("@\\1:\\2"), text) if escape_room_mentions: @@ -79,7 +81,7 @@ def format_command_name(command: "Command") -> str: def format_command_line(prefix: str, command: "Command") -> str: """Formats a command line, including name(s) & usage.""" name = format_command_name(command) - start = "{}{}".format(prefix, name) + start = f"{prefix}{name}" start += " " + command.display_usage.strip().replace("\n", "") return start diff --git a/src/niobot/utils/parsers.py b/src/niobot/utils/parsers.py index 4b9ab5f..984d32d 100644 --- a/src/niobot/utils/parsers.py +++ b/src/niobot/utils/parsers.py @@ -40,7 +40,7 @@ def boolean_parser(_: "Context", __, value: str) -> bool: """ - Converts a given string into a boolean. Value is lower-cased before being parsed. + Converts a given string into a boolean. Value is casefolded before being parsed. The following resolves to true: * 1, y, yes, true, on @@ -52,10 +52,10 @@ def boolean_parser(_: "Context", __, value: str) -> bool: :return: The parsed boolean """ - value = value.lower() - if value in ("1", "y", "yes", "true", "on"): + value = value.casefold() + if value in {"1", "y", "yes", "true", "on"}: return True - if value in ("0", "n", "no", "false", "off"): + if value in {"0", "n", "no", "false", "off"}: return False raise CommandParserError(f"Invalid boolean value: {value}. Should be a sensible value, such as 1, yes, false.") @@ -99,9 +99,7 @@ def __parser(_, __, v) -> typing.Union[int, float]: return __parser -def json_parser( - _: "Context", __: "Argument", value: str -) -> typing.Union[list, dict, str, int, float, type(None), bool]: +def json_parser(_: "Context", __: "Argument", value: str) -> typing.Union[list, dict, str, int, float, None, bool]: """ Converts a given string into a JSON object. @@ -161,7 +159,7 @@ async def room_parser(ctx: "Context", arg: "Argument", value: str) -> nio.Matrix raise CommandParserError(f"Invalid room ID, alias, or matrix.to link: {value!r}.") if room is None: - raise CommandParserError(f"No room with that ID, alias, or matrix.to link found.") + raise CommandParserError("No room with that ID, alias, or matrix.to link found.") return room @@ -222,42 +220,42 @@ def matrix_to_parser( """ async def internal(ctx: "Context", _, value: str) -> MatrixToLink: - if m := MATRIX_TO_REGEX.match(value): - # matrix.to link - groups = m.groupdict() - event_id = groups.get("event_id", "") - room_id = groups.get("room_id", "") - event_id = urllib.unquote(event_id) - room_id = urllib.unquote(room_id) + if not (m := MATRIX_TO_REGEX.match(value)): + raise CommandParserError(f"Invalid matrix.to link: {value!r}.") - if require_room and not room_id: - raise CommandParserError(f"Invalid matrix.to link: {value} (no room).") - if require_event and not event_id: - raise CommandParserError(f"Invalid matrix.to link: {value} (no event).") + # matrix.to link + groups = m.groupdict() + event_id = groups.get("event_id", "") + room_id = groups.get("room_id", "") + event_id = urllib.unquote(event_id) + room_id = urllib.unquote(room_id) - if room_id.startswith("@") and not allow_user_as_room: - raise CommandParserError(f"Invalid matrix.to link: {value} (expected room, got user).") + if require_room and not room_id: + raise CommandParserError(f"Invalid matrix.to link: {value} (no room).") + if require_event and not event_id: + raise CommandParserError(f"Invalid matrix.to link: {value} (no event).") + + if room_id.startswith("@") and not allow_user_as_room: + raise CommandParserError(f"Invalid matrix.to link: {value} (expected room, got user).") - if room_id.startswith("@"): - room = await ctx.client.get_dm_room(room_id) - else: - room = ctx.client.rooms.get(room_id) + if room_id.startswith("@"): + room = await ctx.client.get_dm_room(room_id) + else: + room = ctx.client.rooms.get(room_id) - if room is None: - raise CommandParserError(f"No room with that ID, alias, or matrix.to link found.") + if room is None: + raise CommandParserError("No room with that ID, alias, or matrix.to link found.") - if event_id: - event: U[nio.RoomGetEventResponse, nio.RoomGetEventError] = await ctx.client.room_get_event( - room_id, event_id - ) - if not isinstance(event, nio.RoomGetEventResponse): - raise CommandParserError(f"Invalid event ID: {event_id}.", response=event) - event: nio.Event = event.event - else: - event: None = None - return MatrixToLink(room, event, groups.get("qs")) + if event_id: + event: U[nio.RoomGetEventResponse, nio.RoomGetEventError] = await ctx.client.room_get_event( + room_id, event_id + ) + if not isinstance(event, nio.RoomGetEventResponse): + raise CommandParserError(f"Invalid event ID: {event_id}.", response=event) + event: nio.Event = event.event else: - raise CommandParserError(f"Invalid matrix.to link: {value!r}.") + event: None = None + return MatrixToLink(room, event, groups.get("qs")) return internal diff --git a/src/niobot/utils/string_view.py b/src/niobot/utils/string_view.py index c4b7730..667f9db 100644 --- a/src/niobot/utils/string_view.py +++ b/src/niobot/utils/string_view.py @@ -24,7 +24,7 @@ def __init__(self, string: str): self.source = string self.index = 0 - self.arguments = [] + self.arguments: list[str] = [] def add_arg(self, argument: str) -> None: """Adds an argument to the argument list @@ -65,17 +65,14 @@ def parse_arguments(self) -> "ArgumentView": self.add_arg(reconstructed) reconstructed = "" quote_char = None + elif self.index == 0: # cannot be an escaped string + quote_started = True + quote_char = char + elif self.index > 0 and self.source[self.index - 1] != "\\": + quote_started = True + quote_char = char else: - if self.index == 0: # cannot be an escaped string - quote_started = True - quote_char = char - elif self.index > 0 and self.source[self.index - 1] != "\\": - quote_started = True - quote_char = char - # If it is an escaped quote, we can add it to the string. - else: - reconstructed += char - # If the character is a space, we can add the reconstructed string to the arguments list + reconstructed += char elif char.isspace(): if quote_started: reconstructed += char @@ -83,7 +80,6 @@ def parse_arguments(self) -> "ArgumentView": self.add_arg(reconstructed) reconstructed = "" quote_char = None - # Any other character can be added to the current string elif char: # elif ensures the character isn't null reconstructed += char self.index += 1 diff --git a/src/niobot/utils/typing.py b/src/niobot/utils/typing.py index 97a7616..c9905e6 100644 --- a/src/niobot/utils/typing.py +++ b/src/niobot/utils/typing.py @@ -9,7 +9,7 @@ __all__ = ("Typing",) log = logging.getLogger(__name__) -_TYPING_STATES: typing.Dict[str, "Typing"] = {} +_TYPING_STATES: dict[str, "Typing"] = {} class Typing: diff --git a/src/niobot/utils/unblocking.py b/src/niobot/utils/unblocking.py index d5c8f8c..4adf0e8 100644 --- a/src/niobot/utils/unblocking.py +++ b/src/niobot/utils/unblocking.py @@ -2,11 +2,14 @@ import functools import typing from typing import Any +from collections.abc import Callable __all__ = ("run_blocking", "force_await") +T = typing.TypeVar("T") -async def run_blocking(function: typing.Callable, *args: Any, **kwargs: Any) -> Any: + +async def run_blocking(function: Callable[..., T], *args: Any, **kwargs: Any) -> T: """ Takes a blocking function and runs it in a thread, returning the result.