diff --git a/discord/gateway.py b/discord/gateway.py index 92fa4f56c913..b035448932db 100644 --- a/discord/gateway.py +++ b/discord/gateway.py @@ -719,6 +719,7 @@ def __init__(self, socket, loop): self.loop = loop self._keep_alive = None self._close_code = None + self.secret_key = None async def send_as_json(self, data): log.debug('Sending voice websocket frame: %s.', data) @@ -872,7 +873,7 @@ def average_latency(self): async def load_secret_key(self, data): log.info('received secret key for voice connection') - self._connection.secret_key = data.get('secret_key') + self.secret_key = self._connection.secret_key = data.get('secret_key') await self.speak() await self.speak(False) diff --git a/discord/voice_client.py b/discord/voice_client.py index a16aaf41dd0e..52c881098291 100644 --- a/discord/voice_client.py +++ b/discord/voice_client.py @@ -208,6 +208,7 @@ def __init__(self, client, channel): self._connected = threading.Event() self._handshaking = False + self._potentially_reconnecting = False self._voice_state_complete = asyncio.Event() self._voice_server_complete = asyncio.Event() @@ -250,8 +251,10 @@ async def on_voice_state_update(self, data): self.session_id = data['session_id'] channel_id = data['channel_id'] - if not self._handshaking: + if not self._handshaking or self._potentially_reconnecting: # If we're done handshaking then we just need to update ourselves + # If we're potentially reconnecting due to a 4014, then we need to differentiate + # a channel move and an actual force disconnect if channel_id is None: # We're being disconnected so cleanup await self.disconnect() @@ -294,26 +297,39 @@ async def on_voice_server_update(self, data): self._voice_server_complete.set() async def voice_connect(self): - self._connections += 1 await self.channel.guild.change_voice_state(channel=self.channel) async def voice_disconnect(self): log.info('The voice handshake is being terminated for Channel ID %s (Guild ID %s)', self.channel.id, self.guild.id) await self.channel.guild.change_voice_state(channel=None) + def prepare_handshake(self): + self._voice_state_complete.clear() + self._voice_server_complete.clear() + self._handshaking = True + log.info('Starting voice handshake... (connection attempt %d)', self._connections + 1) + self._connections += 1 + + def finish_handshake(self): + log.info('Voice handshake complete. Endpoint found %s', self.endpoint) + self._handshaking = False + self._voice_server_complete.clear() + self._voice_state_complete.clear() + + async def connect_websocket(self): + ws = await DiscordVoiceWebSocket.from_client(self) + self._connected.clear() + while ws.secret_key is None: + await ws.poll_event() + self._connected.set() + return ws + async def connect(self, *, reconnect, timeout): log.info('Connecting to voice...') self.timeout = timeout - try: - del self.secret_key - except AttributeError: - pass - for i in range(5): - self._voice_state_complete.clear() - self._voice_server_complete.clear() - self._handshaking = True + self.prepare_handshake() # This has to be created before we start the flow. futures = [ @@ -322,7 +338,6 @@ async def connect(self, *, reconnect, timeout): ] # Start the connection flow - log.info('Starting voice handshake... (connection attempt %d)', self._connections + 1) await self.voice_connect() try: @@ -331,17 +346,10 @@ async def connect(self, *, reconnect, timeout): await self.disconnect(force=True) raise - log.info('Voice handshake complete. Endpoint found %s', self.endpoint) - self._handshaking = False - self._voice_server_complete.clear() - self._voice_state_complete.clear() + self.finish_handshake() try: - self.ws = await DiscordVoiceWebSocket.from_client(self) - self._connected.clear() - while not hasattr(self, 'secret_key'): - await self.ws.poll_event() - self._connected.set() + self.ws = await self.connect_websocket() break except (ConnectionClosed, asyncio.TimeoutError): if reconnect: @@ -355,6 +363,26 @@ async def connect(self, *, reconnect, timeout): if self._runner is None: self._runner = self.loop.create_task(self.poll_voice_ws(reconnect)) + async def potential_reconnect(self): + self.prepare_handshake() + self._potentially_reconnecting = True + try: + # We only care about VOICE_SERVER_UPDATE since VOICE_STATE_UPDATE can come before we get disconnected + await asyncio.wait_for(self._voice_server_complete.wait(), timeout=self.timeout) + except asyncio.TimeoutError: + self._potentially_reconnecting = False + await self.disconnect(force=True) + return False + + self.finish_handshake() + self._potentially_reconnecting = False + try: + self.ws = await self.connect_websocket() + except (ConnectionClosed, asyncio.TimeoutError): + return False + else: + return True + @property def latency(self): """:class:`float`: Latency between a HEARTBEAT and a HEARTBEAT_ACK in seconds. @@ -387,10 +415,19 @@ async def poll_voice_ws(self, reconnect): # 1000 - normal closure (obviously) # 4014 - voice channel has been deleted. # 4015 - voice server has crashed - if exc.code in (1000, 4014, 4015): + if exc.code in (1000, 4015): log.info('Disconnecting from voice normally, close code %d.', exc.code) await self.disconnect() break + if exc.code == 4014: + log.info('Disconnected from voice by force... potentially reconnecting.') + successful = await self.potential_reconnect() + if not successful: + log.info('Reconnect was unsuccessful, disconnecting from voice normally...') + await self.disconnect() + break + else: + continue if not reconnect: await self.disconnect()