-
Notifications
You must be signed in to change notification settings - Fork 105
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open sockets outside of connection objects #275
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,54 +1,51 @@ | ||
from ssl import SSLContext | ||
from typing import Optional, Tuple, cast | ||
from typing import Optional, Tuple | ||
|
||
from .._backends.auto import AsyncBackend, AsyncLock, AsyncSocketStream, AutoBackend | ||
from .._exceptions import ConnectError, ConnectTimeout | ||
from .._backends.auto import AsyncBackend, AsyncSocketStream, AutoBackend | ||
from .._types import URL, Headers, Origin, TimeoutDict | ||
from .._utils import exponential_backoff, get_logger, url_to_origin | ||
from .base import ( | ||
AsyncByteStream, | ||
AsyncHTTPTransport, | ||
ConnectionState, | ||
NewConnectionRequired, | ||
) | ||
from .._utils import get_logger, url_to_origin | ||
from .base import AsyncByteStream, AsyncHTTPTransport, ConnectionState | ||
from .http import AsyncBaseHTTPConnection | ||
from .http11 import AsyncHTTP11Connection | ||
|
||
logger = get_logger(__name__) | ||
|
||
RETRIES_BACKOFF_FACTOR = 0.5 # 0s, 0.5s, 1s, 2s, 4s, etc. | ||
|
||
|
||
class AsyncHTTPConnection(AsyncHTTPTransport): | ||
def __init__( | ||
self, | ||
origin: Origin, | ||
http2: bool = False, | ||
uds: str = None, | ||
socket: AsyncSocketStream, | ||
ssl_context: SSLContext = None, | ||
socket: AsyncSocketStream = None, | ||
local_address: str = None, | ||
retries: int = 0, | ||
backend: AsyncBackend = None, | ||
): | ||
self.origin = origin | ||
self.http2 = http2 | ||
self.uds = uds | ||
self.ssl_context = SSLContext() if ssl_context is None else ssl_context | ||
self.socket = socket | ||
self.local_address = local_address | ||
self.retries = retries | ||
|
||
if self.http2: | ||
self.ssl_context.set_alpn_protocols(["http/1.1", "h2"]) | ||
self.ssl_context = SSLContext() if ssl_context is None else ssl_context | ||
|
||
self.connection: Optional[AsyncBaseHTTPConnection] = None | ||
self.is_http11 = False | ||
self.is_http2 = False | ||
self.connect_failed = False | ||
self.expires_at: Optional[float] = None | ||
self.backend = AutoBackend() if backend is None else backend | ||
|
||
self.connection: AsyncBaseHTTPConnection | ||
http_version = self.socket.get_http_version() | ||
logger.trace( | ||
"create_connection socket=%r http_version=%r", self.socket, http_version | ||
) | ||
if http_version == "HTTP/2": | ||
from .http2 import AsyncHTTP2Connection | ||
|
||
self.is_http2 = True | ||
self.connection = AsyncHTTP2Connection( | ||
socket=self.socket, backend=self.backend, ssl_context=self.ssl_context | ||
) | ||
else: | ||
self.is_http11 = True | ||
self.connection = AsyncHTTP11Connection( | ||
socket=self.socket, ssl_context=self.ssl_context | ||
) | ||
|
||
def __repr__(self) -> str: | ||
http_version = "UNKNOWN" | ||
if self.is_http11: | ||
|
@@ -58,20 +55,10 @@ def __repr__(self) -> str: | |
return f"<AsyncHTTPConnection http_version={http_version} state={self.state}>" | ||
|
||
def info(self) -> str: | ||
if self.connection is None: | ||
return "Not connected" | ||
elif self.state == ConnectionState.PENDING: | ||
if self.state == ConnectionState.PENDING: | ||
return "Connecting" | ||
return self.connection.info() | ||
|
||
@property | ||
def request_lock(self) -> AsyncLock: | ||
# We do this lazily, to make sure backend autodetection always | ||
# runs within an async context. | ||
if not hasattr(self, "_request_lock"): | ||
self._request_lock = self.backend.create_lock() | ||
return self._request_lock | ||
|
||
async def arequest( | ||
self, | ||
method: bytes, | ||
|
@@ -81,103 +68,25 @@ async def arequest( | |
ext: dict = None, | ||
) -> Tuple[int, Headers, AsyncByteStream, dict]: | ||
assert url_to_origin(url) == self.origin | ||
ext = {} if ext is None else ext | ||
timeout = cast(TimeoutDict, ext.get("timeout", {})) | ||
|
||
async with self.request_lock: | ||
if self.state == ConnectionState.PENDING: | ||
if not self.socket: | ||
logger.trace( | ||
"open_socket origin=%r timeout=%r", self.origin, timeout | ||
) | ||
self.socket = await self._open_socket(timeout) | ||
self._create_connection(self.socket) | ||
elif self.state in (ConnectionState.READY, ConnectionState.IDLE): | ||
pass | ||
elif self.state == ConnectionState.ACTIVE and self.is_http2: | ||
pass | ||
else: | ||
raise NewConnectionRequired() | ||
Comment on lines
-87
to
-100
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure about why this There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I was also unsure what this was for, a quick glance at the repo and I couldn't find any location where it was caught to be handled. |
||
|
||
assert self.connection is not None | ||
logger.trace( | ||
"connection.arequest method=%r url=%r headers=%r", method, url, headers | ||
) | ||
return await self.connection.arequest(method, url, headers, stream, ext) | ||
|
||
async def _open_socket(self, timeout: TimeoutDict = None) -> AsyncSocketStream: | ||
scheme, hostname, port = self.origin | ||
timeout = {} if timeout is None else timeout | ||
ssl_context = self.ssl_context if scheme == b"https" else None | ||
|
||
retries_left = self.retries | ||
delays = exponential_backoff(factor=RETRIES_BACKOFF_FACTOR) | ||
|
||
while True: | ||
try: | ||
if self.uds is None: | ||
return await self.backend.open_tcp_stream( | ||
hostname, | ||
port, | ||
ssl_context, | ||
timeout, | ||
local_address=self.local_address, | ||
) | ||
else: | ||
return await self.backend.open_uds_stream( | ||
self.uds, hostname, ssl_context, timeout | ||
) | ||
except (ConnectError, ConnectTimeout): | ||
if retries_left <= 0: | ||
self.connect_failed = True | ||
raise | ||
retries_left -= 1 | ||
delay = next(delays) | ||
await self.backend.sleep(delay) | ||
except Exception: # noqa: PIE786 | ||
self.connect_failed = True | ||
raise | ||
|
||
def _create_connection(self, socket: AsyncSocketStream) -> None: | ||
http_version = socket.get_http_version() | ||
logger.trace( | ||
"create_connection socket=%r http_version=%r", socket, http_version | ||
) | ||
if http_version == "HTTP/2": | ||
from .http2 import AsyncHTTP2Connection | ||
|
||
self.is_http2 = True | ||
self.connection = AsyncHTTP2Connection( | ||
socket=socket, backend=self.backend, ssl_context=self.ssl_context | ||
) | ||
else: | ||
self.is_http11 = True | ||
self.connection = AsyncHTTP11Connection( | ||
socket=socket, ssl_context=self.ssl_context | ||
) | ||
|
||
@property | ||
def state(self) -> ConnectionState: | ||
if self.connect_failed: | ||
return ConnectionState.CLOSED | ||
elif self.connection is None: | ||
return ConnectionState.PENDING | ||
return self.connection.get_state() | ||
|
||
def is_socket_readable(self) -> bool: | ||
return self.connection is not None and self.connection.is_socket_readable() | ||
return self.connection.is_socket_readable() | ||
|
||
def mark_as_ready(self) -> None: | ||
if self.connection is not None: | ||
self.connection.mark_as_ready() | ||
self.connection.mark_as_ready() | ||
|
||
async def start_tls(self, hostname: bytes, timeout: TimeoutDict = None) -> None: | ||
if self.connection is not None: | ||
logger.trace("start_tls hostname=%r timeout=%r", hostname, timeout) | ||
self.socket = await self.connection.start_tls(hostname, timeout) | ||
logger.trace("start_tls complete hostname=%r timeout=%r", hostname, timeout) | ||
logger.trace("start_tls hostname=%r timeout=%r", hostname, timeout) | ||
self.socket = await self.connection.start_tls(hostname, timeout) | ||
logger.trace("start_tls complete hostname=%r timeout=%r", hostname, timeout) | ||
|
||
async def aclose(self) -> None: | ||
async with self.request_lock: | ||
if self.connection is not None: | ||
await self.connection.aclose() | ||
await self.connection.aclose() |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,12 +12,23 @@ | |
cast, | ||
) | ||
|
||
from .._backends.auto import AsyncBackend, AsyncLock, AsyncSemaphore | ||
from .._backends.auto import AsyncBackend, AsyncLock, AsyncSemaphore, AsyncSocketStream | ||
from .._backends.base import lookup_async_backend | ||
from .._exceptions import LocalProtocolError, PoolTimeout, UnsupportedProtocol | ||
from .._exceptions import ( | ||
ConnectError, | ||
ConnectTimeout, | ||
LocalProtocolError, | ||
PoolTimeout, | ||
UnsupportedProtocol, | ||
) | ||
from .._threadlock import ThreadLock | ||
from .._types import URL, Headers, Origin, TimeoutDict | ||
from .._utils import get_logger, origin_to_url_string, url_to_origin | ||
from .._utils import ( | ||
exponential_backoff, | ||
get_logger, | ||
origin_to_url_string, | ||
url_to_origin, | ||
) | ||
from .base import ( | ||
AsyncByteStream, | ||
AsyncHTTPTransport, | ||
|
@@ -28,6 +39,8 @@ | |
|
||
logger = get_logger(__name__) | ||
|
||
RETRIES_BACKOFF_FACTOR = 0.5 # 0s, 0.5s, 1s, 2s, 4s, etc. | ||
|
||
|
||
class NullSemaphore(AsyncSemaphore): | ||
def __init__(self) -> None: | ||
|
@@ -144,6 +157,8 @@ def __init__( | |
"package is not installed. Use 'pip install httpcore[http2]'." | ||
) | ||
|
||
self._ssl_context.set_alpn_protocols(["http/1.1", "h2"]) | ||
|
||
@property | ||
def _connection_semaphore(self) -> AsyncSemaphore: | ||
# We do this lazily, to make sure backend autodetection always | ||
|
@@ -167,17 +182,50 @@ def _connection_acquiry_lock(self) -> AsyncLock: | |
def _create_connection( | ||
self, | ||
origin: Tuple[bytes, bytes, int], | ||
socket: AsyncSocketStream, | ||
) -> AsyncHTTPConnection: | ||
return AsyncHTTPConnection( | ||
origin=origin, | ||
http2=self._http2, | ||
uds=self._uds, | ||
socket=socket, | ||
ssl_context=self._ssl_context, | ||
local_address=self._local_address, | ||
retries=self._retries, | ||
backend=self._backend, | ||
) | ||
|
||
async def _open_socket( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Custom pool implementations would be able to override this method (along with a strict package pin) if they wanted to customize how sockets are opened. Standalone implementations (single connection w/o pooling) could copy this code and use them in their own There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. While you can always have your own implementation you also need to implement the underlying connection code which is not impossible but it would be nice to utilise the existing connection codes in httpcore. By reimplementing all the code you
I totally understand this project is in it's infancy so you may not be prepared to expose the existing classes, I just wanted to point out something I wish to see in the future :) |
||
self, origin: Origin, timeout: TimeoutDict = None | ||
) -> AsyncSocketStream: | ||
scheme, hostname, port = origin | ||
timeout = {} if timeout is None else timeout | ||
ssl_context = self._ssl_context if scheme == b"https" else None | ||
|
||
logger.trace("open_socket origin=%r timeout=%r", origin, timeout) | ||
|
||
retries_left = self._retries | ||
delays = exponential_backoff(factor=RETRIES_BACKOFF_FACTOR) | ||
|
||
while True: | ||
try: | ||
if self._uds is None: | ||
return await self._backend.open_tcp_stream( | ||
hostname, | ||
port, | ||
ssl_context, | ||
timeout, | ||
local_address=self._local_address, | ||
) | ||
else: | ||
return await self._backend.open_uds_stream( | ||
self._uds, hostname, ssl_context, timeout | ||
) | ||
except (ConnectError, ConnectTimeout): | ||
if retries_left <= 0: | ||
raise | ||
retries_left -= 1 | ||
delay = next(delays) | ||
await self._backend.sleep(delay) | ||
except Exception: # noqa: PIE786 | ||
raise | ||
|
||
async def arequest( | ||
self, | ||
method: bytes, | ||
|
@@ -208,7 +256,8 @@ async def arequest( | |
connection = await self._get_connection_from_pool(origin) | ||
|
||
if connection is None: | ||
connection = self._create_connection(origin=origin) | ||
socket = await self._open_socket(origin, timeout=timeout) | ||
connection = self._create_connection(origin=origin, socket=socket) | ||
logger.trace("created connection=%r", connection) | ||
await self._add_to_pool(connection, timeout=timeout) | ||
else: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the change from which all others derive:
socket
becomes required, and will be passed by whoever wants to interact with an HTTP connection (in our case, the connection pool).