Skip to content

Commit

Permalink
Add and use ClientConnectionResetError (#9137)
Browse files Browse the repository at this point in the history
(cherry picked from commit f95bcaf)
  • Loading branch information
Dreamsorcerer committed Sep 18, 2024
1 parent c717b25 commit cf8c986
Show file tree
Hide file tree
Showing 11 changed files with 49 additions and 20 deletions.
2 changes: 2 additions & 0 deletions CHANGES/9137.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Added :exc:`aiohttp.ClientConnectionResetError`. Client code that previously threw :exc:`ConnectionResetError`
will now throw this -- by :user:`Dreamsorcerer`.
2 changes: 2 additions & 0 deletions aiohttp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .client import (
BaseConnector,
ClientConnectionError,
ClientConnectionResetError,
ClientConnectorCertificateError,
ClientConnectorError,
ClientConnectorSSLError,
Expand Down Expand Up @@ -124,6 +125,7 @@
# client
"BaseConnector",
"ClientConnectionError",
"ClientConnectionResetError",
"ClientConnectorCertificateError",
"ClientConnectorError",
"ClientConnectorSSLError",
Expand Down
3 changes: 2 additions & 1 deletion aiohttp/base_protocol.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
from typing import Optional, cast

from .client_exceptions import ClientConnectionResetError
from .helpers import set_exception
from .tcp_helpers import tcp_nodelay

Expand Down Expand Up @@ -85,7 +86,7 @@ def connection_lost(self, exc: Optional[BaseException]) -> None:

async def _drain_helper(self) -> None:
if not self.connected:
raise ConnectionResetError("Connection lost")
raise ClientConnectionResetError("Connection lost")
if not self._paused:
return
waiter = self._drain_waiter
Expand Down
2 changes: 2 additions & 0 deletions aiohttp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from .abc import AbstractCookieJar
from .client_exceptions import (
ClientConnectionError,
ClientConnectionResetError,
ClientConnectorCertificateError,
ClientConnectorError,
ClientConnectorSSLError,
Expand Down Expand Up @@ -102,6 +103,7 @@
__all__ = (
# client_exceptions
"ClientConnectionError",
"ClientConnectionResetError",
"ClientConnectorCertificateError",
"ClientConnectorError",
"ClientConnectorSSLError",
Expand Down
9 changes: 7 additions & 2 deletions aiohttp/client_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

from multidict import MultiMapping

from .http_parser import RawResponseMessage
from .typedefs import StrOrURL

try:
Expand All @@ -19,12 +18,14 @@

if TYPE_CHECKING:
from .client_reqrep import ClientResponse, ConnectionKey, Fingerprint, RequestInfo
from .http_parser import RawResponseMessage
else:
RequestInfo = ClientResponse = ConnectionKey = None
RequestInfo = ClientResponse = ConnectionKey = RawResponseMessage = None

__all__ = (
"ClientError",
"ClientConnectionError",
"ClientConnectionResetError",
"ClientOSError",
"ClientConnectorError",
"ClientProxyConnectionError",
Expand Down Expand Up @@ -159,6 +160,10 @@ class ClientConnectionError(ClientError):
"""Base class for client socket errors."""


class ClientConnectionResetError(ClientConnectionError, ConnectionResetError):
"""ConnectionResetError"""


class ClientOSError(ClientConnectionError, OSError):
"""OSError error."""

Expand Down
5 changes: 3 additions & 2 deletions aiohttp/http_websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
)

from .base_protocol import BaseProtocol
from .client_exceptions import ClientConnectionResetError
from .compression_utils import ZLibCompressor, ZLibDecompressor
from .helpers import NO_EXTENSIONS, set_exception
from .streams import DataQueue
Expand Down Expand Up @@ -624,7 +625,7 @@ async def _send_frame(
) -> None:
"""Send a frame over the websocket with message as its payload."""
if self._closing and not (opcode & WSMsgType.CLOSE):
raise ConnectionResetError("Cannot write to closing transport")
raise ClientConnectionResetError("Cannot write to closing transport")

# RSV are the reserved bits in the frame header. They are used to
# indicate that the frame is using an extension.
Expand Down Expand Up @@ -719,7 +720,7 @@ def _make_compress_obj(self, compress: int) -> ZLibCompressor:

def _write(self, data: bytes) -> None:
if self.transport is None or self.transport.is_closing():
raise ConnectionResetError("Cannot write to closing transport")
raise ClientConnectionResetError("Cannot write to closing transport")
self.transport.write(data)

async def pong(self, message: Union[bytes, str] = b"") -> None:
Expand Down
3 changes: 2 additions & 1 deletion aiohttp/http_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from .abc import AbstractStreamWriter
from .base_protocol import BaseProtocol
from .client_exceptions import ClientConnectionResetError
from .compression_utils import ZLibCompressor
from .helpers import NO_EXTENSIONS

Expand Down Expand Up @@ -72,7 +73,7 @@ def _write(self, chunk: bytes) -> None:
self.output_size += size
transport = self.transport
if not self._protocol.connected or transport is None or transport.is_closing():
raise ConnectionResetError("Cannot write to closing transport")
raise ClientConnectionResetError("Cannot write to closing transport")
transport.write(chunk)

async def write(
Expand Down
6 changes: 6 additions & 0 deletions docs/client_reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2207,6 +2207,10 @@ Connection errors

Derived from :exc:`ClientError`

.. class:: ClientConnectionResetError

Derived from :exc:`ClientConnectionError` and :exc:`ConnectionResetError`

.. class:: ClientOSError

Subset of connection errors that are initiated by an :exc:`OSError`
Expand Down Expand Up @@ -2293,6 +2297,8 @@ Hierarchy of exceptions

* :exc:`ClientConnectionError`

* :exc:`ClientConnectionResetError`

* :exc:`ClientOSError`

* :exc:`ClientConnectorError`
Expand Down
23 changes: 15 additions & 8 deletions tests/test_client_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,18 @@
import base64
import hashlib
import os
from typing import Any
from typing import Any, Type
from unittest import mock

import pytest

import aiohttp
from aiohttp import client, hdrs
from aiohttp.client_exceptions import ServerDisconnectedError
from aiohttp import (
ClientConnectionResetError,
ServerDisconnectedError,
client,
hdrs,
)
from aiohttp.http import WS_KEY
from aiohttp.streams import EofStream
from aiohttp.test_utils import make_mocked_coro
Expand Down Expand Up @@ -508,10 +512,13 @@ async def test_close_exc2(loop, ws_key, key_data) -> None:
await resp.close()


async def test_send_data_after_close(ws_key, key_data, loop) -> None:
resp = mock.Mock()
resp.status = 101
resp.headers = {
@pytest.mark.parametrize("exc", (ClientConnectionResetError, ConnectionResetError))
async def test_send_data_after_close(
exc: Type[Exception],
ws_key: bytes,
key_data: bytes,
loop: asyncio.AbstractEventLoop,
) -> None:
hdrs.UPGRADE: "websocket",
hdrs.CONNECTION: "upgrade",
hdrs.SEC_WEBSOCKET_ACCEPT: ws_key,
Expand All @@ -533,7 +540,7 @@ async def test_send_data_after_close(ws_key, key_data, loop) -> None:
(resp.send_bytes, (b"b",)),
(resp.send_json, ({},)),
):
with pytest.raises(ConnectionResetError):
with pytest.raises(exc): # Verify exc can be caught with both classes
await meth(*args)


Expand Down
4 changes: 2 additions & 2 deletions tests/test_client_ws_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import pytest

import aiohttp
from aiohttp import ServerTimeoutError, WSMsgType, hdrs, web
from aiohttp import ClientConnectionResetError, ServerTimeoutError, WSMsgType, hdrs, web
from aiohttp.http import WSCloseCode
from aiohttp.pytest_plugin import AiohttpClient

Expand Down Expand Up @@ -620,7 +620,7 @@ async def handler(request: web.Request) -> NoReturn:
# would cancel the heartbeat task and we wouldn't get a ping
assert resp._conn is not None
with mock.patch.object(
resp._conn.transport, "write", side_effect=ConnectionResetError
resp._conn.transport, "write", side_effect=ClientConnectionResetError
), mock.patch.object(resp._writer, "ping", wraps=resp._writer.ping) as ping:
await resp.receive()
ping_count = ping.call_count
Expand Down
10 changes: 6 additions & 4 deletions tests/test_http_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pytest
from multidict import CIMultiDict

from aiohttp import http
from aiohttp import ClientConnectionResetError, http
from aiohttp.test_utils import make_mocked_coro


Expand Down Expand Up @@ -232,12 +232,12 @@ async def test_write_to_closing_transport(protocol, transport, loop) -> None:
await msg.write(b"Before closing")
transport.is_closing.return_value = True

with pytest.raises(ConnectionResetError):
with pytest.raises(ClientConnectionResetError):
await msg.write(b"After closing")


async def test_write_to_closed_transport(protocol, transport, loop) -> None:
"""Test that writing to a closed transport raises ConnectionResetError.
"""Test that writing to a closed transport raises ClientConnectionResetError.
The StreamWriter checks to see if protocol.transport is None before
writing to the transport. If it is None, it raises ConnectionResetError.
Expand All @@ -247,7 +247,9 @@ async def test_write_to_closed_transport(protocol, transport, loop) -> None:
await msg.write(b"Before transport close")
protocol.transport = None

with pytest.raises(ConnectionResetError, match="Cannot write to closing transport"):
with pytest.raises(
ClientConnectionResetError, match="Cannot write to closing transport"
):
await msg.write(b"After transport closed")


Expand Down

0 comments on commit cf8c986

Please sign in to comment.