From 98d97cc9fdd6d9984372760fdb93d05557396cc5 Mon Sep 17 00:00:00 2001 From: Slava Date: Thu, 22 Jul 2021 22:35:15 +0300 Subject: [PATCH] Keep auth header during http->https redirect (#5848) * Keep auth header during http->https redirect * more informative naming Co-authored-by: Sviatoslav Sydorenko * Update CHANGES/5783.feature Co-authored-by: Sviatoslav Sydorenko * fixes * Update docs/client_advanced.rst Co-authored-by: Sviatoslav Sydorenko * Update tests/test_client_functional.py Co-authored-by: Sviatoslav Sydorenko * Update tests/test_client_functional.py Co-authored-by: Sviatoslav Sydorenko * fixes * more clear naming Co-authored-by: Sviatoslav Sydorenko --- CHANGES/5783.feature | 1 + aiohttp/client.py | 11 +++- docs/client_advanced.rst | 11 ++++ tests/test_client_functional.py | 93 ++++++++++++++++++++++++++++----- 4 files changed, 101 insertions(+), 15 deletions(-) create mode 100644 CHANGES/5783.feature diff --git a/CHANGES/5783.feature b/CHANGES/5783.feature new file mode 100644 index 00000000000..4be16c23343 --- /dev/null +++ b/CHANGES/5783.feature @@ -0,0 +1 @@ +Started keeping the ``Authorization`` header during http->https redirects when the host remains the same. diff --git a/aiohttp/client.py b/aiohttp/client.py index 62b18d07ff6..6d6f1d48496 100644 --- a/aiohttp/client.py +++ b/aiohttp/client.py @@ -566,7 +566,16 @@ async def _request( elif not scheme: parsed_url = url.join(parsed_url) - if url.origin() != parsed_url.origin(): + is_same_host_https_redirect = ( + url.host == parsed_url.host + and parsed_url.scheme == "https" + and url.scheme == "http" + ) + + if ( + url.origin() != parsed_url.origin() + and not is_same_host_https_redirect + ): auth = None headers.pop(hdrs.AUTHORIZATION, None) diff --git a/docs/client_advanced.rst b/docs/client_advanced.rst index 40fd2fca728..7a2f4bef217 100644 --- a/docs/client_advanced.rst +++ b/docs/client_advanced.rst @@ -56,6 +56,17 @@ For *text/plain* :: await session.post(url, data='Привет, Мир!') +.. note:: + + ``Authorization`` header will be removed if you get redirected + to a different host or protocol, except the case when ``HTTP -> HTTPS`` + redirect is performed on the same host. + +.. versionchanged:: 4.0 + + Started keeping the ``Authorization`` header during ``HTTP -> HTTPS`` + redirects when the host remains the same. + Custom Cookies -------------- diff --git a/tests/test_client_functional.py b/tests/test_client_functional.py index 52d74d98324..79e007537cd 100644 --- a/tests/test_client_functional.py +++ b/tests/test_client_functional.py @@ -8,11 +8,13 @@ import json import pathlib import socket +import ssl from typing import Any from unittest import mock import pytest from multidict import MultiDict +from yarl import URL import aiohttp from aiohttp import Fingerprint, ServerFingerprintMismatch, hdrs, web @@ -2333,25 +2335,85 @@ async def test_creds_in_auth_and_url() -> None: await session.close() -async def test_drop_auth_on_redirect_to_other_host(aiohttp_server: Any) -> None: - async def srv1(request): - assert request.host == "host1.com" +@pytest.fixture +def create_server_for_url_and_handler( + aiohttp_server: Any, tls_certificate_authority: Any +): + def create(url: URL, srv: Any): + app = web.Application() + app.router.add_route("GET", url.path, srv) + + kwargs = {} + if url.scheme == "https": + cert = tls_certificate_authority.issue_cert( + url.host, "localhost", "127.0.0.1" + ) + ssl_ctx = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + cert.configure_cert(ssl_ctx) + kwargs["ssl"] = ssl_ctx + return aiohttp_server(app, **kwargs) + + return create + + +@pytest.mark.parametrize( + ["url_from", "url_to", "is_drop_header_expected"], + [ + [ + "http://host1.com/path1", + "http://host2.com/path2", + True, + ], + ["http://host1.com/path1", "https://host1.com/path1", False], + ["https://host1.com/path1", "http://host1.com/path2", True], + ], + ids=( + "entirely different hosts", + "http -> https", + "https -> http", + ), +) +async def test_drop_auth_on_redirect_to_other_host( + create_server_for_url_and_handler: Any, + url_from: str, + url_to: str, + is_drop_header_expected: bool, +) -> None: + url_from, url_to = URL(url_from), URL(url_to) + + async def srv_from(request): + assert request.host == url_from.host assert request.headers["Authorization"] == "Basic dXNlcjpwYXNz" - raise web.HTTPFound("http://host2.com/path2") + raise web.HTTPFound(url_to) - async def srv2(request): - assert request.host == "host2.com" - assert "Authorization" not in request.headers + async def srv_to(request): + assert request.host == url_to.host + if is_drop_header_expected: + assert "Authorization" not in request.headers, "Header wasn't dropped" + else: + assert "Authorization" in request.headers, "Header was dropped" return web.Response() - app = web.Application() - app.router.add_route("GET", "/path1", srv1) - app.router.add_route("GET", "/path2", srv2) + server_from = await create_server_for_url_and_handler(url_from, srv_from) + server_to = await create_server_for_url_and_handler(url_to, srv_to) - server = await aiohttp_server(app) + assert ( + url_from.host != url_to.host or server_from.scheme != server_to.scheme + ), "Invalid test case, host or scheme must differ" + + protocol_port_map = { + "http": 80, + "https": 443, + } + etc_hosts = { + (url_from.host, protocol_port_map[server_from.scheme]): server_from, + (url_to.host, protocol_port_map[server_to.scheme]): server_to, + } class FakeResolver(AbstractResolver): async def resolve(self, host, port=0, family=socket.AF_INET): + server = etc_hosts[(host, port)] + return [ { "hostname": host, @@ -2366,14 +2428,17 @@ async def resolve(self, host, port=0, family=socket.AF_INET): async def close(self): pass - connector = aiohttp.TCPConnector(resolver=FakeResolver()) + connector = aiohttp.TCPConnector(resolver=FakeResolver(), ssl=False) + async with aiohttp.ClientSession(connector=connector) as client: resp = await client.get( - "http://host1.com/path1", auth=aiohttp.BasicAuth("user", "pass") + url_from, + auth=aiohttp.BasicAuth("user", "pass"), ) assert resp.status == 200 resp = await client.get( - "http://host1.com/path1", headers={"Authorization": "Basic dXNlcjpwYXNz"} + url_from, + headers={"Authorization": "Basic dXNlcjpwYXNz"}, ) assert resp.status == 200