diff --git a/asyncssh/connection.py b/asyncssh/connection.py index 0faf4ea..dafeb91 100644 --- a/asyncssh/connection.py +++ b/asyncssh/connection.py @@ -899,6 +899,8 @@ def __init__(self, loop: asyncio.AbstractEventLoop, self._can_send_ext_info = False self._extensions_to_send: 'OrderedDict[bytes, bytes]' = OrderedDict() + self._can_recv_ext_info = False + self._server_sig_algs: Set[bytes] = set() self._next_service: Optional[bytes] = None @@ -908,6 +910,7 @@ def __init__(self, loop: asyncio.AbstractEventLoop, self._auth: Optional[Auth] = None self._auth_in_progress = False self._auth_complete = False + self._auth_final = False self._auth_methods = [b'none'] self._auth_was_trivial = True self._username = '' @@ -1538,15 +1541,25 @@ def _recv_packet(self) -> bool: skip_reason = '' exc_reason = '' - if self._kex and MSG_KEX_FIRST <= pkttype <= MSG_KEX_LAST: - if self._ignore_first_kex: # pragma: no cover - skip_reason = 'ignored first kex' - self._ignore_first_kex = False + if MSG_KEX_FIRST <= pkttype <= MSG_KEX_LAST: + if self._kex: + if self._ignore_first_kex: # pragma: no cover + skip_reason = 'ignored first kex' + self._ignore_first_kex = False + else: + handler = self._kex else: - handler = self._kex - elif (self._auth and - MSG_USERAUTH_FIRST <= pkttype <= MSG_USERAUTH_LAST): - handler = self._auth + skip_reason = 'kex not in progress' + exc_reason = 'Key exchange not in progress' + elif MSG_USERAUTH_FIRST <= pkttype <= MSG_USERAUTH_LAST: + if self._auth: + handler = self._auth + else: + skip_reason = 'auth not in progress' + exc_reason = 'Authentication not in progress' + elif pkttype > MSG_KEX_LAST and not self._recv_encryption: + skip_reason = 'invalid request before kex complete' + exc_reason = 'Invalid request before key exchange was complete' elif pkttype > MSG_USERAUTH_LAST and not self._auth_complete: skip_reason = 'invalid request before auth complete' exc_reason = 'Invalid request before authentication was complete' @@ -1579,6 +1592,9 @@ def _recv_packet(self) -> bool: if exc_reason: raise ProtocolError(exc_reason) + if pkttype > MSG_USERAUTH_LAST: + self._auth_final = True + if self._transport: self._recv_seq = (seq + 1) & 0xffffffff self._recv_handler = self._recv_pkthdr @@ -1596,9 +1612,7 @@ def send_packet(self, pkttype: int, *args: bytes, self._send_kexinit() self._kexinit_sent = True - if (((pkttype in {MSG_SERVICE_REQUEST, MSG_SERVICE_ACCEPT} or - pkttype > MSG_KEX_LAST) and not self._kex_complete) or - (pkttype == MSG_USERAUTH_BANNER and + if ((pkttype == MSG_USERAUTH_BANNER and not (self._auth_in_progress or self._auth_complete)) or (pkttype > MSG_USERAUTH_LAST and not self._auth_complete)): self._deferred_packets.append((pkttype, args)) @@ -1810,9 +1824,11 @@ def send_newkeys(self, k: bytes, h: bytes) -> None: not self._waiter.cancelled(): self._waiter.set_result(None) self._wait = None - else: - self.send_service_request(_USERAUTH_SERVICE) + return else: + self._extensions_to_send[b'server-sig-algs'] = \ + b','.join(self._sig_algs) + self._send_encryption = next_enc_sc self._send_enchdrlen = 1 if etm_sc else 5 self._send_blocksize = max(8, enc_blocksize_sc) @@ -1833,17 +1849,18 @@ def send_newkeys(self, k: bytes, h: bytes) -> None: recv_mac=self._mac_alg_cs.decode('ascii'), recv_compression=self._cmp_alg_cs.decode('ascii')) - if first_kex: - self._next_service = _USERAUTH_SERVICE - - self._extensions_to_send[b'server-sig-algs'] = \ - b','.join(self._sig_algs) - if self._can_send_ext_info: self._send_ext_info() self._can_send_ext_info = False self._kex_complete = True + + if first_kex: + if self.is_client(): + self.send_service_request(_USERAUTH_SERVICE) + else: + self._next_service = _USERAUTH_SERVICE + self._send_deferred_packets() def send_service_request(self, service: bytes) -> None: @@ -2080,18 +2097,25 @@ def _process_service_request(self, _pkttype: int, _pktid: int, service = packet.get_string() packet.check_end() - if service == self._next_service: - self.logger.debug2('Accepting request for service %s', service) + if self.is_client(): + raise ProtocolError('Unexpected service request received') - self.send_packet(MSG_SERVICE_ACCEPT, String(service)) + if not self._recv_encryption: + raise ProtocolError('Service request received before kex complete') - if (self.is_server() and # pragma: no branch - not self._auth_in_progress and - service == _USERAUTH_SERVICE): - self._auth_in_progress = True - self._send_deferred_packets() - else: - raise ServiceNotAvailable('Unexpected service request received') + if service != self._next_service: + raise ServiceNotAvailable('Unexpected service in service request') + + self.logger.debug2('Accepting request for service %s', service) + + self.send_packet(MSG_SERVICE_ACCEPT, String(service)) + + self._next_service = None + + if service == _USERAUTH_SERVICE: # pragma: no branch + self._auth_in_progress = True + self._can_recv_ext_info = False + self._send_deferred_packets() def _process_service_accept(self, _pkttype: int, _pktid: int, packet: SSHPacket) -> None: @@ -2100,27 +2124,35 @@ def _process_service_accept(self, _pkttype: int, _pktid: int, service = packet.get_string() packet.check_end() - if service == self._next_service: - self.logger.debug2('Request for service %s accepted', service) + if self.is_server(): + raise ProtocolError('Unexpected service accept received') - self._next_service = None + if not self._recv_encryption: + raise ProtocolError('Service accept received before kex complete') - if (self.is_client() and # pragma: no branch - service == _USERAUTH_SERVICE): - self.logger.info('Beginning auth for user %s', self._username) + if service != self._next_service: + raise ServiceNotAvailable('Unexpected service in service accept') - self._auth_in_progress = True + self.logger.debug2('Request for service %s accepted', service) - # This method is only in SSHClientConnection - # pylint: disable=no-member - cast('SSHClientConnection', self).try_next_auth() - else: - raise ServiceNotAvailable('Unexpected service accept received') + self._next_service = None + + if service == _USERAUTH_SERVICE: # pragma: no branch + self.logger.info('Beginning auth for user %s', self._username) + + self._auth_in_progress = True + + # This method is only in SSHClientConnection + # pylint: disable=no-member + cast('SSHClientConnection', self).try_next_auth() def _process_ext_info(self, _pkttype: int, _pktid: int, packet: SSHPacket) -> None: """Process extension information""" + if not self._can_recv_ext_info: + raise ProtocolError('Unexpected ext_info received') + extensions: Dict[bytes, bytes] = {} self.logger.debug2('Received extension info') @@ -2246,6 +2278,7 @@ def _process_newkeys(self, _pkttype: int, _pktid: int, self._decompress_after_auth = self._next_decompress_after_auth self._next_recv_encryption = None + self._can_recv_ext_info = True else: raise ProtocolError('New keys not negotiated') @@ -2273,8 +2306,10 @@ def _process_userauth_request(self, _pkttype: int, _pktid: int, if self.is_client(): raise ProtocolError('Unexpected userauth request') elif self._auth_complete: - # Silently ignore requests if we're already authenticated - pass + # Silently ignore additional auth requests after auth succeeds, + # until the client sends a non-auth message + if self._auth_final: + raise ProtocolError('Unexpected userauth request') else: if username != self._username: self.logger.info('Beginning auth for user %s', username) @@ -2316,7 +2351,7 @@ async def _finish_userauth(self, begin_auth: bool, method: bytes, self._auth = lookup_server_auth(cast(SSHServerConnection, self), self._username, method, packet) - def _process_userauth_failure(self, _pkttype: int, pktid: int, + def _process_userauth_failure(self, _pkttype: int, _pktid: int, packet: SSHPacket) -> None: """Process a user authentication failure response""" @@ -2356,10 +2391,9 @@ def _process_userauth_failure(self, _pkttype: int, pktid: int, # pylint: disable=no-member cast(SSHClientConnection, self).try_next_auth() else: - self.logger.debug2('Unexpected userauth failure response') - self.send_packet(MSG_UNIMPLEMENTED, UInt32(pktid)) + raise ProtocolError('Unexpected userauth failure response') - def _process_userauth_success(self, _pkttype: int, pktid: int, + def _process_userauth_success(self, _pkttype: int, _pktid: int, packet: SSHPacket) -> None: """Process a user authentication success response""" @@ -2385,6 +2419,7 @@ def _process_userauth_success(self, _pkttype: int, pktid: int, self._auth = None self._auth_in_progress = False self._auth_complete = True + self._can_recv_ext_info = False if self._agent: self._agent.close() @@ -2412,8 +2447,7 @@ def _process_userauth_success(self, _pkttype: int, pktid: int, self._waiter.set_result(None) self._wait = None else: - self.logger.debug2('Unexpected userauth success response') - self.send_packet(MSG_UNIMPLEMENTED, UInt32(pktid)) + raise ProtocolError('Unexpected userauth success response') def _process_userauth_banner(self, _pkttype: int, _pktid: int, packet: SSHPacket) -> None: diff --git a/tests/test_connection.py b/tests/test_connection.py index f1bf8c5..65b1542 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -30,11 +30,12 @@ from unittest.mock import patch import asyncssh -from asyncssh.constants import MSG_UNIMPLEMENTED, MSG_DEBUG +from asyncssh.constants import MSG_DEBUG from asyncssh.constants import MSG_SERVICE_REQUEST, MSG_SERVICE_ACCEPT -from asyncssh.constants import MSG_KEXINIT, MSG_NEWKEYS +from asyncssh.constants import MSG_KEXINIT, MSG_NEWKEYS, MSG_KEX_FIRST from asyncssh.constants import MSG_USERAUTH_REQUEST, MSG_USERAUTH_SUCCESS from asyncssh.constants import MSG_USERAUTH_FAILURE, MSG_USERAUTH_BANNER +from asyncssh.constants import MSG_USERAUTH_FIRST from asyncssh.constants import MSG_GLOBAL_REQUEST from asyncssh.constants import MSG_CHANNEL_OPEN, MSG_CHANNEL_OPEN_CONFIRMATION from asyncssh.constants import MSG_CHANNEL_OPEN_FAILURE, MSG_CHANNEL_DATA @@ -337,14 +338,6 @@ def begin_auth(self, username): return False -def disconnect_on_unimplemented(self, pkttype, pktid, packet): - """Process an unimplemented message response""" - - # pylint: disable=unused-argument - - self.disconnect(asyncssh.DISC_BY_APPLICATION, 'Unexpected response') - - @patch_gss @patch('asyncssh.connection.SSHClientConnection', _CheckAlgsClientConnection) class _TestConnection(ServerTestCase): @@ -974,8 +967,8 @@ def send_newkeys(self, k, h): with patch('asyncssh.connection.SSHClientConnection.send_newkeys', send_newkeys): - async with self.connect(): - pass + with self.assertRaises(asyncssh.ProtocolError): + await self.connect() @asynctest async def test_encryption_algs(self): @@ -1101,21 +1094,85 @@ async def test_invalid_debug(self): await conn.wait_closed() @asynctest - async def test_invalid_service_request(self): - """Test invalid service request""" + async def test_service_request_before_kex_complete(self): + """Test service request before kex is complete""" + + def send_newkeys(self, k, h): + """Finish a key exchange and send a new keys message""" + + self.send_packet(MSG_SERVICE_REQUEST, String('ssh-userauth')) + + asyncssh.connection.SSHConnection.send_newkeys(self, k, h) + + with patch('asyncssh.connection.SSHClientConnection.send_newkeys', + send_newkeys): + with self.assertRaises(asyncssh.ProtocolError): + await self.connect() + + @asynctest + async def test_service_accept_before_kex_complete(self): + """Test service accept before kex is complete""" + + def send_newkeys(self, k, h): + """Finish a key exchange and send a new keys message""" + + self.send_packet(MSG_SERVICE_ACCEPT, String('ssh-userauth')) + + asyncssh.connection.SSHConnection.send_newkeys(self, k, h) + + with patch('asyncssh.connection.SSHServerConnection.send_newkeys', + send_newkeys): + with self.assertRaises(asyncssh.ProtocolError): + await self.connect() + + @asynctest + async def test_unexpected_service_name_in_request(self): + """Test unexpected service name in service request""" conn = await self.connect() conn.send_packet(MSG_SERVICE_REQUEST, String('xxx')) await conn.wait_closed() @asynctest - async def test_invalid_service_accept(self): - """Test invalid service accept""" + async def test_unexpected_service_name_in_accept(self): + """Test unexpected service name in accept sent by server""" + + def send_newkeys(self, k, h): + """Finish a key exchange and send a new keys message""" + + asyncssh.connection.SSHConnection.send_newkeys(self, k, h) + + self.send_packet(MSG_SERVICE_ACCEPT, String('xxx')) + + with patch('asyncssh.connection.SSHServerConnection.send_newkeys', + send_newkeys): + with self.assertRaises(asyncssh.ServiceNotAvailable): + await self.connect() + + @asynctest + async def test_service_accept_from_client(self): + """Test service accept sent by client""" conn = await self.connect() - conn.send_packet(MSG_SERVICE_ACCEPT, String('xxx')) + conn.send_packet(MSG_SERVICE_ACCEPT, String('ssh-userauth')) await conn.wait_closed() + @asynctest + async def test_service_request_from_server(self): + """Test service request sent by server""" + + def send_newkeys(self, k, h): + """Finish a key exchange and send a new keys message""" + + asyncssh.connection.SSHConnection.send_newkeys(self, k, h) + + self.send_packet(MSG_SERVICE_REQUEST, String('ssh-userauth')) + + with patch('asyncssh.connection.SSHServerConnection.send_newkeys', + send_newkeys): + with self.assertRaises(asyncssh.ProtocolError): + await self.connect() + @asynctest async def test_packet_decode_error(self): """Test SSH packet decode error""" @@ -1322,6 +1379,39 @@ async def test_invalid_newkeys(self): conn.send_packet(MSG_NEWKEYS) await conn.wait_closed() + @asynctest + async def test_kex_after_kex_complete(self): + """Test kex request when kex not in progress""" + + conn = await self.connect() + conn.send_packet(MSG_KEX_FIRST) + await conn.wait_closed() + + @asynctest + async def test_userauth_after_auth_complete(self): + """Test userauth request when auth not in progress""" + + conn = await self.connect() + conn.send_packet(MSG_USERAUTH_FIRST) + await conn.wait_closed() + + @asynctest + async def test_userauth_before_kex_complete(self): + """Test receiving userauth before kex is complete""" + + def send_newkeys(self, k, h): + """Finish a key exchange and send a new keys message""" + + self.send_packet(MSG_USERAUTH_REQUEST, String('guest'), + String('ssh-connection'), String('none')) + + asyncssh.connection.SSHConnection.send_newkeys(self, k, h) + + with patch('asyncssh.connection.SSHClientConnection.send_newkeys', + send_newkeys): + with self.assertRaises(asyncssh.ProtocolError): + await self.connect() + @asynctest async def test_invalid_userauth_service(self): """Test invalid service in userauth request""" @@ -1371,25 +1461,32 @@ async def test_extra_userauth_request(self): String('ssh-connection'), String('none')) await asyncio.sleep(0.1) + @asynctest + async def test_late_userauth_request(self): + """Test userauth request after auth is final""" + + async with self.connect() as conn: + conn.send_packet(MSG_GLOBAL_REQUEST, String('xxx'), + Boolean(False)) + conn.send_packet(MSG_USERAUTH_REQUEST, String('guest'), + String('ssh-connection'), String('none')) + await conn.wait_closed() + @asynctest async def test_unexpected_userauth_success(self): """Test unexpected userauth success response""" - with patch.dict('asyncssh.connection.SSHConnection._packet_handlers', - {MSG_UNIMPLEMENTED: disconnect_on_unimplemented}): - conn = await self.connect() - conn.send_packet(MSG_USERAUTH_SUCCESS) - await conn.wait_closed() + conn = await self.connect() + conn.send_packet(MSG_USERAUTH_SUCCESS) + await conn.wait_closed() @asynctest async def test_unexpected_userauth_failure(self): """Test unexpected userauth failure response""" - with patch.dict('asyncssh.connection.SSHConnection._packet_handlers', - {MSG_UNIMPLEMENTED: disconnect_on_unimplemented}): - conn = await self.connect() - conn.send_packet(MSG_USERAUTH_FAILURE, NameList([]), Boolean(False)) - await conn.wait_closed() + conn = await self.connect() + conn.send_packet(MSG_USERAUTH_FAILURE, NameList([]), Boolean(False)) + await conn.wait_closed() @asynctest async def test_unexpected_userauth_banner(self):