diff --git a/dev-requirements.txt b/dev-requirements.txt index f225f7a..f86aab0 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -6,6 +6,7 @@ html5lib # needed for beautifulsoup mock notebook pre-commit +psutil pytest-asyncio pytest-cov pytest>=3.3 diff --git a/ldapauthenticator/ldapauthenticator.py b/ldapauthenticator/ldapauthenticator.py index 4e3a809..b18a672 100644 --- a/ldapauthenticator/ldapauthenticator.py +++ b/ldapauthenticator/ldapauthenticator.py @@ -228,6 +228,15 @@ def _server_port_default(self): This can be useful in an heterogeneous environment, when supplying a UNIX username to authenticate against AD. """, ) + secondary_uri = Unicode( + config=True, + default="", + help=""" + Comma separated address:port of the LDAP server which can be tried to contact when + primary LDAP server is unavailable. + + """, + ) def resolve_username(self, username_supplied_by_user): search_dn = self.lookup_dn_search_user @@ -305,8 +314,31 @@ def resolve_username(self, username_supplied_by_user): return (user_dn, response[0]["dn"]) def get_connection(self, userdn, password): + try: + return self._get_real_connection( + userdn, password, self.server_address, self.server_port + ) + except ( + ldap3.core.exceptions.LDAPSocketOpenError, + ldap3.core.exceptions.LDAPBindError, + ldap3.core.exceptions.LDAPSocketReceiveError, + ): + for server, port in self._get_secondary_servers(): + try: + return self._get_real_connection(userdn, password, server, port) + except ( + ldap3.core.exceptions.LDAPSocketOpenError, + ldap3.core.exceptions.LDAPBindError, + ldap3.core.exceptions.LDAPSocketReceiveError, + ): + continue + else: + # re-raise the last caught error + raise + + def _get_real_connection(self, userdn, password, server_address, server_port): server = ldap3.Server( - self.server_address, port=self.server_port, use_ssl=self.use_ssl + server_address, port=server_port, use_ssl=self.use_ssl ) auto_bind = ( ldap3.AUTO_BIND_NO_TLS if self.use_ssl else ldap3.AUTO_BIND_TLS_BEFORE_BIND @@ -316,6 +348,24 @@ def get_connection(self, userdn, password): ) return conn + def _get_secondary_servers(self): + uri_list = self.secondary_uri.split(",") + for uri in uri_list: + server_port = uri.strip().split(":") + assert len(server_port) <= 2 + if len(server_port) == 2: + try: + port = int(server_port[1]) + except ValueError: + self.log.warning( + "Invalid port in secondary uri %s, use default" % uri + ) + port = self._server_port_default() + else: + port = self._server_port_default() + + yield (server_port[0], port) + def get_user_attributes(self, conn, userdn): attrs = {} if self.auth_state_attributes: diff --git a/ldapauthenticator/tests/test_ldapauthenticator.py b/ldapauthenticator/tests/test_ldapauthenticator.py index 6471213..ea2a993 100644 --- a/ldapauthenticator/tests/test_ldapauthenticator.py +++ b/ldapauthenticator/tests/test_ldapauthenticator.py @@ -1,4 +1,14 @@ # Inspired by https://github.com/jupyterhub/jupyterhub/blob/master/jupyterhub/tests/test_auth.py +import random + +import psutil + + +def unused_port(): + while True: + port = random.randint(1024, 65534) + if port not in psutil.net_connections(): + return port async def test_ldap_auth_allowed(authenticator): @@ -100,3 +110,26 @@ async def test_ldap_auth_state_attributes(authenticator): ) assert authorized["name"] == "fry" assert authorized["auth_state"] == {"employeeType": ["Delivery boy"]} + + +async def test_ldap_auth_redirects(authenticator): + # set non-available port + correct_server_port = "%s:%s" % ( + authenticator.server_address, + authenticator._server_port_default(), + ) + authenticator.server_port = unused_port() + + async def _test_ldap_redirect(uri_pattern): + authenticator.secondary_uri = uri_pattern + authorized = await authenticator.get_authenticated_user( + None, {"username": "fry", "password": "fry"} + ) + assert authorized["name"] == "fry" + + await _test_ldap_redirect(correct_server_port) + await _test_ldap_redirect("unavailable,%s" % correct_server_port) + await _test_ldap_redirect("unavailable, %s" % correct_server_port) + await _test_ldap_redirect( + "unavailable:8080,localhost:8080,%s" % correct_server_port + )