Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support anyio, sending denial response, handshake headers #34

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Open
2 changes: 1 addition & 1 deletion docs/Usage/FastAPI-Helper.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ app = reverse_http_app(client=client, base_url=base_url)
```

1. You can pass `httpx.AsyncClient` instance:
- if you want to customize the arguments, e.g. `httpx.AsyncClient(proxies={})`
- if you want to customize the arguments, e.g. `httpx.AsyncClient(http2=True)`
- if you want to reuse the connection pool of `httpx.AsyncClient`
---
Or you can pass `None`(The default value), then `fastapi-proxy-lib` will create a new `httpx.AsyncClient` instance for you.
Expand Down
13 changes: 8 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,11 @@ dynamic = ["version"]

dependencies = [
"httpx",
"httpx-ws >= 0.4.2",
"starlette",
"httpx-ws >= 0.6.0",
"starlette >= 0.37.2",
"typing_extensions >=4.5.0",
"anyio >= 4",
"exceptiongroup",
]

[project.optional-dependencies]
Expand Down Expand Up @@ -96,10 +98,11 @@ dependencies = [
"pytest == 7.*",
"pytest-cov == 4.*",
"uvicorn[standard] < 1.0.0", # TODO: Once it releases version 1.0.0, we will remove this restriction.
"hypercorn[trio] == 0.16.*",
"httpx[http2]", # we don't set version here, instead set it in `[project].dependencies`.
"anyio", # we don't set version here, because fastapi has a dependency on it
"asgi-lifespan==2.*",
"pytest-timeout==2.*",
"asgi-lifespan == 2.*",
"pytest-timeout == 2.*",
"sniffio == 1.3.*",
]

[tool.hatch.envs.default.scripts]
Expand Down
42 changes: 2 additions & 40 deletions src/fastapi_proxy_lib/core/_tool.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
"""The utils tools for both http proxy and websocket proxy."""

import ipaddress
import logging
import warnings
from functools import lru_cache
from textwrap import dedent
from typing import (
Any,
Iterable,
Mapping,
Optional,
Protocol,
Expand All @@ -17,7 +15,6 @@
)

import httpx
from starlette import status
from starlette.background import BackgroundTask as BackgroundTask_t
from starlette.datastructures import (
Headers as StarletteHeaders,
Expand All @@ -26,13 +23,11 @@
MutableHeaders as StarletteMutableHeaders,
)
from starlette.responses import JSONResponse
from starlette.types import Scope
from typing_extensions import deprecated, overload

__all__ = (
"check_base_url",
"return_err_msg_response",
"check_http_version",
"BaseURLError",
"ErrMsg",
"ErrRseponseJson",
Expand Down Expand Up @@ -129,10 +124,6 @@ class _RejectedProxyRequestError(RuntimeError):
"""Should be raised when reject proxy request."""


class _UnsupportedHttpVersionError(RuntimeError):
"""Unsupported http version."""


#################### Tools ####################


Expand Down Expand Up @@ -309,8 +300,8 @@ def return_err_msg_response(
err_response_json = ErrRseponseJson(detail=detail)

# TODO: 请注意,logging是同步函数,每次会阻塞1ms左右,这可能会导致性能问题
# 特别是对于写入文件的log,最好把它放到 asyncio.to_thread 里执行
# https://docs.python.org/zh-cn/3/library/asyncio-task.html#coroutine
# 特别是对于写入文件的log,最好把它放到 `anyio.to_thread.run_sync()` 里执行
# https://anyio.readthedocs.io/en/stable/threads.html#running-a-function-in-a-worker-thread

if logger is not None:
# 只要传入了logger,就一定记录日志
Expand All @@ -337,35 +328,6 @@ def return_err_msg_response(
)


def check_http_version(
scope: Scope, supported_versions: Iterable[str]
) -> Union[JSONResponse, None]:
"""Check whether the http version of scope is in supported_versions.

Args:
scope: asgi scope dict.
supported_versions: The supported http versions.

Returns:
If the http version of scope is not in supported_versions, return a JSONResponse with status_code=505,
else return None.
"""
# https://asgi.readthedocs.io/en/latest/specs/www.html#http-connection-scope
# https://asgi.readthedocs.io/en/latest/specs/www.html#websocket-connection-scope
http_version: str = scope.get("http_version", "")
# 如果明确指定了http版本(即不是""),但不在支持的版本内,则返回505
if http_version not in supported_versions and http_version != "":
error = _UnsupportedHttpVersionError(
f"The request http version is {http_version}, but we only support {supported_versions}."
)
# TODO: 或许可以logging记录下 scope.get("client") 的值
return return_err_msg_response(
error,
status_code=status.HTTP_505_HTTP_VERSION_NOT_SUPPORTED,
logger=logging.info,
)


def default_proxy_filter(url: httpx.URL) -> Union[None, str]:
"""Filter by host.

Expand Down
41 changes: 8 additions & 33 deletions src/fastapi_proxy_lib/core/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
_RejectedProxyRequestError, # pyright: ignore [reportPrivateUsage] # 允许使用本项目内部的私有成员
change_necessary_client_header_for_httpx,
check_base_url,
check_http_version,
return_err_msg_response,
warn_for_none_filter,
)
Expand Down Expand Up @@ -81,10 +80,6 @@ class _ReverseProxyServerError(RuntimeError):
_NON_REQUEST_BODY_METHODS = ("GET", "HEAD", "OPTIONS", "TRACE")
"""The http methods that should not contain request body."""

# https://asgi.readthedocs.io/en/latest/specs/www.html#http-connection-scope
SUPPORTED_HTTP_VERSIONS = ("1.0", "1.1")
"""The http versions that we supported now. It depends on `httpx`."""

# https://www.python-httpx.org/exceptions/
_400_ERROR_NEED_TO_BE_CATCHED_IN_FORWARD_PROXY = (
httpx.InvalidURL, # 解析url时出错
Expand Down Expand Up @@ -227,8 +222,6 @@ async def send_request_to_target( # pyright: ignore [reportIncompatibleMethodOv
) -> StarletteResponse:
"""Change request headers and send request to target url.

- The http version of request must be in [`SUPPORTED_HTTP_VERSIONS`][fastapi_proxy_lib.core.http.SUPPORTED_HTTP_VERSIONS].

Args:
request: the original client request.
target_url: target url that request will be sent to.
Expand All @@ -239,10 +232,6 @@ async def send_request_to_target( # pyright: ignore [reportIncompatibleMethodOv
client = self.client
follow_redirects = self.follow_redirects

check_result = check_http_version(request.scope, SUPPORTED_HTTP_VERSIONS)
if check_result is not None:
return check_result

# 将请求头中的host字段改为目标url的host
# 同时强制移除"keep-alive"字段和添加"keep-alive"值到"connection"字段中保持连接
require_close, proxy_header = _change_client_header(
Expand Down Expand Up @@ -338,8 +327,8 @@ async def close_proxy_event(_: FastAPI) -> AsyncIterator[None]: # (1)!
app = FastAPI(lifespan=close_proxy_event)

@app.get("/{path:path}") # (2)!
async def _(request: Request, path: str = ""):
return await proxy.proxy(request=request, path=path) # (3)!
async def _(request: Request):
return await proxy.proxy(request=request) # (3)!

# Then run shell: `uvicorn <your.py>:app --host http://127.0.0.1:8000 --port 8000`
# visit the app: `http://127.0.0.1:8000/`
Expand All @@ -350,10 +339,6 @@ async def _(request: Request, path: str = ""):
2. `{path:path}` is the key.<br>
It allows the app to accept all path parameters.<br>
visit <https://www.starlette.io/routing/#path-parameters> for more info.
3. !!! info
In fact, you only need to pass the `request: Request` argument.<br>
`fastapi_proxy_lib` can automatically get the `path` from `request`.<br>
Explicitly pointing it out here is just to remind you not to forget to specify `{path:path}`.
'''

client: httpx.AsyncClient
Expand Down Expand Up @@ -387,25 +372,20 @@ def __init__(

@override
async def proxy( # pyright: ignore [reportIncompatibleMethodOverride]
self, *, request: StarletteRequest, path: Optional[str] = None
self, *, request: StarletteRequest
) -> StarletteResponse:
"""Send request to target server.

Args:
request: `starlette.requests.Request`
path: The path params of request, which means the path params of base url.<br>
If None, will get it from `request.path_params`.<br>
**Usually, you don't need to pass this argument**.

Returns:
The response from target server.
"""
base_url = self.base_url

# 只取第一个路径参数。注意,我们允许没有路径参数,这代表直接请求
path_param: str = (
path if path is not None else next(iter(request.path_params.values()), "")
)
path_param: str = next(iter(request.path_params.values()), "")

# 将路径参数拼接到目标url上
# e.g: "https://www.example.com/p0/" + "p1"
Expand Down Expand Up @@ -473,8 +453,8 @@ async def close_proxy_event(_: FastAPI) -> AsyncIterator[None]:
app = FastAPI(lifespan=close_proxy_event)

@app.get("/{path:path}")
async def _(request: Request, path: str = ""):
return await proxy.proxy(request=request, path=path)
async def _(request: Request):
return await proxy.proxy(request=request)

# Then run shell: `uvicorn <your.py>:app --host http://127.0.0.1:8000 --port 8000`
# visit the app: `http://127.0.0.1:8000/http://www.example.com`
Expand Down Expand Up @@ -513,25 +493,20 @@ async def proxy( # pyright: ignore [reportIncompatibleMethodOverride]
self,
*,
request: StarletteRequest,
path: Optional[str] = None,
) -> StarletteResponse:
"""Send request to target server.

Args:
request: `starlette.requests.Request`
path: The path params of request, which means the full url of target server.<br>
If None, will get it from `request.path_params`.<br>
**Usually, you don't need to pass this argument**.

Returns:
The response from target server.
"""
proxy_filter = self.proxy_filter

# 只取第一个路径参数
path_param: str = (
next(iter(request.path_params.values()), "") if path is None else path
)
path_param: str = next(iter(request.path_params.values()), "")

# 如果没有路径参数,即在正向代理中未指定目标url,则返回400
if path_param == "":
error = _BadTargetUrlError("Must provide target url.")
Expand Down
Loading
Loading