Skip to content

Commit

Permalink
Add an exposed way to extract shard-specific information.
Browse files Browse the repository at this point in the history
Closes #2654
  • Loading branch information
Rapptz committed Jul 25, 2020
1 parent a42bebe commit 7ed26db
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 16 deletions.
2 changes: 1 addition & 1 deletion discord/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
from .enums import *
from .embeds import Embed
from .mentions import AllowedMentions
from .shard import AutoShardedClient
from .shard import AutoShardedClient, ShardInfo
from .player import *
from .webhook import *
from .voice_client import VoiceClient
Expand Down
104 changes: 93 additions & 11 deletions discord/shard.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,10 @@ async def close(self):
self._cancel_task()
await self.ws.close(code=1000)

async def disconnect(self):
await self.close()
self._dispatch('shard_disconnect', self.id)

async def _handle_disconnect(self, e):
self._dispatch('disconnect')
self._dispatch('shard_disconnect', self.id)
Expand Down Expand Up @@ -178,6 +182,70 @@ async def reconnect(self):
else:
self.launch()

class ShardInfo:
"""A class that gives information and control over a specific shard.
You can retrieve this object via :meth:`AutoShardedClient.get_shard`
or :attr:`AutoShardedClient.shards`.
.. versionadded:: 1.4
Attributes
------------
id: :class:`int`
The shard ID for this shard.
shard_count: Optional[:class:`int`]
The shard count for this cluster. If this is ``None`` then the bot has not started yet.
"""

__slots__ = ('_parent', 'id', 'shard_count')

def __init__(self, parent, shard_count):
self._parent = parent
self.id = parent.id
self.shard_count = shard_count

def is_closed(self):
""":class:`bool`: Whether the shard connection is currently closed."""
return not self._parent.ws.open

async def disconnect(self):
"""|coro|
Disconnects a shard. When this is called, the shard connection will no
longer be open.
If the shard is already disconnected this does nothing.
"""
if self.is_closed():
return

await self._parent.disconnect()

async def reconnect(self):
"""|coro|
Disconnects and then connects the shard again.
"""
if not self.is_closed():
await self._parent.disconnect()
await self._parent.reconnect()

async def connect(self):
"""|coro|
Connects a shard. If the shard is already connected this does nothing.
"""
if not self.is_closed():
return

await self._parent.reconnect()

@property
def latency(self):
""":class:`float`: Measures latency between a HEARTBEAT and a HEARTBEAT_ACK in seconds for this shard."""
return self._parent.ws.latency

class AutoShardedClient(Client):
"""A client similar to :class:`Client` except it handles the complications
of sharding for the user into a more manageable and transparent single
Expand Down Expand Up @@ -221,14 +289,14 @@ def __init__(self, *args, loop=None, **kwargs):

# instead of a single websocket, we have multiple
# the key is the shard_id
self.shards = {}
self.__shards = {}
self._connection._get_websocket = self._get_websocket
self._queue = asyncio.PriorityQueue()
self.__queue = asyncio.PriorityQueue()

def _get_websocket(self, guild_id=None, *, shard_id=None):
if shard_id is None:
shard_id = (guild_id >> 22) % self.shard_count
return self.shards[shard_id].ws
return self.__shards[shard_id].ws

@property
def latency(self):
Expand All @@ -238,17 +306,31 @@ def latency(self):
latency of every shard's latency. To get a list of shard latency, check the
:attr:`latencies` property. Returns ``nan`` if there are no shards ready.
"""
if not self.shards:
if not self.__shards:
return float('nan')
return sum(latency for _, latency in self.latencies) / len(self.shards)
return sum(latency for _, latency in self.latencies) / len(self.__shards)

@property
def latencies(self):
"""List[Tuple[:class:`int`, :class:`float`]]: A list of latencies between a HEARTBEAT and a HEARTBEAT_ACK in seconds.
This returns a list of tuples with elements ``(shard_id, latency)``.
"""
return [(shard_id, shard.ws.latency) for shard_id, shard in self.shards.items()]
return [(shard_id, shard.ws.latency) for shard_id, shard in self.__shards.items()]

def get_shard(self, shard_id):
"""Optional[:class:`ShardInfo`]: Gets the shard information at a given shard ID or ``None`` if not found."""
try:
parent = self.__shards[shard_id]
except KeyError:
return None
else:
return ShardInfo(parent, self.shard_count)

@utils.cached_property
def shards(self):
"""Mapping[int, :class:`ShardInfo`]: Returns a mapping of shard IDs to their respective info object."""
return { shard_id: ShardInfo(parent, self.shard_count) for shard_id, parent in self.__shards.items() }

async def request_offline_members(self, *guilds):
r"""|coro|
Expand Down Expand Up @@ -291,7 +373,7 @@ async def launch_shard(self, gateway, shard_id, *, initial=False):
return await self.launch_shard(gateway, shard_id)

# keep reading the shard while others connect
self.shards[shard_id] = ret = Shard(ws, self)
self.__shards[shard_id] = ret = Shard(ws, self)
ret.launch()

async def launch_shards(self):
Expand All @@ -316,7 +398,7 @@ async def connect(self, *, reconnect=True):
await self.launch_shards()

while not self.is_closed():
item = await self._queue.get()
item = await self.__queue.get()
if item.type == EventType.close:
await self.close()
if isinstance(item.error, ConnectionClosed) and item.error.code != 1000:
Expand Down Expand Up @@ -346,7 +428,7 @@ async def close(self):
except Exception:
pass

to_close = [asyncio.ensure_future(shard.close(), loop=self.loop) for shard in self.shards.values()]
to_close = [asyncio.ensure_future(shard.close(), loop=self.loop) for shard in self.__shards.values()]
if to_close:
await asyncio.wait(to_close)

Expand Down Expand Up @@ -395,12 +477,12 @@ async def change_presence(self, *, activity=None, status=None, afk=False, shard_
status = str(status)

if shard_id is None:
for shard in self.shards.values():
for shard in self.__shards.values():
await shard.ws.change_presence(activity=activity, status=status, afk=afk)

guilds = self._connection.guilds
else:
shard = self.shards[shard_id]
shard = self.__shards[shard_id]
await shard.ws.change_presence(activity=activity, status=status, afk=afk)
guilds = [g for g in self._connection.guilds if g.shard_id == shard_id]

Expand Down
14 changes: 10 additions & 4 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2756,7 +2756,7 @@ Data Classes
Some classes are just there to be data containers, this lists them.

Unlike :ref:`models <discord_api_models>` you are allowed to create
these yourself, even if they can also be used to hold attributes.
most of these yourself, even if they can also be used to hold attributes.

Nearly all classes here have :ref:`py:slots` defined which means that it is
impossible to have dynamic attributes to the data classes.
Expand Down Expand Up @@ -2837,22 +2837,28 @@ PermissionOverwrite
.. autoclass:: PermissionOverwrite
:members:

ShardInfo
~~~~~~~~~~~

.. autoclass:: ShardInfo()
:members:

SystemChannelFlags
~~~~~~~~~~~~~~~~~~~~

.. autoclass:: SystemChannelFlags
.. autoclass:: SystemChannelFlags()
:members:

MessageFlags
~~~~~~~~~~~~

.. autoclass:: MessageFlags
.. autoclass:: MessageFlags()
:members:

PublicUserFlags
~~~~~~~~~~~~~~~

.. autoclass:: PublicUserFlags
.. autoclass:: PublicUserFlags()
:members:


Expand Down

0 comments on commit 7ed26db

Please sign in to comment.