Skip to content

Commit

Permalink
Ensure writer is always reset on completion (#7815) (#7826)
Browse files Browse the repository at this point in the history
(cherry picked from commit 8f2f048)
  • Loading branch information
Dreamsorcerer authored Nov 12, 2023
1 parent c0f9017 commit cb94533
Show file tree
Hide file tree
Showing 5 changed files with 79 additions and 38 deletions.
1 change: 1 addition & 0 deletions CHANGES/7815.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fixed an issue where the client could go into an infinite loop. -- by :user:`Dreamsorcerer`
74 changes: 49 additions & 25 deletions aiohttp/client_reqrep.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,13 @@
reify,
set_result,
)
from .http import SERVER_SOFTWARE, HttpVersion10, HttpVersion11, StreamWriter
from .http import (
SERVER_SOFTWARE,
HttpVersion,
HttpVersion10,
HttpVersion11,
StreamWriter,
)
from .log import client_logger
from .streams import StreamReader
from .typedefs import (
Expand Down Expand Up @@ -241,7 +247,7 @@ class ClientRequest:
auth = None
response = None

_writer = None # async task for streaming data
__writer = None # async task for streaming data
_continue = None # waiter future for '100 Continue' response

# N.B.
Expand Down Expand Up @@ -332,6 +338,21 @@ def __init__(
traces = []
self._traces = traces

def __reset_writer(self, _: object = None) -> None:
self.__writer = None

@property
def _writer(self) -> Optional["asyncio.Task[None]"]:
return self.__writer

@_writer.setter
def _writer(self, writer: Optional["asyncio.Task[None]"]) -> None:
if self.__writer is not None:
self.__writer.remove_done_callback(self.__reset_writer)
self.__writer = writer
if writer is not None:
writer.add_done_callback(self.__reset_writer)

def is_ssl(self) -> bool:
return self.url.scheme in ("https", "wss")

Expand Down Expand Up @@ -625,8 +646,6 @@ async def write_bytes(
else:
await writer.write_eof()
protocol.start_timeout()
finally:
self._writer = None

async def send(self, conn: "Connection") -> "ClientResponse":
# Specify request target:
Expand Down Expand Up @@ -711,16 +730,14 @@ async def send(self, conn: "Connection") -> "ClientResponse":

async def close(self) -> None:
if self._writer is not None:
try:
with contextlib.suppress(asyncio.CancelledError):
await self._writer
finally:
self._writer = None
with contextlib.suppress(asyncio.CancelledError):
await self._writer

def terminate(self) -> None:
if self._writer is not None:
if not self.loop.is_closed():
self._writer.cancel()
self._writer.remove_done_callback(self.__reset_writer)
self._writer = None

async def _on_chunk_request_sent(self, method: str, url: URL, chunk: bytes) -> None:
Expand All @@ -740,9 +757,9 @@ class ClientResponse(HeadersMixin):
# but will be set by the start() method.
# As the end user will likely never see the None values, we cheat the types below.
# from the Status-Line of the response
version = None # HTTP-Version
status: int = None # type: ignore[assignment] # Status-Code
reason = None # Reason-Phrase
version: Optional[HttpVersion] = None # HTTP-Version
status: int = None # type: ignore[assignment] # Status-Code
reason: Optional[str] = None # Reason-Phrase

content: StreamReader = None # type: ignore[assignment] # Payload stream
_headers: CIMultiDictProxy[str] = None # type: ignore[assignment]
Expand All @@ -754,6 +771,7 @@ class ClientResponse(HeadersMixin):
# post-init stage allows to not change ctor signature
_closed = True # to allow __del__ for non-initialized properly response
_released = False
__writer = None

def __init__(
self,
Expand Down Expand Up @@ -799,6 +817,21 @@ def __init__(
if loop.get_debug():
self._source_traceback = traceback.extract_stack(sys._getframe(1))

def __reset_writer(self, _: object = None) -> None:
self.__writer = None

@property
def _writer(self) -> Optional["asyncio.Task[None]"]:
return self.__writer

@_writer.setter
def _writer(self, writer: Optional["asyncio.Task[None]"]) -> None:
if self.__writer is not None:
self.__writer.remove_done_callback(self.__reset_writer)
self.__writer = writer
if writer is not None:
writer.add_done_callback(self.__reset_writer)

@reify
def url(self) -> URL:
return self._url
Expand Down Expand Up @@ -863,7 +896,7 @@ def __repr__(self) -> str:
"ascii", "backslashreplace"
).decode("ascii")
else:
ascii_encodable_reason = self.reason
ascii_encodable_reason = "None"
print(
"<ClientResponse({}) [{} {}]>".format(
ascii_encodable_url, self.status, ascii_encodable_reason
Expand Down Expand Up @@ -1044,18 +1077,12 @@ def _release_connection(self) -> None:

async def _wait_released(self) -> None:
if self._writer is not None:
try:
await self._writer
finally:
self._writer = None
await self._writer
self._release_connection()

def _cleanup_writer(self) -> None:
if self._writer is not None:
if self._writer.done():
self._writer = None
else:
self._writer.cancel()
self._writer.cancel()
self._session = None

def _notify_content(self) -> None:
Expand All @@ -1066,10 +1093,7 @@ def _notify_content(self) -> None:

async def wait_for_close(self) -> None:
if self._writer is not None:
try:
await self._writer
finally:
self._writer = None
await self._writer
self.release()

async def read(self) -> bytes:
Expand Down
20 changes: 16 additions & 4 deletions tests/test_client_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import urllib.parse
import zlib
from http.cookies import BaseCookie, Morsel, SimpleCookie
from typing import Any, Dict, Optional
from typing import Any, Callable, Dict, Optional
from unittest import mock

import pytest
Expand All @@ -24,6 +24,17 @@
from aiohttp.test_utils import make_mocked_coro


class WriterMock(mock.AsyncMock):
def __await__(self) -> None:
return self().__await__()

def add_done_callback(self, cb: Callable[[], None]) -> None:
"""Dummy method."""

def remove_done_callback(self, cb: Callable[[], None]) -> None:
"""Dummy method."""


@pytest.fixture
def make_request(loop):
request = None
Expand Down Expand Up @@ -1167,7 +1178,7 @@ def read(self, decode=False):
async def test_oserror_on_write_bytes(loop, conn) -> None:
req = ClientRequest("POST", URL("http://python.org/"), loop=loop)

writer = mock.Mock()
writer = WriterMock()
writer.write.side_effect = OSError

await req.write_bytes(writer, conn)
Expand All @@ -1183,7 +1194,8 @@ async def test_terminate(loop, conn) -> None:
req = ClientRequest("get", URL("http://python.org"), loop=loop)
resp = await req.send(conn)
assert req._writer is not None
writer = req._writer = mock.Mock()
writer = req._writer = WriterMock()
writer.cancel = mock.Mock()

req.terminate()
assert req._writer is None
Expand All @@ -1201,7 +1213,7 @@ async def go():
req = ClientRequest("get", URL("http://python.org"))
resp = await req.send(conn)
assert req._writer is not None
writer = req._writer = mock.Mock()
writer = req._writer = WriterMock()

await asyncio.sleep(0.05)

Expand Down
4 changes: 4 additions & 0 deletions tests/test_client_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import gc
import sys
from typing import Callable
from unittest import mock

import pytest
Expand All @@ -19,6 +20,9 @@ class WriterMock(mock.AsyncMock):
def __await__(self) -> None:
return self().__await__()

def add_done_callback(self, cb: Callable[[], None]) -> None:
cb()

def done(self) -> bool:
return True

Expand Down
18 changes: 9 additions & 9 deletions tests/test_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def test_proxy_server_hostname_default(self, ClientRequestMock) -> None:
"get",
URL("http://proxy.example.com"),
request_info=mock.Mock(),
writer=mock.Mock(),
writer=None,
continue100=None,
timer=TimerNoop(),
traces=[],
Expand Down Expand Up @@ -264,7 +264,7 @@ def test_proxy_server_hostname_override(self, ClientRequestMock) -> None:
"get",
URL("http://proxy.example.com"),
request_info=mock.Mock(),
writer=mock.Mock(),
writer=None,
continue100=None,
timer=TimerNoop(),
traces=[],
Expand Down Expand Up @@ -326,7 +326,7 @@ def test_https_connect(self, ClientRequestMock) -> None:
"get",
URL("http://proxy.example.com"),
request_info=mock.Mock(),
writer=mock.Mock(),
writer=None,
continue100=None,
timer=TimerNoop(),
traces=[],
Expand Down Expand Up @@ -386,7 +386,7 @@ def test_https_connect_certificate_error(self, ClientRequestMock) -> None:
"get",
URL("http://proxy.example.com"),
request_info=mock.Mock(),
writer=mock.Mock(),
writer=None,
continue100=None,
timer=TimerNoop(),
traces=[],
Expand Down Expand Up @@ -440,7 +440,7 @@ def test_https_connect_ssl_error(self, ClientRequestMock) -> None:
"get",
URL("http://proxy.example.com"),
request_info=mock.Mock(),
writer=mock.Mock(),
writer=None,
continue100=None,
timer=TimerNoop(),
traces=[],
Expand Down Expand Up @@ -496,7 +496,7 @@ def test_https_connect_http_proxy_error(self, ClientRequestMock) -> None:
"get",
URL("http://proxy.example.com"),
request_info=mock.Mock(),
writer=mock.Mock(),
writer=None,
continue100=None,
timer=TimerNoop(),
traces=[],
Expand Down Expand Up @@ -555,7 +555,7 @@ def test_https_connect_resp_start_error(self, ClientRequestMock) -> None:
"get",
URL("http://proxy.example.com"),
request_info=mock.Mock(),
writer=mock.Mock(),
writer=None,
continue100=None,
timer=TimerNoop(),
traces=[],
Expand Down Expand Up @@ -666,7 +666,7 @@ def test_https_connect_pass_ssl_context(self, ClientRequestMock) -> None:
"get",
URL("http://proxy.example.com"),
request_info=mock.Mock(),
writer=mock.Mock(),
writer=None,
continue100=None,
timer=TimerNoop(),
traces=[],
Expand Down Expand Up @@ -737,7 +737,7 @@ def test_https_auth(self, ClientRequestMock) -> None:
"get",
URL("http://proxy.example.com"),
request_info=mock.Mock(),
writer=mock.Mock(),
writer=None,
continue100=None,
timer=TimerNoop(),
traces=[],
Expand Down

0 comments on commit cb94533

Please sign in to comment.