Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open sockets outside of connection objects #275

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 31 additions & 122 deletions httpcore/_async/connection.py
Original file line number Diff line number Diff line change
@@ -1,54 +1,51 @@
from ssl import SSLContext
from typing import Optional, Tuple, cast
from typing import Optional, Tuple

from .._backends.auto import AsyncBackend, AsyncLock, AsyncSocketStream, AutoBackend
from .._exceptions import ConnectError, ConnectTimeout
from .._backends.auto import AsyncBackend, AsyncSocketStream, AutoBackend
from .._types import URL, Headers, Origin, TimeoutDict
from .._utils import exponential_backoff, get_logger, url_to_origin
from .base import (
AsyncByteStream,
AsyncHTTPTransport,
ConnectionState,
NewConnectionRequired,
)
from .._utils import get_logger, url_to_origin
from .base import AsyncByteStream, AsyncHTTPTransport, ConnectionState
from .http import AsyncBaseHTTPConnection
from .http11 import AsyncHTTP11Connection

logger = get_logger(__name__)

RETRIES_BACKOFF_FACTOR = 0.5 # 0s, 0.5s, 1s, 2s, 4s, etc.


class AsyncHTTPConnection(AsyncHTTPTransport):
def __init__(
self,
origin: Origin,
http2: bool = False,
uds: str = None,
socket: AsyncSocketStream,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the change from which all others derive: socket becomes required, and will be passed by whoever wants to interact with an HTTP connection (in our case, the connection pool).

ssl_context: SSLContext = None,
socket: AsyncSocketStream = None,
local_address: str = None,
retries: int = 0,
backend: AsyncBackend = None,
):
self.origin = origin
self.http2 = http2
self.uds = uds
self.ssl_context = SSLContext() if ssl_context is None else ssl_context
self.socket = socket
self.local_address = local_address
self.retries = retries

if self.http2:
self.ssl_context.set_alpn_protocols(["http/1.1", "h2"])
self.ssl_context = SSLContext() if ssl_context is None else ssl_context

self.connection: Optional[AsyncBaseHTTPConnection] = None
self.is_http11 = False
self.is_http2 = False
self.connect_failed = False
self.expires_at: Optional[float] = None
self.backend = AutoBackend() if backend is None else backend

self.connection: AsyncBaseHTTPConnection
http_version = self.socket.get_http_version()
logger.trace(
"create_connection socket=%r http_version=%r", self.socket, http_version
)
if http_version == "HTTP/2":
from .http2 import AsyncHTTP2Connection

self.is_http2 = True
self.connection = AsyncHTTP2Connection(
socket=self.socket, backend=self.backend, ssl_context=self.ssl_context
)
else:
self.is_http11 = True
self.connection = AsyncHTTP11Connection(
socket=self.socket, ssl_context=self.ssl_context
)

def __repr__(self) -> str:
http_version = "UNKNOWN"
if self.is_http11:
Expand All @@ -58,20 +55,10 @@ def __repr__(self) -> str:
return f"<AsyncHTTPConnection http_version={http_version} state={self.state}>"

def info(self) -> str:
if self.connection is None:
return "Not connected"
elif self.state == ConnectionState.PENDING:
if self.state == ConnectionState.PENDING:
return "Connecting"
return self.connection.info()

@property
def request_lock(self) -> AsyncLock:
# We do this lazily, to make sure backend autodetection always
# runs within an async context.
if not hasattr(self, "_request_lock"):
self._request_lock = self.backend.create_lock()
return self._request_lock

async def arequest(
self,
method: bytes,
Expand All @@ -81,103 +68,25 @@ async def arequest(
ext: dict = None,
) -> Tuple[int, Headers, AsyncByteStream, dict]:
assert url_to_origin(url) == self.origin
ext = {} if ext is None else ext
timeout = cast(TimeoutDict, ext.get("timeout", {}))

async with self.request_lock:
if self.state == ConnectionState.PENDING:
if not self.socket:
logger.trace(
"open_socket origin=%r timeout=%r", self.origin, timeout
)
self.socket = await self._open_socket(timeout)
self._create_connection(self.socket)
elif self.state in (ConnectionState.READY, ConnectionState.IDLE):
pass
elif self.state == ConnectionState.ACTIVE and self.is_http2:
pass
else:
raise NewConnectionRequired()
Comment on lines -87 to -100
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure about why this NewConnectionRequired() case was introduced, so it's possible the new impl is missing something in the HTTP/1.1 case.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was also unsure what this was for, a quick glance at the repo and I couldn't find any location where it was caught to be handled.


assert self.connection is not None
logger.trace(
"connection.arequest method=%r url=%r headers=%r", method, url, headers
)
return await self.connection.arequest(method, url, headers, stream, ext)

async def _open_socket(self, timeout: TimeoutDict = None) -> AsyncSocketStream:
scheme, hostname, port = self.origin
timeout = {} if timeout is None else timeout
ssl_context = self.ssl_context if scheme == b"https" else None

retries_left = self.retries
delays = exponential_backoff(factor=RETRIES_BACKOFF_FACTOR)

while True:
try:
if self.uds is None:
return await self.backend.open_tcp_stream(
hostname,
port,
ssl_context,
timeout,
local_address=self.local_address,
)
else:
return await self.backend.open_uds_stream(
self.uds, hostname, ssl_context, timeout
)
except (ConnectError, ConnectTimeout):
if retries_left <= 0:
self.connect_failed = True
raise
retries_left -= 1
delay = next(delays)
await self.backend.sleep(delay)
except Exception: # noqa: PIE786
self.connect_failed = True
raise

def _create_connection(self, socket: AsyncSocketStream) -> None:
http_version = socket.get_http_version()
logger.trace(
"create_connection socket=%r http_version=%r", socket, http_version
)
if http_version == "HTTP/2":
from .http2 import AsyncHTTP2Connection

self.is_http2 = True
self.connection = AsyncHTTP2Connection(
socket=socket, backend=self.backend, ssl_context=self.ssl_context
)
else:
self.is_http11 = True
self.connection = AsyncHTTP11Connection(
socket=socket, ssl_context=self.ssl_context
)

@property
def state(self) -> ConnectionState:
if self.connect_failed:
return ConnectionState.CLOSED
elif self.connection is None:
return ConnectionState.PENDING
return self.connection.get_state()

def is_socket_readable(self) -> bool:
return self.connection is not None and self.connection.is_socket_readable()
return self.connection.is_socket_readable()

def mark_as_ready(self) -> None:
if self.connection is not None:
self.connection.mark_as_ready()
self.connection.mark_as_ready()

async def start_tls(self, hostname: bytes, timeout: TimeoutDict = None) -> None:
if self.connection is not None:
logger.trace("start_tls hostname=%r timeout=%r", hostname, timeout)
self.socket = await self.connection.start_tls(hostname, timeout)
logger.trace("start_tls complete hostname=%r timeout=%r", hostname, timeout)
logger.trace("start_tls hostname=%r timeout=%r", hostname, timeout)
self.socket = await self.connection.start_tls(hostname, timeout)
logger.trace("start_tls complete hostname=%r timeout=%r", hostname, timeout)

async def aclose(self) -> None:
async with self.request_lock:
if self.connection is not None:
await self.connection.aclose()
await self.connection.aclose()
65 changes: 57 additions & 8 deletions httpcore/_async/connection_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,23 @@
cast,
)

from .._backends.auto import AsyncBackend, AsyncLock, AsyncSemaphore
from .._backends.auto import AsyncBackend, AsyncLock, AsyncSemaphore, AsyncSocketStream
from .._backends.base import lookup_async_backend
from .._exceptions import LocalProtocolError, PoolTimeout, UnsupportedProtocol
from .._exceptions import (
ConnectError,
ConnectTimeout,
LocalProtocolError,
PoolTimeout,
UnsupportedProtocol,
)
from .._threadlock import ThreadLock
from .._types import URL, Headers, Origin, TimeoutDict
from .._utils import get_logger, origin_to_url_string, url_to_origin
from .._utils import (
exponential_backoff,
get_logger,
origin_to_url_string,
url_to_origin,
)
from .base import (
AsyncByteStream,
AsyncHTTPTransport,
Expand All @@ -28,6 +39,8 @@

logger = get_logger(__name__)

RETRIES_BACKOFF_FACTOR = 0.5 # 0s, 0.5s, 1s, 2s, 4s, etc.


class NullSemaphore(AsyncSemaphore):
def __init__(self) -> None:
Expand Down Expand Up @@ -144,6 +157,8 @@ def __init__(
"package is not installed. Use 'pip install httpcore[http2]'."
)

self._ssl_context.set_alpn_protocols(["http/1.1", "h2"])

@property
def _connection_semaphore(self) -> AsyncSemaphore:
# We do this lazily, to make sure backend autodetection always
Expand All @@ -167,17 +182,50 @@ def _connection_acquiry_lock(self) -> AsyncLock:
def _create_connection(
self,
origin: Tuple[bytes, bytes, int],
socket: AsyncSocketStream,
) -> AsyncHTTPConnection:
return AsyncHTTPConnection(
origin=origin,
http2=self._http2,
uds=self._uds,
socket=socket,
ssl_context=self._ssl_context,
local_address=self._local_address,
retries=self._retries,
backend=self._backend,
)

async def _open_socket(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Custom pool implementations would be able to override this method (along with a strict package pin) if they wanted to customize how sockets are opened.

Standalone implementations (single connection w/o pooling) could copy this code and use them in their own open_custom_connection() context manager…

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While you can always have your own implementation you also need to implement the underlying connection code which is not impossible but it would be nice to utilise the existing connection codes in httpcore. By reimplementing all the code you

  • Duplicate work, it's IMO a non-trivial amount of code to get a basic HTTP connection going
  • Missing out on new functionality/bugfixes that is added by httpcore who know what they are doing when it comes to issues
  • Makes the implementers code easier to maintain, they don't need to worry about h11, h2 and how to interact with that

I totally understand this project is in it's infancy so you may not be prepared to expose the existing classes, I just wanted to point out something I wish to see in the future :)

self, origin: Origin, timeout: TimeoutDict = None
) -> AsyncSocketStream:
scheme, hostname, port = origin
timeout = {} if timeout is None else timeout
ssl_context = self._ssl_context if scheme == b"https" else None

logger.trace("open_socket origin=%r timeout=%r", origin, timeout)

retries_left = self._retries
delays = exponential_backoff(factor=RETRIES_BACKOFF_FACTOR)

while True:
try:
if self._uds is None:
return await self._backend.open_tcp_stream(
hostname,
port,
ssl_context,
timeout,
local_address=self._local_address,
)
else:
return await self._backend.open_uds_stream(
self._uds, hostname, ssl_context, timeout
)
except (ConnectError, ConnectTimeout):
if retries_left <= 0:
raise
retries_left -= 1
delay = next(delays)
await self._backend.sleep(delay)
except Exception: # noqa: PIE786
raise

async def arequest(
self,
method: bytes,
Expand Down Expand Up @@ -208,7 +256,8 @@ async def arequest(
connection = await self._get_connection_from_pool(origin)

if connection is None:
connection = self._create_connection(origin=origin)
socket = await self._open_socket(origin, timeout=timeout)
connection = self._create_connection(origin=origin, socket=socket)
logger.trace("created connection=%r", connection)
await self._add_to_pool(connection, timeout=timeout)
else:
Expand Down
11 changes: 7 additions & 4 deletions httpcore/_async/http_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,11 @@ async def _forward_request(
connection = await self._get_connection_from_pool(origin)

if connection is None:
socket = await self._open_socket(origin, timeout=timeout)
connection = AsyncHTTPConnection(
origin=origin, http2=self._http2, ssl_context=self._ssl_context
origin=origin,
socket=socket,
ssl_context=self._ssl_context,
)
await self._add_to_pool(connection, timeout)

Expand Down Expand Up @@ -193,9 +196,10 @@ async def _tunnel_request(
scheme, host, port = origin

# First, create a connection to the proxy server
socket = await self._open_socket(origin=self.proxy_origin, timeout=timeout)
proxy_connection = AsyncHTTPConnection(
origin=self.proxy_origin,
http2=self._http2,
socket=socket,
ssl_context=self._ssl_context,
)

Expand Down Expand Up @@ -247,9 +251,8 @@ async def _tunnel_request(
# retain the tunnel.
connection = AsyncHTTPConnection(
origin=origin,
http2=self._http2,
ssl_context=self._ssl_context,
socket=proxy_connection.socket,
ssl_context=self._ssl_context,
)
await self._add_to_pool(connection, timeout)

Expand Down
Loading