diff --git a/docs/Usage/FastAPI-Helper.md b/docs/Usage/FastAPI-Helper.md index 45f1ebe..ccdee99 100644 --- a/docs/Usage/FastAPI-Helper.md +++ b/docs/Usage/FastAPI-Helper.md @@ -29,7 +29,7 @@ app = reverse_http_app(client=client, base_url=base_url) ``` 1. You can pass `httpx.AsyncClient` instance: - - if you want to customize the arguments, e.g. `httpx.AsyncClient(proxies={})` + - if you want to customize the arguments, e.g. `httpx.AsyncClient(http2=True)` - if you want to reuse the connection pool of `httpx.AsyncClient` --- Or you can pass `None`(The default value), then `fastapi-proxy-lib` will create a new `httpx.AsyncClient` instance for you. diff --git a/pyproject.toml b/pyproject.toml index f3c9bbf..2d7318c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,9 +50,11 @@ dynamic = ["version"] dependencies = [ "httpx", - "httpx-ws >= 0.4.2", - "starlette", + "httpx-ws >= 0.6.0", + "starlette >= 0.37.2", "typing_extensions >=4.5.0", + "anyio >= 4", + "exceptiongroup", ] [project.optional-dependencies] @@ -96,10 +98,11 @@ dependencies = [ "pytest == 7.*", "pytest-cov == 4.*", "uvicorn[standard] < 1.0.0", # TODO: Once it releases version 1.0.0, we will remove this restriction. + "hypercorn[trio] == 0.16.*", "httpx[http2]", # we don't set version here, instead set it in `[project].dependencies`. - "anyio", # we don't set version here, because fastapi has a dependency on it - "asgi-lifespan==2.*", - "pytest-timeout==2.*", + "asgi-lifespan == 2.*", + "pytest-timeout == 2.*", + "sniffio == 1.3.*", ] [tool.hatch.envs.default.scripts] diff --git a/src/fastapi_proxy_lib/core/_tool.py b/src/fastapi_proxy_lib/core/_tool.py index 5b3d014..5604dec 100644 --- a/src/fastapi_proxy_lib/core/_tool.py +++ b/src/fastapi_proxy_lib/core/_tool.py @@ -1,13 +1,11 @@ """The utils tools for both http proxy and websocket proxy.""" import ipaddress -import logging import warnings from functools import lru_cache from textwrap import dedent from typing import ( Any, - Iterable, Mapping, Optional, Protocol, @@ -17,7 +15,6 @@ ) import httpx -from starlette import status from starlette.background import BackgroundTask as BackgroundTask_t from starlette.datastructures import ( Headers as StarletteHeaders, @@ -26,13 +23,11 @@ MutableHeaders as StarletteMutableHeaders, ) from starlette.responses import JSONResponse -from starlette.types import Scope from typing_extensions import deprecated, overload __all__ = ( "check_base_url", "return_err_msg_response", - "check_http_version", "BaseURLError", "ErrMsg", "ErrRseponseJson", @@ -129,10 +124,6 @@ class _RejectedProxyRequestError(RuntimeError): """Should be raised when reject proxy request.""" -class _UnsupportedHttpVersionError(RuntimeError): - """Unsupported http version.""" - - #################### Tools #################### @@ -309,8 +300,8 @@ def return_err_msg_response( err_response_json = ErrRseponseJson(detail=detail) # TODO: 请注意,logging是同步函数,每次会阻塞1ms左右,这可能会导致性能问题 - # 特别是对于写入文件的log,最好把它放到 asyncio.to_thread 里执行 - # https://docs.python.org/zh-cn/3/library/asyncio-task.html#coroutine + # 特别是对于写入文件的log,最好把它放到 `anyio.to_thread.run_sync()` 里执行 + # https://anyio.readthedocs.io/en/stable/threads.html#running-a-function-in-a-worker-thread if logger is not None: # 只要传入了logger,就一定记录日志 @@ -337,35 +328,6 @@ def return_err_msg_response( ) -def check_http_version( - scope: Scope, supported_versions: Iterable[str] -) -> Union[JSONResponse, None]: - """Check whether the http version of scope is in supported_versions. - - Args: - scope: asgi scope dict. - supported_versions: The supported http versions. - - Returns: - If the http version of scope is not in supported_versions, return a JSONResponse with status_code=505, - else return None. - """ - # https://asgi.readthedocs.io/en/latest/specs/www.html#http-connection-scope - # https://asgi.readthedocs.io/en/latest/specs/www.html#websocket-connection-scope - http_version: str = scope.get("http_version", "") - # 如果明确指定了http版本(即不是""),但不在支持的版本内,则返回505 - if http_version not in supported_versions and http_version != "": - error = _UnsupportedHttpVersionError( - f"The request http version is {http_version}, but we only support {supported_versions}." - ) - # TODO: 或许可以logging记录下 scope.get("client") 的值 - return return_err_msg_response( - error, - status_code=status.HTTP_505_HTTP_VERSION_NOT_SUPPORTED, - logger=logging.info, - ) - - def default_proxy_filter(url: httpx.URL) -> Union[None, str]: """Filter by host. diff --git a/src/fastapi_proxy_lib/core/http.py b/src/fastapi_proxy_lib/core/http.py index fab3316..c3a7613 100644 --- a/src/fastapi_proxy_lib/core/http.py +++ b/src/fastapi_proxy_lib/core/http.py @@ -31,7 +31,6 @@ _RejectedProxyRequestError, # pyright: ignore [reportPrivateUsage] # 允许使用本项目内部的私有成员 change_necessary_client_header_for_httpx, check_base_url, - check_http_version, return_err_msg_response, warn_for_none_filter, ) @@ -81,10 +80,6 @@ class _ReverseProxyServerError(RuntimeError): _NON_REQUEST_BODY_METHODS = ("GET", "HEAD", "OPTIONS", "TRACE") """The http methods that should not contain request body.""" -# https://asgi.readthedocs.io/en/latest/specs/www.html#http-connection-scope -SUPPORTED_HTTP_VERSIONS = ("1.0", "1.1") -"""The http versions that we supported now. It depends on `httpx`.""" - # https://www.python-httpx.org/exceptions/ _400_ERROR_NEED_TO_BE_CATCHED_IN_FORWARD_PROXY = ( httpx.InvalidURL, # 解析url时出错 @@ -227,8 +222,6 @@ async def send_request_to_target( # pyright: ignore [reportIncompatibleMethodOv ) -> StarletteResponse: """Change request headers and send request to target url. - - The http version of request must be in [`SUPPORTED_HTTP_VERSIONS`][fastapi_proxy_lib.core.http.SUPPORTED_HTTP_VERSIONS]. - Args: request: the original client request. target_url: target url that request will be sent to. @@ -239,10 +232,6 @@ async def send_request_to_target( # pyright: ignore [reportIncompatibleMethodOv client = self.client follow_redirects = self.follow_redirects - check_result = check_http_version(request.scope, SUPPORTED_HTTP_VERSIONS) - if check_result is not None: - return check_result - # 将请求头中的host字段改为目标url的host # 同时强制移除"keep-alive"字段和添加"keep-alive"值到"connection"字段中保持连接 require_close, proxy_header = _change_client_header( @@ -338,8 +327,8 @@ async def close_proxy_event(_: FastAPI) -> AsyncIterator[None]: # (1)! app = FastAPI(lifespan=close_proxy_event) @app.get("/{path:path}") # (2)! - async def _(request: Request, path: str = ""): - return await proxy.proxy(request=request, path=path) # (3)! + async def _(request: Request): + return await proxy.proxy(request=request) # (3)! # Then run shell: `uvicorn :app --host http://127.0.0.1:8000 --port 8000` # visit the app: `http://127.0.0.1:8000/` @@ -350,10 +339,6 @@ async def _(request: Request, path: str = ""): 2. `{path:path}` is the key.
It allows the app to accept all path parameters.
visit for more info. - 3. !!! info - In fact, you only need to pass the `request: Request` argument.
- `fastapi_proxy_lib` can automatically get the `path` from `request`.
- Explicitly pointing it out here is just to remind you not to forget to specify `{path:path}`. ''' client: httpx.AsyncClient @@ -387,15 +372,12 @@ def __init__( @override async def proxy( # pyright: ignore [reportIncompatibleMethodOverride] - self, *, request: StarletteRequest, path: Optional[str] = None + self, *, request: StarletteRequest ) -> StarletteResponse: """Send request to target server. Args: request: `starlette.requests.Request` - path: The path params of request, which means the path params of base url.
- If None, will get it from `request.path_params`.
- **Usually, you don't need to pass this argument**. Returns: The response from target server. @@ -403,9 +385,7 @@ async def proxy( # pyright: ignore [reportIncompatibleMethodOverride] base_url = self.base_url # 只取第一个路径参数。注意,我们允许没有路径参数,这代表直接请求 - path_param: str = ( - path if path is not None else next(iter(request.path_params.values()), "") - ) + path_param: str = next(iter(request.path_params.values()), "") # 将路径参数拼接到目标url上 # e.g: "https://www.example.com/p0/" + "p1" @@ -473,8 +453,8 @@ async def close_proxy_event(_: FastAPI) -> AsyncIterator[None]: app = FastAPI(lifespan=close_proxy_event) @app.get("/{path:path}") - async def _(request: Request, path: str = ""): - return await proxy.proxy(request=request, path=path) + async def _(request: Request): + return await proxy.proxy(request=request) # Then run shell: `uvicorn :app --host http://127.0.0.1:8000 --port 8000` # visit the app: `http://127.0.0.1:8000/http://www.example.com` @@ -513,15 +493,11 @@ async def proxy( # pyright: ignore [reportIncompatibleMethodOverride] self, *, request: StarletteRequest, - path: Optional[str] = None, ) -> StarletteResponse: """Send request to target server. Args: request: `starlette.requests.Request` - path: The path params of request, which means the full url of target server.
- If None, will get it from `request.path_params`.
- **Usually, you don't need to pass this argument**. Returns: The response from target server. @@ -529,9 +505,8 @@ async def proxy( # pyright: ignore [reportIncompatibleMethodOverride] proxy_filter = self.proxy_filter # 只取第一个路径参数 - path_param: str = ( - next(iter(request.path_params.values()), "") if path is None else path - ) + path_param: str = next(iter(request.path_params.values()), "") + # 如果没有路径参数,即在正向代理中未指定目标url,则返回400 if path_param == "": error = _BadTargetUrlError("Must provide target url.") diff --git a/src/fastapi_proxy_lib/core/websocket.py b/src/fastapi_proxy_lib/core/websocket.py index 36dacd7..223c381 100644 --- a/src/fastapi_proxy_lib/core/websocket.py +++ b/src/fastapi_proxy_lib/core/websocket.py @@ -1,34 +1,30 @@ """The websocket proxy lib.""" -import asyncio import logging +import warnings +from collections import deque from contextlib import AsyncExitStack +from textwrap import dedent from typing import ( TYPE_CHECKING, Any, List, - Literal, - NamedTuple, NoReturn, Optional, Union, ) +import anyio +import anyio.abc import httpx import httpx_ws import starlette.websockets as starlette_ws -from httpx_ws._api import ( # HACK: 注意,这个是私有模块 - DEFAULT_KEEPALIVE_PING_INTERVAL_SECONDS, - DEFAULT_KEEPALIVE_PING_TIMEOUT_SECONDS, - DEFAULT_MAX_MESSAGE_SIZE_BYTES, - DEFAULT_QUEUE_SIZE, -) +import wsproto from starlette import status as starlette_status -from starlette.exceptions import WebSocketException as StarletteWebSocketException -from starlette.responses import Response as StarletteResponse -from starlette.responses import StreamingResponse +from starlette.background import BackgroundTask +from starlette.responses import Response from starlette.types import Scope -from typing_extensions import TypeAlias, override +from typing_extensions import override from wsproto.events import BytesMessage as WsprotoBytesMessage from wsproto.events import TextMessage as WsprotoTextMessage @@ -36,9 +32,41 @@ from ._tool import ( change_necessary_client_header_for_httpx, check_base_url, - check_http_version, ) +# XXX: because these variables are private, we have to use try-except to avoid errors +try: + from httpx_ws._api import ( + DEFAULT_KEEPALIVE_PING_INTERVAL_SECONDS, + DEFAULT_KEEPALIVE_PING_TIMEOUT_SECONDS, + DEFAULT_MAX_MESSAGE_SIZE_BYTES, + DEFAULT_QUEUE_SIZE, + ) +except ImportError: # pragma: no cover + # ref: https://github.com/frankie567/httpx-ws/blob/b2135792141b71551b022ff0d76542a0263a890c/httpx_ws/_api.py#L31-L34 + DEFAULT_KEEPALIVE_PING_TIMEOUT_SECONDS = ( # pyright: ignore[reportConstantRedefinition] + 20.0 + ) + DEFAULT_KEEPALIVE_PING_INTERVAL_SECONDS = ( # pyright: ignore[reportConstantRedefinition] + 20.0 + ) + DEFAULT_MAX_MESSAGE_SIZE_BYTES = ( # pyright: ignore[reportConstantRedefinition] + 65_536 + ) + DEFAULT_QUEUE_SIZE = 512 # pyright: ignore[reportConstantRedefinition] + + warnings.warn( + dedent( + """\ + Can not import the default httpx_ws arguments, please open an issue on: + https://github.com/WSH032/fastapi-proxy-lib\ + """ + ), + RuntimeWarning, + stacklevel=1, + ) + + __all__ = ( "BaseWebSocketProxy", "ReverseWebSocketProxy", @@ -52,25 +80,18 @@ #################### Data Model #################### -_ClentToServerTaskType: TypeAlias = "asyncio.Task[starlette_ws.WebSocketDisconnect]" -_ServerToClientTaskType: TypeAlias = "asyncio.Task[httpx_ws.WebSocketDisconnect]" - - -class _ClientServerProxyTask(NamedTuple): - """The task group for passing websocket message between client and target server.""" - - client_to_server_task: _ClentToServerTaskType - server_to_client_task: _ServerToClientTaskType +_WsDisconnectType = Union[ + starlette_ws.WebSocketDisconnect, httpx_ws.WebSocketDisconnect +] +_WsDisconnectDequqType = Union[ + _WsDisconnectType, # Exception that contains closing info + Exception, # other Exception +] #################### Constant #################### -# https://asgi.readthedocs.io/en/latest/specs/www.html#websocket-connection-scope -SUPPORTED_WS_HTTP_VERSIONS = ("1.1",) -"""The http versions that we supported now. It depends on `httpx`.""" - - #################### Error #################### @@ -106,36 +127,36 @@ async def _starlette_ws_receive_bytes_or_str( """Receive bytes or str from starlette WebSocket. - There is already a queue inside to store the received data - - Even if Exception is raised, the {WebSocket} would **not** be closed automatically, you should close it manually + - Even if `AssertionError` is raised, the `WebSocket` would **not** be closed automatically, + you should close it manually, Args: websocket: The starlette WebSocket that has been connected. - "has been connected" measn that you have called "websocket.accept" first. + "has been connected" means that you have called "websocket.accept" first. Raises: starlette.websockets.WebSocketDisconnect: If the WebSocket is disconnected. WebSocketDisconnect.code is the close code. WebSocketDisconnect.reason is the close reason. - **This is normal behavior that you should catch** - StarletteWebSocketException: If receive a invalid message type which is neither bytes nor str. - StarletteWebSocketException.code = starlette_status.WS_1008_POLICY_VIOLATION - StarletteWebSocketException.reason is the close reason. - - RuntimeError: If the WebSocket is not connected. Need to call "accept" first. - If the {websocket} argument you passed in is correct, this error will never be raised, just for asset. + AssertionError: + - If receive a invalid message type which is neither bytes nor str. + - RuntimeError: If the WebSocket is not connected. Need to call "accept" first. + If the `websocket` argument passed in is correct, this error will never be raised, just for assertion. Returns: bytes | str: The received data. """ - # 实现参考: + # Implement reference: # https://github.com/encode/starlette/blob/657e7e7b728e13dc66cc3f77dffd00a42545e171/starlette/websockets.py#L107C1-L115C1 assert ( websocket.application_state == starlette_ws.WebSocketState.CONNECTED ), """WebSocket is not connected. Need to call "accept" first.""" message = await websocket.receive() - # maybe raise WebSocketDisconnect - websocket._raise_on_disconnect(message) # pyright: ignore [reportPrivateUsage] + + if message["type"] == "websocket.disconnect": + raise starlette_ws.WebSocketDisconnect(message["code"], message.get("reason")) # https://asgi.readthedocs.io/en/latest/specs/www.html#receive-receive-event if message.get("bytes") is not None: @@ -143,12 +164,8 @@ async def _starlette_ws_receive_bytes_or_str( elif message.get("text") is not None: return message["text"] else: - # 这种情况应该不会发生,因为这是ASGI标准 + # It should never happen, because of the ASGI spec raise AssertionError("message should have 'bytes' or 'text' key") - raise StarletteWebSocketException( - code=starlette_status.WS_1008_POLICY_VIOLATION, - reason="Invalid message type received (neither bytes nor text).", - ) # 为什么使用这个函数而不是直接使用httpx_ws_AsyncWebSocketSession.receive_text() @@ -159,8 +176,8 @@ async def _httpx_ws_receive_bytes_or_str( """Receive bytes or str from httpx_ws AsyncWebSocketSession . - There is already a queue inside to store the received data - - Even if Exception is raised, the {WebSocket} would **not** be closed automatically, you should close it manually - - except for httpx_ws.WebSocketNetworkError, which will call 'close' automatically + - Even if `AssertionError` or `httpx_ws.WebSocketNetworkError` is raised, the `WebSocket` would **not** be closed automatically, + you should close it manually, Args: websocket: The httpx_ws AsyncWebSocketSession that has been connected. @@ -171,9 +188,8 @@ async def _httpx_ws_receive_bytes_or_str( WebSocketDisconnect.reason is the close reason. - **This is normal behavior that you should catch** httpx_ws.WebSocketNetworkError: A network error occurred. - - httpx_ws.WebSocketInvalidTypeReceived: If receive a invalid message type which is neither bytes nor str. - Usually it will never be raised, just for assert + AssertionError: If receive a invalid message type which is neither bytes nor str. + Usually it will never be raised, just for assertion Returns: bytes | str: The received data. @@ -198,7 +214,7 @@ async def _httpx_ws_receive_bytes_or_str( else: # pragma: no cover # 无法测试这个分支,因为无法发送这种消息,正常来说也不会被执行,所以我们这里记录critical msg = f"Invalid message type received: {type(event)}" logging.critical(msg) - raise httpx_ws.WebSocketInvalidTypeReceived(event) + raise AssertionError(event) async def _httpx_ws_send_bytes_or_str( @@ -207,10 +223,10 @@ async def _httpx_ws_send_bytes_or_str( ) -> None: """Send bytes or str to WebSocket. - - Usually, when Exception is raised, the {WebSocket} is already closed. + - Usually, when Exception is raised, the `WebSocket` is already closed. Args: - websocket: The httpx_ws.AsyncWebSocketSession that has been connected. + websocket: The `httpx_ws.AsyncWebSocketSession` that has been connected. data: The data to send. Raises: @@ -236,18 +252,14 @@ async def _starlette_ws_send_bytes_or_str( ) -> None: """Send bytes or str to WebSocket. - - Even if Exception is raised, the {WebSocket} would **not** be closed automatically, you should close it manually + - Even if Exception is raised, the `WebSocket` would **not** be closed automatically, you should close it manually Args: websocket: The starlette_ws.WebSocket that has been connected. data: The data to send. Raises: - When websocket has been disconnected, there may be exceptions raised, or maybe not. - # https://github.com/encode/uvicorn/discussions/2137 - For Uvicorn backend: - - `wsproto`: nothing raised. - - `websockets`: websockets.exceptions.ConnectionClosedError + When websocket has been disconnected, will raise `starlette_ws.WebSocketDisconnect(1006)`. """ # HACK: make pyright happy @@ -261,136 +273,191 @@ async def _starlette_ws_send_bytes_or_str( async def _wait_client_then_send_to_server( - *, client_ws: starlette_ws.WebSocket, server_ws: httpx_ws.AsyncWebSocketSession -) -> starlette_ws.WebSocketDisconnect: + client_ws: starlette_ws.WebSocket, + server_ws: httpx_ws.AsyncWebSocketSession, + ws_disconnect_deque: "deque[_WsDisconnectDequqType]", + task_group: anyio.abc.TaskGroup, +) -> None: """Receive data from client, then send to target server. Args: client_ws: The websocket which receive data of client. server_ws: The websocket which send data to target server. + ws_disconnect_deque: A deque to store Exception. + task_group: The task group which run this task. + if catch a Exception, will cancel the task group. Returns: - If the client_ws sends a shutdown message normally, will return starlette_ws.WebSocketDisconnect. + None: Always run forever, except encounter Exception. Raises: - error for receiving: refer to `_starlette_ws_receive_bytes_or_str` - error for sending: refer to `_httpx_ws_send_bytes_or_str` + error for receiving: refer to `_starlette_ws_receive_bytes_or_str`. + error for sending: refer to `_httpx_ws_send_bytes_or_str`. """ - while True: - try: + try: + while True: receive = await _starlette_ws_receive_bytes_or_str(client_ws) - except starlette_ws.WebSocketDisconnect as e: - return e - else: await _httpx_ws_send_bytes_or_str(server_ws, receive) + except Exception as ws_disconnect: + task_group.cancel_scope.cancel() + ws_disconnect_deque.append(ws_disconnect) async def _wait_server_then_send_to_client( - *, client_ws: starlette_ws.WebSocket, server_ws: httpx_ws.AsyncWebSocketSession -) -> httpx_ws.WebSocketDisconnect: + client_ws: starlette_ws.WebSocket, + server_ws: httpx_ws.AsyncWebSocketSession, + ws_disconnect_deque: "deque[_WsDisconnectDequqType]", + task_group: anyio.abc.TaskGroup, +) -> None: """Receive data from target server, then send to client. Args: - client_ws: The websocket which send data to client. - server_ws: The websocket which receive data of target server. + client_ws: The websocket which receive data of client. + server_ws: The websocket which send data to target server. + ws_disconnect_deque: A deque to store Exception. + task_group: The task group which run this task. + if catch a Exception, will cancel the task group. Returns: - If the server_ws sends a shutdown message normally, will return httpx_ws.WebSocketDisconnect. + None: Always run forever, except encounter Exception. Raises: - error for receiving: refer to `_httpx_ws_receive_bytes_or_str` - error for sending: refer to `_starlette_ws_send_bytes_or_str` + error for receiving: refer to `_httpx_ws_receive_bytes_or_str`. + error for sending: refer to `_starlette_ws_send_bytes_or_str`. """ - while True: - try: + try: + while True: receive = await _httpx_ws_receive_bytes_or_str(server_ws) - except httpx_ws.WebSocketDisconnect as e: - return e - else: await _starlette_ws_send_bytes_or_str(client_ws, receive) + except Exception as ws_disconnect: + task_group.cancel_scope.cancel() + ws_disconnect_deque.append(ws_disconnect) -async def _close_ws( - *, - client_to_server_task: _ClentToServerTaskType, - server_to_client_task: _ServerToClientTaskType, +async def _close_ws( # noqa: C901, PLR0912 client_ws: starlette_ws.WebSocket, server_ws: httpx_ws.AsyncWebSocketSession, -) -> None: - """Close ws connection and send status code based on task results. - - - If there is an error, or can't get status code from tasks, then always send a 1011 status code - - Will close ws connection whatever happens. - - Args: - client_to_server_task: client_to_server_task - server_to_client_task: server_to_client_task - client_ws: client_ws - server_ws: server_ws - """ - try: - # NOTE: 先判断 cancelled ,因为被取消的 task.exception() 会引发异常 - client_error = ( - asyncio.CancelledError - if client_to_server_task.cancelled() - else client_to_server_task.exception() - ) - server_error = ( - asyncio.CancelledError - if server_to_client_task.cancelled() - else server_to_client_task.exception() - ) + ws_disconnect_deque: "deque[_WsDisconnectDequqType]", + caught_tg_exc: Optional[BaseException], +): + ws_disconnect_tuple = tuple(ws_disconnect_deque) + + client_disc_errs: List[starlette_ws.WebSocketDisconnect] = [] + not_client_disc_errs: List[Exception] = [] + for e in ws_disconnect_tuple: + if isinstance(e, starlette_ws.WebSocketDisconnect): + client_disc_errs.append(e) + else: + not_client_disc_errs.append(e) - if client_error is None: - # clinet端收到正常关闭消息,则关闭server端 - disconnection = client_to_server_task.result() - await server_ws.close(disconnection.code, disconnection.reason) - return - elif server_error is None: - # server端收到正常关闭消息,则关闭client端 - disconnection = server_to_client_task.result() - await client_ws.close(disconnection.code, disconnection.reason) - return + server_disc_errs: List[httpx_ws.WebSocketDisconnect] = [] + not_server_disc_errs: List[Exception] = [] + for e in ws_disconnect_tuple: + if isinstance(e, httpx_ws.WebSocketDisconnect): + server_disc_errs.append(e) else: - # 如果上述情况都没有发生,意味着至少其中一个任务发生了异常,导致了另一个任务被取消 - # NOTE: 我们不在这个分支调用 `ws.close`,而是留到最后的 finally 来关闭 - client_info = client_ws.client - client_host, client_port = ( - (client_info.host, client_info.port) - if client_info is not None - else (None, None) + not_server_disc_errs.append(e) + + is_canceled = isinstance(caught_tg_exc, anyio.get_cancelled_exc_class()) + + client_host, client_port = ( + (client_ws.client.host, client_ws.client.port) + if client_ws.client is not None + else (None, None) + ) + + # Implement reference: + # https://github.com/encode/starlette/blob/4e453ce91940cc7c995e6c728e3fdf341c039056/starlette/websockets.py#L64-L112 + client_ws_closed_state = { + starlette_ws.WebSocketState.DISCONNECTED, + starlette_ws.WebSocketState.RESPONSE, + } + if ( + client_ws.application_state not in client_ws_closed_state + and client_ws.client_state not in client_ws_closed_state + ): + if server_disc_errs: + server_disc = server_disc_errs[0] + await client_ws.close(server_disc.code, server_disc.reason) + elif is_canceled: + await client_ws.close(starlette_status.WS_1001_GOING_AWAY) + else: + await client_ws.close(starlette_status.WS_1011_INTERNAL_ERROR) + logging.warning( + f"[{client_host}:{client_port}] Client websocket closed abnormally during proxying. " + f"Catch tasks exceptions: {not_server_disc_errs!r} " + f"Catch task group exceptions: {caught_tg_exc!r}" ) - # 这里不用dedent是为了更好的性能 - msg = f"""\ -An error occurred in the websocket connection for {client_host}:{client_port}. -client_error: {client_error} -server_error: {server_error}\ -""" - logging.warning(msg) - - except ( - Exception - ) as e: # pragma: no cover # 这个分支是一个保险分支,通常无法执行,所以只进行记录 - logging.error( - f"{e} when close ws connection. client: {client_to_server_task}, server:{server_to_client_task}" - ) - raise - finally: - # 无论如何,确保关闭两个websocket - # 状态码参考: https://developer.mozilla.org/zh-CN/docs/Web/API/CloseEvent - # https://datatracker.ietf.org/doc/html/rfc6455#section-7.4.1 - try: - await client_ws.close(starlette_status.WS_1011_INTERNAL_ERROR) - except Exception: - # 这个分支通常会被触发,因为uvicorn服务器在重复调用close时会引发异常 - pass - try: + # Implement reference: + # https://github.com/frankie567/httpx-ws/blob/940c9adb3afee9dd7c8b95514bdf6444673e4820/httpx_ws/_api.py#L928-L931 + if server_ws.connection.state not in { + wsproto.connection.ConnectionState.CLOSED, + wsproto.connection.ConnectionState.LOCAL_CLOSING, + }: + if client_disc_errs: + client_disc = client_disc_errs[0] + await server_ws.close(client_disc.code, client_disc.reason) + elif is_canceled: + await server_ws.close(starlette_status.WS_1001_GOING_AWAY) + else: + # If remote server has closed normally, here we just close local ws. + # It's normal, so we don't need warning. + if ( + server_ws.connection.state + != wsproto.connection.ConnectionState.REMOTE_CLOSING + ): + logging.warning( + f"[{client_host}:{client_port}] Server websocket closed abnormally during proxying. " + f"Catch tasks exceptions: {not_client_disc_errs!r} " + f"Catch task group exceptions: {caught_tg_exc!r}" + ) await server_ws.close(starlette_status.WS_1011_INTERNAL_ERROR) - except Exception as e: # pragma: no cover - # 这个分支是一个保险分支,通常无法执行,所以只进行记录 - # 不会触发的原因是,负责服务端 ws 连接的 httpx_ws 支持重复调用close而不引发错误 - logging.debug("Unexpected error for debug", exc_info=e) + + +async def _handle_ws_upgrade_error( + client_ws: starlette_ws.WebSocket, + background: BackgroundTask, + ws_upgrade_exc: httpx_ws.WebSocketUpgradeError, +) -> None: + proxy_res = ws_upgrade_exc.response + # https://asgi.readthedocs.io/en/latest/extensions.html#websocket-denial-response + # https://github.com/encode/starlette/blob/4e453ce91940cc7c995e6c728e3fdf341c039056/starlette/websockets.py#L207-L214 + is_able_to_send_denial_response = "websocket.http.response" in client_ws.scope.get( + "extensions", {} + ) + + if is_able_to_send_denial_response: + # # XXX: Can not use send_denial_response with StreamingResponse + # # See: https://github.com/encode/starlette/discussions/2566 + # denial_response = StreamingResponse( + # content=proxy_res.aiter_raw(), + # status_code=proxy_res.status_code, + # headers=proxy_res.headers, + # background=background, + # ) + + # # XXX: Unable to read the content of WebSocketUpgradeError.response + # # See: https://github.com/frankie567/httpx-ws/discussions/69 + # content = await proxy_res.aread() + + denial_response = Response( + content="", + status_code=proxy_res.status_code, + headers=proxy_res.headers, + background=background, + ) + await client_ws.send_denial_response(denial_response) + else: + msg = ( + "Proxy websocket handshake failed, " + "but your ASGI server does not support sending denial response.\n" + f"Denial response: {proxy_res!r}" + ) + logging.warning(msg) + # we close before accept, then ASGI will send 403 to client + await client_ws.close() + await background() #################### # #################### @@ -461,21 +528,15 @@ async def send_request_to_target( # pyright: ignore [reportIncompatibleMethodOv *, websocket: starlette_ws.WebSocket, target_url: httpx.URL, - ) -> Union[Literal[False], StarletteResponse]: + ) -> bool: """Establish websocket connection for both client and target_url, then pass messages between them. - - The http version of request must be in [`SUPPORTED_WS_HTTP_VERSIONS`][fastapi_proxy_lib.core.websocket.SUPPORTED_WS_HTTP_VERSIONS]. - Args: websocket: The client websocket requests. target_url: The url of target websocket server. Returns: - If the establish websocket connection unsuccessfully: - - Will call `websocket.close()` to send code `4xx` - - Then return a `StarletteResponse` from target server - If the establish websocket connection successfully: - - Will run forever until the connection is closed. Then return False. + bool: If handshake failed, return True. Else return False. """ client = self.client follow_redirects = self.follow_redirects @@ -495,13 +556,6 @@ async def send_request_to_target( # pyright: ignore [reportIncompatibleMethodOv ) client_request_params: "QueryParamTypes" = websocket.query_params - # TODO: 是否可以不检查http版本? - check_result = check_http_version(websocket.scope, SUPPORTED_WS_HTTP_VERSIONS) - if check_result is not None: - # NOTE: return 之前最好关闭websocket - await websocket.close() - return check_result - # DEBUG: 用于调试的记录 logging.debug( "WS: client:%s ; url:%s ; params:%s ; headers:%s", @@ -523,7 +577,7 @@ async def send_request_to_target( # pyright: ignore [reportIncompatibleMethodOv proxy_ws = await stack.enter_async_context( httpx_ws.aconnect_ws( - # 这个是httpx_ws类型注解的问题,其实是可以使用httpx.URL的 + # XXX: 这个是httpx_ws类型注解的问题,其实是可以使用httpx.URL的 url=target_url, # pyright: ignore [reportArgumentType] client=client, max_message_size_bytes=max_message_size_bytes, @@ -538,109 +592,68 @@ async def send_request_to_target( # pyright: ignore [reportIncompatibleMethodOv follow_redirects=follow_redirects, ) ) - except httpx_ws.WebSocketUpgradeError as e: - # 这个错误是在 httpx.stream 获取到响应后才返回的, 也就是说至少本服务器的网络应该是正常的 - # 且对于反向ws代理来说,本服务器管理者有义务保证与目标服务器的连接是正常的 - # 所以这里既有可能是客户端的错误,或者是目标服务器拒绝了连接 - # TODO: 也有可能是本服务器的未知错误 - proxy_res = e.response - - # NOTE: return 之前最好关闭websocket - # 不调用websocket.accept就发送关闭请求,uvicorn会自动发送403错误 - await websocket.close() - # TODO: 连接失败的时候httpx_ws会自己关闭连接,但或许这里显式关闭会更好 - - # HACK: 这里的返回的响应其实uvicorn不会处理 - return StreamingResponse( - content=proxy_res.aiter_raw(), - status_code=proxy_res.status_code, - headers=proxy_res.headers, + except httpx_ws.WebSocketUpgradeError as ws_upgrade_exc: + await _handle_ws_upgrade_error( + client_ws=websocket, + background=BackgroundTask(stack.aclose), + ws_upgrade_exc=ws_upgrade_exc, ) + return True # NOTE: 对于反向代理服务器,我们不返回 "任何" "具体的内部" 错误信息给客户端,因为这可能涉及到服务器内部的信息泄露 # NOTE: 请使用 with 语句来 "保证关闭" AsyncWebSocketSession async with stack: - # TODO: websocket.accept 中还有一个headers参数,但是httpx_ws不支持,考虑发起PR - # https://github.com/frankie567/httpx-ws/discussions/53 - - # FIXME: 调查缺少headers参数是否会引起问题,及是否会影响透明代理的无损转发性 + proxy_ws_resp = proxy_ws.response + # TODO: Here is a typing issue of `httpx_ws`, we have to use `assert` to make pyright happy + # https://github.com/frankie567/httpx-ws/pull/54#pullrequestreview-1974062119 + assert proxy_ws_resp is not None + headers = proxy_ws_resp.headers.copy() + # ASGI not allow the headers contains `sec-websocket-protocol` field # https://asgi.readthedocs.io/en/latest/specs/www.html#accept-send-event - - # 这时候如果发生错误,退出时 stack 会自动关闭 httpx_ws 连接,所以这里不需要手动关闭 + headers.pop("sec-websocket-protocol", None) + # XXX: uvicorn websockets implementation not allow contains multiple `Date` and `Server` field, + # only wsporoto can do so. + # https://github.com/encode/uvicorn/pull/1606 + # https://github.com/python-websockets/websockets/issues/1226 + headers.pop("Date", None) + headers.pop("Server", None) await websocket.accept( - subprotocol=proxy_ws.subprotocol - # headers=... - ) - - client_to_server_task = asyncio.create_task( - _wait_client_then_send_to_server( - client_ws=websocket, - server_ws=proxy_ws, - ), - name="client_to_server_task", - ) - server_to_client_task = asyncio.create_task( - _wait_server_then_send_to_client( - client_ws=websocket, - server_ws=proxy_ws, - ), - name="server_to_client_task", - ) - # 保持强引用: https://docs.python.org/zh-cn/3.12/library/asyncio-task.html#creating-tasks - task_group = _ClientServerProxyTask( - client_to_server_task=client_to_server_task, - server_to_client_task=server_to_client_task, + subprotocol=proxy_ws.subprotocol, headers=headers.raw ) - # NOTE: 考虑这两种情况: - # 1. 如果一个任务在发送阶段退出: - # 这意味着对应发送的ws已经关闭或者出错 - # 那么另一个任务很快就会在接收该ws的时候引发异常而退出 - # 很快,最终两个任务都结束 - # **这时候pending 可能 为空,而done为两个任务** - # 2. 如果一个任务在接收阶段退出: - # 这意味着对应接收的ws已经关闭或者发生出错 - # - 对于另一个任务的发送,可能会在发送的时候引发异常而退出 - # - 可能指的是: wsproto后端的uvicorn发送消息永远不会出错 - # - https://github.com/encode/uvicorn/discussions/2137 - # - 对于另一个任务的接收,可能会等待很久,才能继续进行发送任务而引发异常而退出 - # **这时候pending一般为一个未结束任务** - # - # 因为第二种情况的存在,所以需要用 wait_for 强制让其退出 - # 但考虑到第一种情况,先等它 1s ,看看能否正常退出 + ws_disconnect_deque: "deque[_WsDisconnectDequqType]" = deque() + caught_tg_exc = None try: - _, pending = await asyncio.wait( - task_group, - return_when=asyncio.FIRST_COMPLETED, - ) - for ( - pending_task - ) in pending: # NOTE: pending 一般为一个未结束任务,或者为空 - # 开始取消未结束的任务 - try: - await asyncio.wait_for(pending_task, timeout=1) - except asyncio.TimeoutError: - logging.debug(f"{pending} TimeoutError, it's normal.") - except Exception as e: - # 取消期间可能另一个ws会发生异常,这个是正常情况,且会被 asyncio.wait_for 传播 - logging.debug( - f"{pending} raise error when being canceled, it's normal. error: {e}" - ) - except Exception as e: # pragma: no cover # 这个是保险分支,通常无法执行 - logging.warning( - f"Something wrong, please contact the developer. error: {e}" - ) - raise + async with anyio.create_task_group() as tg: + tg.start_soon( + _wait_client_then_send_to_server, + websocket, + proxy_ws, + ws_disconnect_deque, + tg, + name="client_to_server_task", + ) + tg.start_soon( + _wait_server_then_send_to_client, + websocket, + proxy_ws, + ws_disconnect_deque, + tg, + name="server_to_client_task", + ) + except BaseException as base_exc: + caught_tg_exc = base_exc + raise # NOTE: must raise again finally: - # 无论如何都要关闭两个websocket - # NOTE: 这时候两个任务都已经结束 - await _close_ws( - client_to_server_task=client_to_server_task, - server_to_client_task=server_to_client_task, - client_ws=websocket, - server_ws=proxy_ws, - ) + with anyio.CancelScope(shield=True): + await _close_ws( + websocket, + proxy_ws, + ws_disconnect_deque, + caught_tg_exc, + ) + return False @override @@ -701,8 +714,8 @@ async def close_proxy_event(_: FastAPI) -> AsyncIterator[None]: app = FastAPI(lifespan=close_proxy_event) @app.websocket("/{path:path}") - async def _(websocket: WebSocket, path: str = ""): - return await proxy.proxy(websocket=websocket, path=path) + async def _(websocket: WebSocket): + return await proxy.proxy(websocket=websocket) # Then run shell: `uvicorn :app --host http://127.0.0.1:8000 --port 8000` # visit the app: `ws://127.0.0.1:8000/` @@ -767,29 +780,20 @@ def __init__( @override async def proxy( # pyright: ignore [reportIncompatibleMethodOverride] - self, *, websocket: starlette_ws.WebSocket, path: Optional[str] = None - ) -> Union[Literal[False], StarletteResponse]: + self, *, websocket: starlette_ws.WebSocket + ) -> bool: """Establish websocket connection for both client and target_url, then pass messages between them. Args: websocket: The client websocket requests. - path: The path params of websocket request, which means the path params of base url.
- If None, will get it from `websocket.path_params`.
- **Usually, you don't need to pass this argument**. Returns: - If the establish websocket connection unsuccessfully: - - Will call `websocket.close()` to send code `4xx` - - Then return a `StarletteResponse` from target server - If the establish websocket connection successfully: - - Will run forever until the connection is closed. Then return False. + bool: If handshake failed, return True. Else return False. """ base_url = self.base_url # 只取第一个路径参数。注意,我们允许没有路径参数,这代表直接请求 - path_param: str = ( - path if path is not None else next(iter(websocket.path_params.values()), "") - ) + path_param: str = next(iter(websocket.path_params.values()), "") # 将路径参数拼接到目标url上 # e.g: "https://www.example.com/p0/" + "p1" diff --git a/src/fastapi_proxy_lib/fastapi/router.py b/src/fastapi_proxy_lib/fastapi/router.py index 076f69d..346c037 100644 --- a/src/fastapi_proxy_lib/fastapi/router.py +++ b/src/fastapi_proxy_lib/fastapi/router.py @@ -3,7 +3,6 @@ The low-level API for [fastapi_proxy_lib.fastapi.app][]. """ -import asyncio import warnings from contextlib import asynccontextmanager from typing import ( @@ -11,13 +10,13 @@ AsyncContextManager, AsyncIterator, Callable, - Literal, Optional, Set, TypeVar, Union, ) +import anyio from fastapi import APIRouter from starlette.requests import Request from starlette.responses import Response @@ -63,7 +62,7 @@ def _http_register_router( @router.patch("/{path:path}", **kwargs) @router.trace("/{path:path}", **kwargs) async def http_proxy( # pyright: ignore[reportUnusedFunction] - request: Request, path: str = "" + request: Request, ) -> Response: """HTTP proxy endpoint. @@ -74,7 +73,7 @@ async def http_proxy( # pyright: ignore[reportUnusedFunction] Returns: The response from target server. """ - return await proxy.proxy(request=request, path=path) + return await proxy.proxy(request=request) def _ws_register_router( @@ -96,8 +95,8 @@ def _ws_register_router( @router.websocket("/{path:path}", **kwargs) async def ws_proxy( # pyright: ignore[reportUnusedFunction] - websocket: WebSocket, path: str = "" - ) -> Union[Response, Literal[False]]: + websocket: WebSocket, + ) -> bool: """WebSocket proxy endpoint. Args: @@ -105,13 +104,9 @@ async def ws_proxy( # pyright: ignore[reportUnusedFunction] path: The path parameters of request. Returns: - If the establish websocket connection unsuccessfully: - - Will call `websocket.close()` to send code `4xx` - - Then return a `StarletteResponse` from target server - If the establish websocket connection successfully: - - Will run forever until the connection is closed. Then return False. + bool: If handshake failed, return True. Else return False. """ - return await proxy.proxy(websocket=websocket, path=path) + return await proxy.proxy(websocket=websocket) class RouterHelper: @@ -273,6 +268,8 @@ async def shutdown_clients(*_: Any, **__: Any) -> AsyncIterator[None]: When __aexit__ is called, will close all registered proxy. """ yield - await asyncio.gather(*[proxy.aclose() for proxy in self.registered_proxy]) + async with anyio.create_task_group() as tg: + for proxy in self.registered_proxy: + tg.start_soon(proxy.aclose) return shutdown_clients diff --git a/tests/app/echo_ws_app.py b/tests/app/echo_ws_app.py index e4925e6..3fb4b6a 100644 --- a/tests/app/echo_ws_app.py +++ b/tests/app/echo_ws_app.py @@ -1,9 +1,10 @@ # ruff: noqa: D100 # pyright: reportUnusedFunction=false -import asyncio +import anyio from fastapi import FastAPI, WebSocket +from starlette.responses import JSONResponse from starlette.websockets import WebSocketDisconnect from .tool import AppDataclass4Test, RequestDict @@ -53,8 +54,8 @@ async def echo_bytes(websocket: WebSocket): except WebSocketDisconnect: break - @app.websocket("/accept_foo_subprotocol") - async def accept_foo_subprotocol(websocket: WebSocket): + @app.websocket("/accept_foo_subprotocol_and_foo_bar_header") + async def accept_foo_subprotocol_and_foo_bar_header(websocket: WebSocket): """When client send subprotocols request, if subprotocols contain "foo", will accept it.""" nonlocal test_app_dataclass test_app_dataclass.request_dict["request"] = websocket @@ -65,19 +66,20 @@ async def accept_foo_subprotocol(websocket: WebSocket): else: accepted_subprotocol = None - await websocket.accept(subprotocol=accepted_subprotocol) + await websocket.accept( + subprotocol=accepted_subprotocol, headers=[(b"foo", b"bar")] + ) await websocket.close() - @app.websocket("/just_close_with_1001") - async def just_close_with_1001(websocket: WebSocket): - """Just do nothing after `accept`, then close ws with 1001 code.""" + @app.websocket("/just_close_with_1002_and_foo") + async def just_close_with_1002_and_foo(websocket: WebSocket): + """Just do nothing after `accept`, then close ws with 1001 code and 'foo'.""" nonlocal test_app_dataclass test_app_dataclass.request_dict["request"] = websocket await websocket.accept() - await asyncio.sleep(0.3) - await websocket.close(1001) + await websocket.close(1002, "foo") @app.websocket("/reject_handshake") async def reject_handshake(websocket: WebSocket): @@ -87,14 +89,34 @@ async def reject_handshake(websocket: WebSocket): await websocket.close() - @app.websocket("/do_nothing") + @app.websocket("/send_denial_response_400_foo_bar_header_and_json_body") + async def send_denial_response_400_foo_bar_header_and_json_body( + websocket: WebSocket, + ): + """Will reject ws request by just calling `websocket.close()`.""" + nonlocal test_app_dataclass + test_app_dataclass.request_dict["request"] = websocket + + denial_resp = JSONResponse({"foo": "bar"}, 400, headers={"foo": "bar"}) + await websocket.send_denial_response(denial_resp) + + @app.websocket("/receive_and_send_text_once_without_closing") async def do_nothing(websocket: WebSocket): - """Will do nothing except `websocket.accept()`.""" + """Will receive text once and send it back once, without closing ws. + + Note: user must close the ws manually, and call `websocket.state.closing.set()`. + """ nonlocal test_app_dataclass + websocket.state.closing = anyio.Event() test_app_dataclass.request_dict["request"] = websocket await websocket.accept() + recev = await websocket.receive_text() + await websocket.send_text(recev) + + await websocket.state.closing.wait() + return test_app_dataclass diff --git a/tests/app/tool.py b/tests/app/tool.py index c740478..05cead4 100644 --- a/tests/app/tool.py +++ b/tests/app/tool.py @@ -1,18 +1,30 @@ # noqa: D100 -import asyncio -import socket +from contextlib import AsyncExitStack from dataclasses import dataclass -from typing import Any, Callable, List, Optional, Type, TypedDict, TypeVar, Union +from typing import Any, Literal, Optional, TypedDict, Union +import anyio import httpx +import sniffio import uvicorn from fastapi import FastAPI +from hypercorn import Config as HyperConfig +from hypercorn.asyncio.run import ( + worker_serve as hyper_aio_worker_serve, # pyright: ignore[reportUnknownVariableType] +) +from hypercorn.trio.run import ( + worker_serve as hyper_trio_worker_serve, # pyright: ignore[reportUnknownVariableType] +) +from hypercorn.utils import ( + repr_socket_addr, # pyright: ignore[reportUnknownVariableType] +) +from hypercorn.utils import ( + wrap_app as hyper_wrap_app, # pyright: ignore[reportUnknownVariableType] +) from starlette.requests import Request from starlette.websockets import WebSocket -from typing_extensions import Self, override - -_Decoratable_T = TypeVar("_Decoratable_T", bound=Union[Callable[..., Any], Type[Any]]) +from typing_extensions import Self, assert_never ServerRecvRequestsTypes = Union[Request, WebSocket] @@ -46,192 +58,219 @@ def get_request(self) -> ServerRecvRequestsTypes: return server_recv_request -def _no_override_uvicorn_server(_method: _Decoratable_T) -> _Decoratable_T: - """Check if the method is already in `uvicorn.Server`.""" - assert not hasattr( - uvicorn.Server, _method.__name__ - ), f"Override method of `uvicorn.Server` cls : {_method.__name__}" - return _method +class _UvicornServer(uvicorn.Server): + """subclass of `uvicorn.Server` which can use AsyncContext to launch and shutdown automatically.""" + async def __aenter__(self) -> Self: + """Launch the server.""" + # FIXME: # 这个socket被设计为可被同一进程内的多个server共享,可能会引起潜在问题 + self._socket = self.config.bind_socket() + self._exit_stack = AsyncExitStack() -class AeixtTimeoutUndefine: - """Didn't set `contx_exit_timeout` in `aexit()`.""" + task_group = await self._exit_stack.enter_async_context( + anyio.create_task_group() + ) + task_group.start_soon( + self.serve, [self._socket], name=f"Uvicorn Server Task of {self}" + ) + return self -aexit_timeout_undefine = AeixtTimeoutUndefine() + async def __aexit__(self, *_: Any, **__: Any) -> None: + """Shutdown the server.""" + # 在 uvicorn.Server 的实现中,设置 should_exit 可以使得 server 任务结束 + assert not self.should_exit, "The server has already exited." + self.should_exit = True + await self._exit_stack.__aexit__(*_, **__) + @property + def contx_socket_url(self) -> httpx.URL: + """If server is tcp socket, return the url of server. -# HACK: 不能继承 AbstractAsyncContextManager[Self] -# 目前有问题,继承 AbstractAsyncContextManager 的话pyright也推测不出来类型 -# 只能依靠 __aenter__ 和 __aexit__ 的类型注解 -class UvicornServer(uvicorn.Server): - """subclass of `uvicorn.Server` which can use AsyncContext to launch and shutdown automatically. + Note: The path of url is explicitly set to "/". + """ + config = self.config + if config.fd is not None or config.uds is not None: + raise RuntimeError("Only support tcp socket.") - Attributes: - contx_server_task: The task of server. - contx_socket: The socket of server. + # Implement ref: + # https://github.com/encode/uvicorn/blob/a2219eb2ed2bbda4143a0fb18c4b0578881b1ae8/uvicorn/server.py#L201-L220 + host, port = self._socket.getsockname()[:2] + return httpx.URL( + host=host, + port=port, + scheme="https" if config.is_ssl else "http", + path="/", + ) - other attributes are same as `uvicorn.Server`: - - config: The config arg that be passed in. - ... - """ - _contx_server_task: Union["asyncio.Task[None]", None] - assert not hasattr(uvicorn.Server, "_contx_server_task") +class _HypercornServer: + """An AsyncContext to launch and shutdown Hypercorn server automatically.""" - _contx_socket: Union[socket.socket, None] - assert not hasattr(uvicorn.Server, "_contx_socket") + def __init__(self, app: FastAPI, config: HyperConfig): + self.config = config + self.app = app + self.should_exit = anyio.Event() - _contx_server_started_event: Union[asyncio.Event, None] - assert not hasattr(uvicorn.Server, "_contx_server_started_event") + async def __aenter__(self) -> Self: + """Launch the server.""" + self._exit_stack = AsyncExitStack() - contx_exit_timeout: Union[int, float, None] - assert not hasattr(uvicorn.Server, "contx_exit_timeout") + self.current_async_lib = sniffio.current_async_library() - @override - def __init__( - self, config: uvicorn.Config, contx_exit_timeout: Union[int, float, None] = None - ) -> None: - """The same as `uvicorn.Server.__init__`.""" - super().__init__(config=config) - self._contx_server_task = None - self._contx_socket = None - self._contx_server_started_event = None - self.contx_exit_timeout = contx_exit_timeout - - @override - async def startup(self, sockets: Optional[List[socket.socket]] = None) -> None: - """The same as `uvicorn.Server.startup`.""" - super_return = await super().startup(sockets=sockets) - self.contx_server_started_event.set() - return super_return - - @_no_override_uvicorn_server - async def aenter(self) -> Self: - """Launch the server.""" - # 在分配资源之前,先检查是否重入 - if self.contx_server_started_event.is_set(): - raise RuntimeError("DO not launch server by __aenter__ again!") + if self.current_async_lib == "asyncio": + serve_func = ( # pyright: ignore[reportUnknownVariableType] + hyper_aio_worker_serve + ) - # FIXME: # 这个socket被设计为可被同一进程内的多个server共享,可能会引起潜在问题 - self._contx_socket = self.config.bind_socket() + # Implement ref: + # https://github.com/pgjones/hypercorn/blob/3fbd5f245e5dfeaba6ad852d9135d6a32b228d05/src/hypercorn/asyncio/run.py#L89-L90 + self._sockets = self.config.create_sockets() - self._contx_server_task = asyncio.create_task( - self.serve([self._contx_socket]), name=f"Uvicorn Server Task of {self}" - ) - # 在 uvicorn.Server 的实现中,Server.serve() 内部会调用 Server.startup() 完成启动 - # 被覆盖的 self.startup() 会在完成时调用 self.contx_server_started_event.set() - await self.contx_server_started_event.wait() # 等待服务器确实启动后才返回 - return self + elif self.current_async_lib == "trio": + serve_func = ( # pyright: ignore[reportUnknownVariableType] + hyper_trio_worker_serve + ) - @_no_override_uvicorn_server - async def __aenter__(self) -> Self: - """Launch the server. + # Implement ref: + # https://github.com/pgjones/hypercorn/blob/3fbd5f245e5dfeaba6ad852d9135d6a32b228d05/src/hypercorn/trio/run.py#L51-L56 + self._sockets = self.config.create_sockets() + for sock in self._sockets.secure_sockets: + sock.listen(self.config.backlog) + for sock in self._sockets.insecure_sockets: + sock.listen(self.config.backlog) - The same as `self.aenter()`. - """ - return await self.aenter() + else: + raise RuntimeError(f"Unsupported async library {self.current_async_lib!r}") + + async def serve() -> None: + # Implement ref: + # https://github.com/pgjones/hypercorn/blob/3fbd5f245e5dfeaba6ad852d9135d6a32b228d05/src/hypercorn/asyncio/__init__.py#L12-L46 + # https://github.com/pgjones/hypercorn/blob/3fbd5f245e5dfeaba6ad852d9135d6a32b228d05/src/hypercorn/trio/__init__.py#L14-L52 + await serve_func( + hyper_wrap_app( + self.app, # pyright: ignore[reportArgumentType] + self.config.wsgi_max_body_size, + mode=None, + ), + self.config, + shutdown_trigger=self.should_exit.wait, + sockets=self._sockets, + ) + + task_group = await self._exit_stack.enter_async_context( + anyio.create_task_group() + ) + task_group.start_soon(serve, name=f"Hypercorn Server Task of {self}") + return self - @_no_override_uvicorn_server - async def aexit( - self, - contx_exit_timeout: Union[ - int, float, None, AeixtTimeoutUndefine - ] = aexit_timeout_undefine, - ) -> None: + async def __aexit__(self, *_: Any, **__: Any) -> None: """Shutdown the server.""" - contx_server_task = self.contx_server_task - contx_socket = self.contx_socket + assert not self.should_exit.is_set(), "The server has already exited." + self.should_exit.set() + await self._exit_stack.__aexit__(*_, **__) - if isinstance(contx_exit_timeout, AeixtTimeoutUndefine): - contx_exit_timeout = self.contx_exit_timeout + @property + def contx_socket_url(self) -> httpx.URL: + """If server is tcp socket, return the url of server. - # 在 uvicorn.Server 的实现中,设置 should_exit 可以使得 server 任务结束 - assert hasattr(self, "should_exit") - self.should_exit = True + Note: The path of url is explicitly set to "/". + """ + config = self.config + sockets = self._sockets + + # Implement ref: + # https://github.com/pgjones/hypercorn/blob/3fbd5f245e5dfeaba6ad852d9135d6a32b228d05/src/hypercorn/asyncio/run.py#L112-L149 + # https://github.com/pgjones/hypercorn/blob/3fbd5f245e5dfeaba6ad852d9135d6a32b228d05/src/hypercorn/trio/run.py#L61-L82 + + # We only run on one socket each time, + # so we raise `RuntimeError` to avoid other unknown errors during testing. + if sockets.insecure_sockets: + if len(sockets.insecure_sockets) > 1: + raise RuntimeError("Hypercorn test: Multiple insecure_sockets found.") + socket = sockets.insecure_sockets[0] + elif sockets.secure_sockets: + if len(sockets.secure_sockets) > 1: + raise RuntimeError("Hypercorn test: secure_sockets sockets found.") + socket = sockets.secure_sockets[0] + else: + raise RuntimeError("Hypercorn test: No socket found.") - try: - await asyncio.wait_for(contx_server_task, timeout=contx_exit_timeout) - except asyncio.TimeoutError: - print(f"{contx_server_task.get_name()} timeout!") - finally: - # 其实uvicorn.Server会自动关闭socket,这里是为了保险起见 - contx_socket.close() + bind = repr_socket_addr(socket.family, socket.getsockname()) + if bind.startswith(("unix:", "fd://")): + raise RuntimeError("Only support tcp socket.") - @_no_override_uvicorn_server - async def __aexit__(self, *_: Any, **__: Any) -> None: - """Shutdown the server. + # Implement ref: + # https://docs.python.org/zh-cn/3/library/socket.html#socket-families + host, port = bind.split(":") + port = int(port) - The same as `self.aexit()`. - """ - return await self.aexit() + return httpx.URL( + host=host, + port=port, + scheme="https" if config.ssl_enabled else "http", + path="/", + ) - @property - @_no_override_uvicorn_server - def contx_server_started_event(self) -> asyncio.Event: - """The event that indicates the server has started. - When first call the property, it will instantiate a `asyncio.Event()`to - `self._contx_server_started_event`. +class AutoServer: + """An AsyncContext to launch and shutdown Hypercorn or Uvicorn server automatically.""" - Warn: This is a internal implementation detail, do not change the event manually. - - please call the property in `self.aenter()` or `self.startup()` **first**. - - **Never** call it outside of an async event loop first: - https://stackoverflow.com/questions/53724665/using-queues-results-in-asyncio-exception-got-future-future-pending-attached - """ - if self._contx_server_started_event is None: - self._contx_server_started_event = asyncio.Event() + server_type: Literal["uvicorn", "hypercorn"] - return self._contx_server_started_event + def __init__( + self, + app: FastAPI, + host: str, + port: int, + server_type: Optional[Literal["uvicorn", "hypercorn"]] = None, + ): + """Only support ipv4 address. - @property - @_no_override_uvicorn_server - def contx_socket(self) -> socket.socket: - """The socket of server. + If use uvicorn, it only support asyncio backend. - Note: must call `self.__aenter__()` first. + If `host` == 0, then use random port. """ - if self._contx_socket is None: - raise RuntimeError("Please call `self.__aenter__()` first.") + self.app = app + self.host = host + self.port = port + self._server_type: Optional[Literal["uvicorn", "hypercorn"]] = server_type + + async def __aenter__(self) -> Self: + """Launch the server.""" + if self._server_type is None: + if sniffio.current_async_library() == "asyncio": + self.server_type = "uvicorn" + else: + self.server_type = "hypercorn" else: - return self._contx_socket + self.server_type = self._server_type - @property - @_no_override_uvicorn_server - def contx_server_task(self) -> "asyncio.Task[None]": - """The task of server. + if self.server_type == "hypercorn": + config = HyperConfig() + config.bind = f"{self.host}:{self.port}" - Note: must call `self.__aenter__()` first. - """ - if self._contx_server_task is None: - raise RuntimeError("Please call `self.__aenter__()` first.") + self.config = config + self.server = _HypercornServer(self.app, config) + elif self.server_type == "uvicorn": + self.config = uvicorn.Config(self.app, host=self.host, port=self.port) + self.server = _UvicornServer(self.config) else: - return self._contx_server_task + assert_never(self.server_type) - @property - @_no_override_uvicorn_server - def contx_socket_getname(self) -> Any: - """Utils for calling self.contx_socket.getsockname(). + self._exit_stack = AsyncExitStack() + await self._exit_stack.enter_async_context(self.server) + await anyio.sleep(0.5) # XXX, HACK: wait for server to start + return self - Return: - refer to: https://docs.python.org/zh-cn/3/library/socket.html#socket-families - """ - return self.contx_socket.getsockname() + async def __aexit__(self, *_: Any, **__: Any) -> None: + """Shutdown the server.""" + await self._exit_stack.__aexit__(*_, **__) @property - @_no_override_uvicorn_server def contx_socket_url(self) -> httpx.URL: """If server is tcp socket, return the url of server. Note: The path of url is explicitly set to "/". """ - config = self.config - if config.fd is not None or config.uds is not None: - raise RuntimeError("Only support tcp socket.") - host, port = self.contx_socket_getname[:2] - return httpx.URL( - host=host, - port=port, - scheme="https" if config.is_ssl else "http", - path="/", - ) + return self.server.contx_socket_url diff --git a/tests/conftest.py b/tests/conftest.py index 0527101..e60ec07 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -15,13 +15,13 @@ Callable, Coroutine, Literal, + Optional, Protocol, - Union, ) import pytest -import uvicorn from asgi_lifespan import LifespanManager +from fastapi import FastAPI from fastapi_proxy_lib.fastapi.app import ( forward_http_app, reverse_http_app, @@ -31,7 +31,7 @@ from .app.echo_http_app import get_app as get_http_test_app from .app.echo_ws_app import get_app as get_ws_test_app -from .app.tool import AppDataclass4Test, UvicornServer +from .app.tool import AppDataclass4Test, AutoServer # ASGI types. # Copied from: https://github.com/florimondmanca/asgi-lifespan/blob/fbb0f440337314be97acaae1a3c0c7a2ec8298dd/src/asgi_lifespan/_types.py @@ -62,17 +62,28 @@ class LifeAppDataclass4Test(AppDataclass4Test): """The lifespan of app will be managed automatically by pytest.""" -class UvicornServerFixture(Protocol): # noqa: D101 +class AutoServerFixture(Protocol): # noqa: D101 def __call__( # noqa: D102 - self, config: uvicorn.Config, contx_exit_timeout: Union[int, float, None] = None - ) -> Coroutine[None, None, UvicornServer]: ... + self, + app: FastAPI, + host: str, + port: int, + server_type: Optional[Literal["uvicorn", "hypercorn"]] = None, + ) -> Coroutine[None, None, AutoServer]: ... # https://anyio.readthedocs.io/en/stable/testing.html#specifying-the-backends-to-run-on -@pytest.fixture() -def anyio_backend() -> Literal["asyncio"]: +@pytest.fixture( + params=[ + pytest.param(("asyncio", {"use_uvloop": False}), id="asyncio"), + pytest.param( + ("trio", {"restrict_keyboard_interrupt_to_checkpoints": True}), id="trio" + ), + ], +) +def anyio_backend(request: pytest.FixtureRequest): """Specify the async backend for `pytest.mark.anyio`.""" - return "asyncio" + return request.param @pytest.fixture() @@ -192,19 +203,22 @@ def reverse_ws_app_fct( @pytest.fixture() -async def uvicorn_server_fixture() -> AsyncIterator[UvicornServerFixture]: - """Fixture for UvicornServer. +async def auto_server_fixture() -> AsyncIterator[AutoServerFixture]: + """Fixture for AutoServer. Will launch and shutdown automatically. """ async with AsyncExitStack() as exit_stack: - async def uvicorn_server_fct( - config: uvicorn.Config, contx_exit_timeout: Union[int, float, None] = None - ) -> UvicornServer: - uvicorn_server = await exit_stack.enter_async_context( - UvicornServer(config=config, contx_exit_timeout=contx_exit_timeout) + async def auto_server_fct( + app: FastAPI, + host: str, + port: int, + server_type: Optional[Literal["uvicorn", "hypercorn"]] = None, + ) -> AutoServer: + auto_server = await exit_stack.enter_async_context( + AutoServer(app=app, host=host, port=port, server_type=server_type) ) - return uvicorn_server + return auto_server - yield uvicorn_server_fct + yield auto_server_fct diff --git a/tests/test_core_lib.py b/tests/test_core_lib.py index da52e58..b214b2e 100644 --- a/tests/test_core_lib.py +++ b/tests/test_core_lib.py @@ -70,7 +70,10 @@ async def _() -> JSONResponse: # } # } - client = httpx.AsyncClient(app=app, base_url="http://www.example.com") + client = httpx.AsyncClient( + transport=httpx.ASGITransport(app), # pyright: ignore[reportArgumentType] + base_url="http://www.example.com", + ) resp = await client.get("http://www.example.com/exception") assert resp.status_code == 0 assert resp.json()["detail"] == test_err_msg diff --git a/tests/test_docs_examples.py b/tests/test_docs_examples.py index 12bea06..d7a9a3a 100644 --- a/tests/test_docs_examples.py +++ b/tests/test_docs_examples.py @@ -23,8 +23,8 @@ async def close_proxy_event(_: FastAPI) -> AsyncIterator[None]: app = FastAPI(lifespan=close_proxy_event) @app.get("/{path:path}") - async def _(request: Request, path: str = ""): - return await proxy.proxy(request=request, path=path) + async def _(request: Request): + return await proxy.proxy(request=request) # Then run shell: `uvicorn :app --host http://127.0.0.1:8000 --port 8000` # visit the app: `http://127.0.0.1:8000/http://www.example.com` @@ -52,8 +52,8 @@ async def close_proxy_event(_: FastAPI) -> AsyncIterator[None]: # (1)! app = FastAPI(lifespan=close_proxy_event) @app.get("/{path:path}") # (2)! - async def _(request: Request, path: str = ""): - return await proxy.proxy(request=request, path=path) # (3)! + async def _(request: Request): + return await proxy.proxy(request=request) # (3)! # Then run shell: `uvicorn :app --host http://127.0.0.1:8000 --port 8000` # visit the app: `http://127.0.0.1:8000/` @@ -62,11 +62,7 @@ async def _(request: Request, path: str = ""): """ 1. lifespan please refer to [starlette/lifespan](https://www.starlette.io/lifespan/) 2. `{path:path}` is the key.
It allows the app to accept all path parameters.
- visit for more info. - 3. !!! info - In fact, you only need to pass the `request: Request` argument.
- `fastapi_proxy_lib` can automatically get the `path` from `request`.
- Explicitly pointing it out here is just to remind you not to forget to specify `{path:path}`. """ + visit for more info. """ def test_reverse_ws_proxy() -> None: @@ -90,8 +86,8 @@ async def close_proxy_event(_: FastAPI) -> AsyncIterator[None]: app = FastAPI(lifespan=close_proxy_event) @app.websocket("/{path:path}") - async def _(websocket: WebSocket, path: str = ""): - return await proxy.proxy(websocket=websocket, path=path) + async def _(websocket: WebSocket): + return await proxy.proxy(websocket=websocket) # Then run shell: `uvicorn :app --host http://127.0.0.1:8000 --port 8000` # visit the app: `ws://127.0.0.1:8000/` diff --git a/tests/test_http.py b/tests/test_http.py index e2c649c..71b56ac 100644 --- a/tests/test_http.py +++ b/tests/test_http.py @@ -32,7 +32,10 @@ async def tool_4_test_fixture( # pyright: ignore[reportIncompatibleMethodOverri ) -> Tool4TestFixture: """目标服务器请参考`tests.app.echo_http_app.get_app`.""" client_for_conn_to_target_server = httpx.AsyncClient( - app=echo_http_test_model.app, base_url=DEFAULT_TARGET_SERVER_BASE_URL + transport=httpx.ASGITransport( + echo_http_test_model.app # pyright: ignore[reportArgumentType] + ), + base_url=DEFAULT_TARGET_SERVER_BASE_URL, ) reverse_http_app = await reverse_http_app_fct( @@ -41,7 +44,10 @@ async def tool_4_test_fixture( # pyright: ignore[reportIncompatibleMethodOverri ) client_for_conn_to_proxy_server = httpx.AsyncClient( - app=reverse_http_app, base_url=DEFAULT_PROXY_SERVER_BASE_URL + transport=httpx.ASGITransport( + reverse_http_app # pyright: ignore[reportArgumentType] + ), + base_url=DEFAULT_PROXY_SERVER_BASE_URL, ) get_request = echo_http_test_model.get_request @@ -198,7 +204,10 @@ async def test_bad_url_request( ) client_for_conn_to_proxy_server = httpx.AsyncClient( - app=reverse_http_app, base_url=DEFAULT_PROXY_SERVER_BASE_URL + transport=httpx.ASGITransport( + reverse_http_app # pyright: ignore[reportArgumentType] + ), + base_url=DEFAULT_PROXY_SERVER_BASE_URL, ) r = await client_for_conn_to_proxy_server.get(DEFAULT_PROXY_SERVER_BASE_URL) @@ -233,9 +242,9 @@ async def test_cookie_leakage( assert not client_for_conn_to_proxy_server.cookies # check if cookie is not leaked + client_for_conn_to_proxy_server.cookies.set("a", "b") r = await client_for_conn_to_proxy_server.get( - proxy_server_base_url + "get/cookies", - cookies={"a": "b"}, + proxy_server_base_url + "get/cookies" ) assert "foo" not in r.json() # not leaked assert r.json()["a"] == "b" # send cookies normally @@ -252,7 +261,10 @@ async def tool_4_test_fixture( # pyright: ignore[reportIncompatibleMethodOverri ) -> Tool4TestFixture: """目标服务器请参考`tests.app.echo_http_app.get_app`.""" client_for_conn_to_target_server = httpx.AsyncClient( - app=echo_http_test_model.app, base_url=DEFAULT_TARGET_SERVER_BASE_URL + transport=httpx.ASGITransport( + echo_http_test_model.app # pyright: ignore[reportArgumentType] + ), + base_url=DEFAULT_TARGET_SERVER_BASE_URL, ) forward_http_app = await forward_http_app_fct( @@ -260,7 +272,10 @@ async def tool_4_test_fixture( # pyright: ignore[reportIncompatibleMethodOverri ) client_for_conn_to_proxy_server = httpx.AsyncClient( - app=forward_http_app, base_url=DEFAULT_PROXY_SERVER_BASE_URL + transport=httpx.ASGITransport( + forward_http_app # pyright: ignore[reportArgumentType] + ), + base_url=DEFAULT_PROXY_SERVER_BASE_URL, ) get_request = echo_http_test_model.get_request @@ -310,7 +325,10 @@ async def test_bad_url_request( ) client_for_conn_to_proxy_server = httpx.AsyncClient( - app=forward_http_app, base_url=DEFAULT_PROXY_SERVER_BASE_URL + transport=httpx.ASGITransport( + forward_http_app # pyright: ignore[reportArgumentType] + ), + base_url=DEFAULT_PROXY_SERVER_BASE_URL, ) # 错误的无法发出请求的URL @@ -356,7 +374,10 @@ async def connect_error_mock_handler( ) client_for_conn_to_proxy_server = httpx.AsyncClient( - app=forward_http_app, base_url=DEFAULT_PROXY_SERVER_BASE_URL + transport=httpx.ASGITransport( + forward_http_app # pyright: ignore[reportArgumentType] + ), + base_url=DEFAULT_PROXY_SERVER_BASE_URL, ) r = await client_for_conn_to_proxy_server.get( @@ -385,7 +406,9 @@ async def test_denial_http2( ) client_for_conn_to_proxy_server = httpx.AsyncClient( - app=forward_http_app, + transport=httpx.ASGITransport( + forward_http_app + ), # pyright: ignore[reportArgumentType] base_url=proxy_server_base_url, http2=True, http1=False, diff --git a/tests/test_ws.py b/tests/test_ws.py index 2119719..c76b912 100644 --- a/tests/test_ws.py +++ b/tests/test_ws.py @@ -1,157 +1,150 @@ # noqa: D100 -import asyncio from contextlib import AsyncExitStack +from dataclasses import dataclass from multiprocessing import Process, Queue -from typing import Any, Dict, Literal, Optional +from typing import Any, Dict +import anyio import httpx import httpx_ws import pytest -import uvicorn from fastapi_proxy_lib.fastapi.app import reverse_ws_app as get_reverse_ws_app from httpx_ws import aconnect_ws from starlette import websockets as starlette_websockets_module from typing_extensions import override from .app.echo_ws_app import get_app as get_ws_test_app -from .app.tool import UvicornServer -from .conftest import UvicornServerFixture +from .app.tool import AutoServer +from .conftest import AutoServerFixture from .tool import ( AbstractTestProxy, Tool4TestFixture, ) DEFAULT_HOST = "127.0.0.1" -DEFAULT_PORT = 0 -DEFAULT_CONTX_EXIT_TIMEOUT = 5 +DEFAULT_PORT = 0 # random port -# WS_BACKENDS_NEED_BE_TESTED = ("websockets", "wsproto") -# # FIXME: wsproto 有问题,暂时不测试 -# # ConnectionResetError: [WinError 10054] 远程主机强迫关闭了一个现有的连接。 -# # https://github.com/encode/uvicorn/discussions/2105 -WS_BACKENDS_NEED_BE_TESTED = ("websockets",) - -# https://www.python-httpx.org/advanced/#http-proxying +# https://www.python-httpx.org/advanced/proxies/ +# NOTE: Foce to connect directly, avoid using system proxies NO_PROXIES: Dict[Any, Any] = {"all://": None} -def _subprocess_run_echo_ws_uvicorn_server(queue: "Queue[str]", **kwargs: Any): +@dataclass +class Tool4ServerTestFixture(Tool4TestFixture): # noqa: D101 + target_server: AutoServer + proxy_server: AutoServer + + +def _subprocess_run_echo_ws_server(queue: "Queue[str]"): """Run echo ws app in subprocess. Args: queue: The queue for subprocess to put the url of echo ws app. After the server is started, the url will be put into the queue. - **kwargs: The kwargs for `uvicorn.Config` """ - default_kwargs = { - "app": get_ws_test_app().app, - "port": DEFAULT_PORT, - "host": DEFAULT_HOST, - } - default_kwargs.update(kwargs) - target_ws_server = UvicornServer( - uvicorn.Config(**default_kwargs), # pyright: ignore[reportArgumentType] + target_ws_server = AutoServer( + app=get_ws_test_app().app, + host=DEFAULT_HOST, + port=DEFAULT_PORT, ) async def run(): - await target_ws_server.aenter() - url = str(target_ws_server.contx_socket_url) - queue.put(url) - queue.close() - while True: # run forever - await asyncio.sleep(0.1) + async with target_ws_server: + url = str(target_ws_server.contx_socket_url) + queue.put(url) + queue.close() + while True: # run forever + await anyio.sleep(0.1) - asyncio.run(run()) + # It's not proxy app server for which we test, so it's ok to not use trio backend + anyio.run(run) def _subprocess_run_httpx_ws( queue: "Queue[str]", - kwargs_async_client: Optional[Dict[str, Any]] = None, - kwargs_aconnect_ws: Optional[Dict[str, Any]] = None, + aconnect_ws_url: str, ): """Run aconnect_ws in subprocess. Args: queue: The queue for subprocess to put something for flag of ws connection established. - kwargs_async_client: The kwargs for `httpx.AsyncClient` - kwargs_aconnect_ws: The kwargs for `httpx_ws.aconnect_ws` + aconnect_ws_url: The websocket url for aconnect_ws. + will add "receive_and_send_text_once_without_closing" to the url. """ - kwargs_async_client = kwargs_async_client or {} - kwargs_aconnect_ws = kwargs_aconnect_ws or {} - - kwargs_async_client.pop("proxies", None) - kwargs_aconnect_ws.pop("client", None) async def run(): _exit_stack = AsyncExitStack() - _temp_client = httpx.AsyncClient(proxies=NO_PROXIES, **kwargs_async_client) - _ = await _exit_stack.enter_async_context( + _temp_client = httpx.AsyncClient(mounts=NO_PROXIES) + ws = await _exit_stack.enter_async_context( aconnect_ws( client=_temp_client, - **kwargs_aconnect_ws, + url=aconnect_ws_url + "receive_and_send_text_once_without_closing", ) ) + # make sure ws is connected + msg = "foo" + await ws.send_text(msg) + await ws.receive_text() + # use queue to notify the connection established queue.put("done") queue.close() while True: # run forever - await asyncio.sleep(0.1) + await anyio.sleep(0.1) - asyncio.run(run()) + # It's not proxy app server for which we test, so it's ok to not use trio backend + anyio.run(run) class TestReverseWsProxy(AbstractTestProxy): """For testing reverse websocket proxy.""" @override - @pytest.fixture(params=WS_BACKENDS_NEED_BE_TESTED) + @pytest.fixture() async def tool_4_test_fixture( # pyright: ignore[reportIncompatibleMethodOverride] self, - uvicorn_server_fixture: UvicornServerFixture, - request: pytest.FixtureRequest, - ) -> Tool4TestFixture: + auto_server_fixture: AutoServerFixture, + ) -> Tool4ServerTestFixture: """目标服务器请参考`tests.app.echo_ws_app.get_app`.""" echo_ws_test_model = get_ws_test_app() echo_ws_app = echo_ws_test_model.app echo_ws_get_request = echo_ws_test_model.get_request - target_ws_server = await uvicorn_server_fixture( - uvicorn.Config( - echo_ws_app, port=DEFAULT_PORT, host=DEFAULT_HOST, ws=request.param - ), - contx_exit_timeout=DEFAULT_CONTX_EXIT_TIMEOUT, + target_ws_server = await auto_server_fixture( + app=echo_ws_app, port=DEFAULT_PORT, host=DEFAULT_HOST ) target_server_base_url = str(target_ws_server.contx_socket_url) - client_for_conn_to_target_server = httpx.AsyncClient(proxies=NO_PROXIES) + client_for_conn_to_target_server = httpx.AsyncClient(mounts=NO_PROXIES) reverse_ws_app = get_reverse_ws_app( client=client_for_conn_to_target_server, base_url=target_server_base_url ) - proxy_ws_server = await uvicorn_server_fixture( - uvicorn.Config( - reverse_ws_app, port=DEFAULT_PORT, host=DEFAULT_HOST, ws=request.param - ), - contx_exit_timeout=DEFAULT_CONTX_EXIT_TIMEOUT, + proxy_ws_server = await auto_server_fixture( + app=reverse_ws_app, port=DEFAULT_PORT, host=DEFAULT_HOST ) proxy_server_base_url = str(proxy_ws_server.contx_socket_url) - client_for_conn_to_proxy_server = httpx.AsyncClient(proxies=NO_PROXIES) + client_for_conn_to_proxy_server = httpx.AsyncClient(mounts=NO_PROXIES) - return Tool4TestFixture( + return Tool4ServerTestFixture( client_for_conn_to_target_server=client_for_conn_to_target_server, client_for_conn_to_proxy_server=client_for_conn_to_proxy_server, get_request=echo_ws_get_request, target_server_base_url=target_server_base_url, proxy_server_base_url=proxy_server_base_url, + target_server=target_ws_server, + proxy_server=proxy_ws_server, ) @pytest.mark.anyio() - async def test_ws_proxy(self, tool_4_test_fixture: Tool4TestFixture) -> None: + async def test_ws_proxy( # noqa: PLR0915 + self, tool_4_test_fixture: Tool4ServerTestFixture + ) -> None: """测试websocket代理.""" proxy_server_base_url = tool_4_test_fixture.proxy_server_base_url client_for_conn_to_proxy_server = ( @@ -159,6 +152,9 @@ async def test_ws_proxy(self, tool_4_test_fixture: Tool4TestFixture) -> None: ) get_request = tool_4_test_fixture.get_request + target_server = tool_4_test_fixture.target_server + proxy_server = tool_4_test_fixture.proxy_server + ########## 测试数据的正常转发 ########## async with aconnect_ws( @@ -176,21 +172,58 @@ async def test_ws_proxy(self, tool_4_test_fixture: Tool4TestFixture) -> None: ########## 测试子协议 ########## async with aconnect_ws( - proxy_server_base_url + "accept_foo_subprotocol", + proxy_server_base_url + "accept_foo_subprotocol_and_foo_bar_header", client_for_conn_to_proxy_server, subprotocols=["foo", "bar"], ) as ws: assert ws.subprotocol == "foo" + assert ws.response is not None + assert ws.response.headers["foo"] == "bar" - ########## 关闭代码 ########## + ########## 客户端发送关闭代码 ########## + code = 1003 + reason = "foo" async with aconnect_ws( - proxy_server_base_url + "just_close_with_1001", + proxy_server_base_url + "receive_and_send_text_once_without_closing", + client_for_conn_to_proxy_server, + ) as ws: + await ws.send_text("foo") + await ws.receive_text() + await ws.close(code=code, reason=reason) + + target_starlette_ws = get_request() + assert isinstance(target_starlette_ws, starlette_websockets_module.WebSocket) + with pytest.raises(starlette_websockets_module.WebSocketDisconnect) as exce: + await target_starlette_ws.receive_text() + + closing_event = target_starlette_ws.state.closing + assert isinstance(closing_event, anyio.Event) + closing_event.set() + + # XXX, HACK, TODO: + # hypercorn can't receive correctly close code, it always receive 1006 + # https://github.com/pgjones/hypercorn/issues/127 + # so we only test close code for uvicorn + if ( + target_server.server_type == "uvicorn" + and proxy_server.server_type == "uvicorn" + ): + assert exce.value.code == code + # XXX, HACK, TODO: + # reaseon are wrong, httpx-ws can't send close reason correctly + # assert exce.value.reason == reason + + ########## 服务端发送关闭代码 ########## + + async with aconnect_ws( + proxy_server_base_url + "just_close_with_1002_and_foo", client_for_conn_to_proxy_server, ) as ws: with pytest.raises(httpx_ws.WebSocketDisconnect) as exce: await ws.receive_text() - assert exce.value.code == 1001 + assert exce.value.code == 1002 + assert exce.value.reason == "foo" ########## 协议升级失败或者连接失败 ########## @@ -200,33 +233,42 @@ async def test_ws_proxy(self, tool_4_test_fixture: Tool4TestFixture) -> None: client_for_conn_to_proxy_server, ) as ws: pass - # uvicorn 服务器在未调用`websocket.accept()`之前调用了`websocket.close()`,会发生403 + # Starlette 在未调用`websocket.accept()`之前调用了`websocket.close()`,会发生403 assert exce.value.response.status_code == 403 + ########## test denial response ########## + + with pytest.raises(httpx_ws.WebSocketUpgradeError) as exce: + async with aconnect_ws( + proxy_server_base_url + + "send_denial_response_400_foo_bar_header_and_json_body", + client_for_conn_to_proxy_server, + ) as ws: + pass + # Starlette 在未调用`websocket.accept()`之前调用了`websocket.close()`,会发生403 + assert exce.value.response.status_code == 400 + assert exce.value.response.headers["foo"] == "bar" + # XXX, HACK, TODO: Unable to read the content of WebSocketUpgradeError.response + # See: https://github.com/frankie567/httpx-ws/discussions/69 + # assert exce.value.response.json() == {"foo": "bar"} + ########## 客户端突然关闭时,服务器应该收到1011 ########## # NOTE: 这个测试不放在 `test_target_server_shutdown_abnormally` 来做 # 是因为这里已经有现成的target server,放在这里测试可以节省启动服务器时间 aconnect_ws_subprocess_queue: "Queue[str]" = Queue() - - kwargs_async_client = {"proxies": NO_PROXIES} - kwargs_aconnect_ws = {"url": proxy_server_base_url + "do_nothing"} - kwargs = { - "kwargs_async_client": kwargs_async_client, - "kwargs_aconnect_ws": kwargs_aconnect_ws, - } + aconnect_ws_url = proxy_server_base_url aconnect_ws_subprocess = Process( target=_subprocess_run_httpx_ws, - args=(aconnect_ws_subprocess_queue,), - kwargs=kwargs, + args=(aconnect_ws_subprocess_queue, aconnect_ws_url), ) aconnect_ws_subprocess.start() # 避免从队列中get导致的异步阻塞 while aconnect_ws_subprocess_queue.empty(): - await asyncio.sleep(0.1) + await anyio.sleep(0.1) _ = aconnect_ws_subprocess_queue.get() # 获取到了即代表连接建立成功 # force shutdown client @@ -240,18 +282,20 @@ async def test_ws_proxy(self, tool_4_test_fixture: Tool4TestFixture) -> None: with pytest.raises(starlette_websockets_module.WebSocketDisconnect) as exce: await target_starlette_ws.receive_text() # receive_bytes() 也可以 + closing_event = target_starlette_ws.state.closing + assert isinstance(closing_event, anyio.Event) + closing_event.set() + # assert exce.value.code == 1011 # HACK, FIXME: 无法测试错误代码,似乎无法正常传递,且不同后端也不同 # FAILED test_ws_proxy[websockets] - assert 1005 == 1011 # FAILED test_ws_proxy[wsproto] - assert == 1011 + # NOTE: the close code for abnormal close is undefined behavior, so we won't test this # FIXME: 调查为什么收到关闭代码需要40s @pytest.mark.timeout(60) @pytest.mark.anyio() - @pytest.mark.parametrize("ws_backend", WS_BACKENDS_NEED_BE_TESTED) - async def test_target_server_shutdown_abnormally( - self, ws_backend: Literal["websockets", "wsproto"] - ) -> None: + async def test_target_server_shutdown_abnormally(self) -> None: """测试因为目标服务器突然断连导致的,ws桥接异常关闭. 需要在 60s 内向客户端发送 1011 关闭代码. @@ -259,37 +303,43 @@ async def test_target_server_shutdown_abnormally( subprocess_queue: "Queue[str]" = Queue() target_ws_server_subprocess = Process( - target=_subprocess_run_echo_ws_uvicorn_server, + target=_subprocess_run_echo_ws_server, args=(subprocess_queue,), - kwargs={"port": DEFAULT_PORT, "host": DEFAULT_HOST, "ws": ws_backend}, ) target_ws_server_subprocess.start() # 避免从队列中get导致的异步阻塞 while subprocess_queue.empty(): - await asyncio.sleep(0.1) + await anyio.sleep(0.1) target_server_base_url = subprocess_queue.get() - client_for_conn_to_target_server = httpx.AsyncClient(proxies=NO_PROXIES) + client_for_conn_to_target_server = httpx.AsyncClient(mounts=NO_PROXIES) reverse_ws_app = get_reverse_ws_app( client=client_for_conn_to_target_server, base_url=target_server_base_url ) - async with UvicornServer( - uvicorn.Config( - reverse_ws_app, port=DEFAULT_PORT, host=DEFAULT_HOST, ws=ws_backend - ) + async with AutoServer( + app=reverse_ws_app, + port=DEFAULT_PORT, + host=DEFAULT_HOST, ) as proxy_ws_server: proxy_server_base_url = str(proxy_ws_server.contx_socket_url) async with aconnect_ws( - proxy_server_base_url + "do_nothing", - httpx.AsyncClient(proxies=NO_PROXIES), + proxy_server_base_url + "echo_text", + httpx.AsyncClient(mounts=NO_PROXIES), ) as ws0, aconnect_ws( - proxy_server_base_url + "do_nothing", - httpx.AsyncClient(proxies=NO_PROXIES), + proxy_server_base_url + "echo_text", + httpx.AsyncClient(mounts=NO_PROXIES), ) as ws1: + # make sure ws is connected + msg = "foo" + await ws0.send_text(msg) + assert msg == await ws0.receive_text() + await ws1.send_text(msg) + assert msg == await ws1.receive_text() + # force shutdown target server target_ws_server_subprocess.terminate() target_ws_server_subprocess.kill() @@ -300,16 +350,16 @@ async def test_target_server_shutdown_abnormally( await ws0.receive() assert exce.value.code == 1011 - loop = asyncio.get_running_loop() - - seconde_ws_recv_start = loop.time() + seconde_ws_recv_start = anyio.current_time() with pytest.raises(httpx_ws.WebSocketDisconnect) as exce: await ws1.receive() assert exce.value.code == 1011 - seconde_ws_recv_end = loop.time() + seconde_ws_recv_end = anyio.current_time() # HACK: 由于收到关闭代码需要40s,目前无法确定是什么原因, # 所以目前会同时测试两个客户端的连接, # 只要第二个客户端不是在之前40s基础上又重复40s,就暂时没问题, # 因为这模拟了多个客户端进行连接的情况。 assert (seconde_ws_recv_end - seconde_ws_recv_start) < 2 + + # NOTE: the close code for abnormal close is undefined behavior, so we won't test this