diff --git a/httpcore/_async/connection.py b/httpcore/_async/connection.py index 530663d8..9e03e6ba 100644 --- a/httpcore/_async/connection.py +++ b/httpcore/_async/connection.py @@ -1,54 +1,51 @@ from ssl import SSLContext -from typing import Optional, Tuple, cast +from typing import Optional, Tuple -from .._backends.auto import AsyncBackend, AsyncLock, AsyncSocketStream, AutoBackend -from .._exceptions import ConnectError, ConnectTimeout +from .._backends.auto import AsyncBackend, AsyncSocketStream, AutoBackend from .._types import URL, Headers, Origin, TimeoutDict -from .._utils import exponential_backoff, get_logger, url_to_origin -from .base import ( - AsyncByteStream, - AsyncHTTPTransport, - ConnectionState, - NewConnectionRequired, -) +from .._utils import get_logger, url_to_origin +from .base import AsyncByteStream, AsyncHTTPTransport, ConnectionState from .http import AsyncBaseHTTPConnection from .http11 import AsyncHTTP11Connection logger = get_logger(__name__) -RETRIES_BACKOFF_FACTOR = 0.5 # 0s, 0.5s, 1s, 2s, 4s, etc. - class AsyncHTTPConnection(AsyncHTTPTransport): def __init__( self, origin: Origin, - http2: bool = False, - uds: str = None, + socket: AsyncSocketStream, ssl_context: SSLContext = None, - socket: AsyncSocketStream = None, - local_address: str = None, - retries: int = 0, backend: AsyncBackend = None, ): self.origin = origin - self.http2 = http2 - self.uds = uds - self.ssl_context = SSLContext() if ssl_context is None else ssl_context self.socket = socket - self.local_address = local_address - self.retries = retries - - if self.http2: - self.ssl_context.set_alpn_protocols(["http/1.1", "h2"]) + self.ssl_context = SSLContext() if ssl_context is None else ssl_context - self.connection: Optional[AsyncBaseHTTPConnection] = None self.is_http11 = False self.is_http2 = False - self.connect_failed = False self.expires_at: Optional[float] = None self.backend = AutoBackend() if backend is None else backend + self.connection: AsyncBaseHTTPConnection + http_version = self.socket.get_http_version() + logger.trace( + "create_connection socket=%r http_version=%r", self.socket, http_version + ) + if http_version == "HTTP/2": + from .http2 import AsyncHTTP2Connection + + self.is_http2 = True + self.connection = AsyncHTTP2Connection( + socket=self.socket, backend=self.backend, ssl_context=self.ssl_context + ) + else: + self.is_http11 = True + self.connection = AsyncHTTP11Connection( + socket=self.socket, ssl_context=self.ssl_context + ) + def __repr__(self) -> str: http_version = "UNKNOWN" if self.is_http11: @@ -58,20 +55,10 @@ def __repr__(self) -> str: return f"" def info(self) -> str: - if self.connection is None: - return "Not connected" - elif self.state == ConnectionState.PENDING: + if self.state == ConnectionState.PENDING: return "Connecting" return self.connection.info() - @property - def request_lock(self) -> AsyncLock: - # We do this lazily, to make sure backend autodetection always - # runs within an async context. - if not hasattr(self, "_request_lock"): - self._request_lock = self.backend.create_lock() - return self._request_lock - async def arequest( self, method: bytes, @@ -81,103 +68,25 @@ async def arequest( ext: dict = None, ) -> Tuple[int, Headers, AsyncByteStream, dict]: assert url_to_origin(url) == self.origin - ext = {} if ext is None else ext - timeout = cast(TimeoutDict, ext.get("timeout", {})) - - async with self.request_lock: - if self.state == ConnectionState.PENDING: - if not self.socket: - logger.trace( - "open_socket origin=%r timeout=%r", self.origin, timeout - ) - self.socket = await self._open_socket(timeout) - self._create_connection(self.socket) - elif self.state in (ConnectionState.READY, ConnectionState.IDLE): - pass - elif self.state == ConnectionState.ACTIVE and self.is_http2: - pass - else: - raise NewConnectionRequired() - - assert self.connection is not None logger.trace( "connection.arequest method=%r url=%r headers=%r", method, url, headers ) return await self.connection.arequest(method, url, headers, stream, ext) - async def _open_socket(self, timeout: TimeoutDict = None) -> AsyncSocketStream: - scheme, hostname, port = self.origin - timeout = {} if timeout is None else timeout - ssl_context = self.ssl_context if scheme == b"https" else None - - retries_left = self.retries - delays = exponential_backoff(factor=RETRIES_BACKOFF_FACTOR) - - while True: - try: - if self.uds is None: - return await self.backend.open_tcp_stream( - hostname, - port, - ssl_context, - timeout, - local_address=self.local_address, - ) - else: - return await self.backend.open_uds_stream( - self.uds, hostname, ssl_context, timeout - ) - except (ConnectError, ConnectTimeout): - if retries_left <= 0: - self.connect_failed = True - raise - retries_left -= 1 - delay = next(delays) - await self.backend.sleep(delay) - except Exception: # noqa: PIE786 - self.connect_failed = True - raise - - def _create_connection(self, socket: AsyncSocketStream) -> None: - http_version = socket.get_http_version() - logger.trace( - "create_connection socket=%r http_version=%r", socket, http_version - ) - if http_version == "HTTP/2": - from .http2 import AsyncHTTP2Connection - - self.is_http2 = True - self.connection = AsyncHTTP2Connection( - socket=socket, backend=self.backend, ssl_context=self.ssl_context - ) - else: - self.is_http11 = True - self.connection = AsyncHTTP11Connection( - socket=socket, ssl_context=self.ssl_context - ) - @property def state(self) -> ConnectionState: - if self.connect_failed: - return ConnectionState.CLOSED - elif self.connection is None: - return ConnectionState.PENDING return self.connection.get_state() def is_socket_readable(self) -> bool: - return self.connection is not None and self.connection.is_socket_readable() + return self.connection.is_socket_readable() def mark_as_ready(self) -> None: - if self.connection is not None: - self.connection.mark_as_ready() + self.connection.mark_as_ready() async def start_tls(self, hostname: bytes, timeout: TimeoutDict = None) -> None: - if self.connection is not None: - logger.trace("start_tls hostname=%r timeout=%r", hostname, timeout) - self.socket = await self.connection.start_tls(hostname, timeout) - logger.trace("start_tls complete hostname=%r timeout=%r", hostname, timeout) + logger.trace("start_tls hostname=%r timeout=%r", hostname, timeout) + self.socket = await self.connection.start_tls(hostname, timeout) + logger.trace("start_tls complete hostname=%r timeout=%r", hostname, timeout) async def aclose(self) -> None: - async with self.request_lock: - if self.connection is not None: - await self.connection.aclose() + await self.connection.aclose() diff --git a/httpcore/_async/connection_pool.py b/httpcore/_async/connection_pool.py index 46ede0ba..c310e38d 100644 --- a/httpcore/_async/connection_pool.py +++ b/httpcore/_async/connection_pool.py @@ -12,12 +12,23 @@ cast, ) -from .._backends.auto import AsyncBackend, AsyncLock, AsyncSemaphore +from .._backends.auto import AsyncBackend, AsyncLock, AsyncSemaphore, AsyncSocketStream from .._backends.base import lookup_async_backend -from .._exceptions import LocalProtocolError, PoolTimeout, UnsupportedProtocol +from .._exceptions import ( + ConnectError, + ConnectTimeout, + LocalProtocolError, + PoolTimeout, + UnsupportedProtocol, +) from .._threadlock import ThreadLock from .._types import URL, Headers, Origin, TimeoutDict -from .._utils import get_logger, origin_to_url_string, url_to_origin +from .._utils import ( + exponential_backoff, + get_logger, + origin_to_url_string, + url_to_origin, +) from .base import ( AsyncByteStream, AsyncHTTPTransport, @@ -28,6 +39,8 @@ logger = get_logger(__name__) +RETRIES_BACKOFF_FACTOR = 0.5 # 0s, 0.5s, 1s, 2s, 4s, etc. + class NullSemaphore(AsyncSemaphore): def __init__(self) -> None: @@ -144,6 +157,8 @@ def __init__( "package is not installed. Use 'pip install httpcore[http2]'." ) + self._ssl_context.set_alpn_protocols(["http/1.1", "h2"]) + @property def _connection_semaphore(self) -> AsyncSemaphore: # We do this lazily, to make sure backend autodetection always @@ -167,17 +182,50 @@ def _connection_acquiry_lock(self) -> AsyncLock: def _create_connection( self, origin: Tuple[bytes, bytes, int], + socket: AsyncSocketStream, ) -> AsyncHTTPConnection: return AsyncHTTPConnection( origin=origin, - http2=self._http2, - uds=self._uds, + socket=socket, ssl_context=self._ssl_context, - local_address=self._local_address, - retries=self._retries, backend=self._backend, ) + async def _open_socket( + self, origin: Origin, timeout: TimeoutDict = None + ) -> AsyncSocketStream: + scheme, hostname, port = origin + timeout = {} if timeout is None else timeout + ssl_context = self._ssl_context if scheme == b"https" else None + + logger.trace("open_socket origin=%r timeout=%r", origin, timeout) + + retries_left = self._retries + delays = exponential_backoff(factor=RETRIES_BACKOFF_FACTOR) + + while True: + try: + if self._uds is None: + return await self._backend.open_tcp_stream( + hostname, + port, + ssl_context, + timeout, + local_address=self._local_address, + ) + else: + return await self._backend.open_uds_stream( + self._uds, hostname, ssl_context, timeout + ) + except (ConnectError, ConnectTimeout): + if retries_left <= 0: + raise + retries_left -= 1 + delay = next(delays) + await self._backend.sleep(delay) + except Exception: # noqa: PIE786 + raise + async def arequest( self, method: bytes, @@ -208,7 +256,8 @@ async def arequest( connection = await self._get_connection_from_pool(origin) if connection is None: - connection = self._create_connection(origin=origin) + socket = await self._open_socket(origin, timeout=timeout) + connection = self._create_connection(origin=origin, socket=socket) logger.trace("created connection=%r", connection) await self._add_to_pool(connection, timeout=timeout) else: diff --git a/httpcore/_async/http_proxy.py b/httpcore/_async/http_proxy.py index d9df762b..716f95d2 100644 --- a/httpcore/_async/http_proxy.py +++ b/httpcore/_async/http_proxy.py @@ -143,8 +143,11 @@ async def _forward_request( connection = await self._get_connection_from_pool(origin) if connection is None: + socket = await self._open_socket(origin, timeout=timeout) connection = AsyncHTTPConnection( - origin=origin, http2=self._http2, ssl_context=self._ssl_context + origin=origin, + socket=socket, + ssl_context=self._ssl_context, ) await self._add_to_pool(connection, timeout) @@ -193,9 +196,10 @@ async def _tunnel_request( scheme, host, port = origin # First, create a connection to the proxy server + socket = await self._open_socket(origin=self.proxy_origin, timeout=timeout) proxy_connection = AsyncHTTPConnection( origin=self.proxy_origin, - http2=self._http2, + socket=socket, ssl_context=self._ssl_context, ) @@ -247,9 +251,8 @@ async def _tunnel_request( # retain the tunnel. connection = AsyncHTTPConnection( origin=origin, - http2=self._http2, - ssl_context=self._ssl_context, socket=proxy_connection.socket, + ssl_context=self._ssl_context, ) await self._add_to_pool(connection, timeout) diff --git a/httpcore/_sync/connection.py b/httpcore/_sync/connection.py index 04042227..3949c8fd 100644 --- a/httpcore/_sync/connection.py +++ b/httpcore/_sync/connection.py @@ -1,54 +1,51 @@ from ssl import SSLContext -from typing import Optional, Tuple, cast +from typing import Optional, Tuple -from .._backends.sync import SyncBackend, SyncLock, SyncSocketStream, SyncBackend -from .._exceptions import ConnectError, ConnectTimeout +from .._backends.sync import SyncBackend, SyncSocketStream, SyncBackend from .._types import URL, Headers, Origin, TimeoutDict -from .._utils import exponential_backoff, get_logger, url_to_origin -from .base import ( - SyncByteStream, - SyncHTTPTransport, - ConnectionState, - NewConnectionRequired, -) +from .._utils import get_logger, url_to_origin +from .base import SyncByteStream, SyncHTTPTransport, ConnectionState from .http import SyncBaseHTTPConnection from .http11 import SyncHTTP11Connection logger = get_logger(__name__) -RETRIES_BACKOFF_FACTOR = 0.5 # 0s, 0.5s, 1s, 2s, 4s, etc. - class SyncHTTPConnection(SyncHTTPTransport): def __init__( self, origin: Origin, - http2: bool = False, - uds: str = None, + socket: SyncSocketStream, ssl_context: SSLContext = None, - socket: SyncSocketStream = None, - local_address: str = None, - retries: int = 0, backend: SyncBackend = None, ): self.origin = origin - self.http2 = http2 - self.uds = uds - self.ssl_context = SSLContext() if ssl_context is None else ssl_context self.socket = socket - self.local_address = local_address - self.retries = retries - - if self.http2: - self.ssl_context.set_alpn_protocols(["http/1.1", "h2"]) + self.ssl_context = SSLContext() if ssl_context is None else ssl_context - self.connection: Optional[SyncBaseHTTPConnection] = None self.is_http11 = False self.is_http2 = False - self.connect_failed = False self.expires_at: Optional[float] = None self.backend = SyncBackend() if backend is None else backend + self.connection: SyncBaseHTTPConnection + http_version = self.socket.get_http_version() + logger.trace( + "create_connection socket=%r http_version=%r", self.socket, http_version + ) + if http_version == "HTTP/2": + from .http2 import SyncHTTP2Connection + + self.is_http2 = True + self.connection = SyncHTTP2Connection( + socket=self.socket, backend=self.backend, ssl_context=self.ssl_context + ) + else: + self.is_http11 = True + self.connection = SyncHTTP11Connection( + socket=self.socket, ssl_context=self.ssl_context + ) + def __repr__(self) -> str: http_version = "UNKNOWN" if self.is_http11: @@ -58,20 +55,10 @@ def __repr__(self) -> str: return f"" def info(self) -> str: - if self.connection is None: - return "Not connected" - elif self.state == ConnectionState.PENDING: + if self.state == ConnectionState.PENDING: return "Connecting" return self.connection.info() - @property - def request_lock(self) -> SyncLock: - # We do this lazily, to make sure backend autodetection always - # runs within an async context. - if not hasattr(self, "_request_lock"): - self._request_lock = self.backend.create_lock() - return self._request_lock - def request( self, method: bytes, @@ -81,103 +68,25 @@ def request( ext: dict = None, ) -> Tuple[int, Headers, SyncByteStream, dict]: assert url_to_origin(url) == self.origin - ext = {} if ext is None else ext - timeout = cast(TimeoutDict, ext.get("timeout", {})) - - with self.request_lock: - if self.state == ConnectionState.PENDING: - if not self.socket: - logger.trace( - "open_socket origin=%r timeout=%r", self.origin, timeout - ) - self.socket = self._open_socket(timeout) - self._create_connection(self.socket) - elif self.state in (ConnectionState.READY, ConnectionState.IDLE): - pass - elif self.state == ConnectionState.ACTIVE and self.is_http2: - pass - else: - raise NewConnectionRequired() - - assert self.connection is not None logger.trace( "connection.request method=%r url=%r headers=%r", method, url, headers ) return self.connection.request(method, url, headers, stream, ext) - def _open_socket(self, timeout: TimeoutDict = None) -> SyncSocketStream: - scheme, hostname, port = self.origin - timeout = {} if timeout is None else timeout - ssl_context = self.ssl_context if scheme == b"https" else None - - retries_left = self.retries - delays = exponential_backoff(factor=RETRIES_BACKOFF_FACTOR) - - while True: - try: - if self.uds is None: - return self.backend.open_tcp_stream( - hostname, - port, - ssl_context, - timeout, - local_address=self.local_address, - ) - else: - return self.backend.open_uds_stream( - self.uds, hostname, ssl_context, timeout - ) - except (ConnectError, ConnectTimeout): - if retries_left <= 0: - self.connect_failed = True - raise - retries_left -= 1 - delay = next(delays) - self.backend.sleep(delay) - except Exception: # noqa: PIE786 - self.connect_failed = True - raise - - def _create_connection(self, socket: SyncSocketStream) -> None: - http_version = socket.get_http_version() - logger.trace( - "create_connection socket=%r http_version=%r", socket, http_version - ) - if http_version == "HTTP/2": - from .http2 import SyncHTTP2Connection - - self.is_http2 = True - self.connection = SyncHTTP2Connection( - socket=socket, backend=self.backend, ssl_context=self.ssl_context - ) - else: - self.is_http11 = True - self.connection = SyncHTTP11Connection( - socket=socket, ssl_context=self.ssl_context - ) - @property def state(self) -> ConnectionState: - if self.connect_failed: - return ConnectionState.CLOSED - elif self.connection is None: - return ConnectionState.PENDING return self.connection.get_state() def is_socket_readable(self) -> bool: - return self.connection is not None and self.connection.is_socket_readable() + return self.connection.is_socket_readable() def mark_as_ready(self) -> None: - if self.connection is not None: - self.connection.mark_as_ready() + self.connection.mark_as_ready() def start_tls(self, hostname: bytes, timeout: TimeoutDict = None) -> None: - if self.connection is not None: - logger.trace("start_tls hostname=%r timeout=%r", hostname, timeout) - self.socket = self.connection.start_tls(hostname, timeout) - logger.trace("start_tls complete hostname=%r timeout=%r", hostname, timeout) + logger.trace("start_tls hostname=%r timeout=%r", hostname, timeout) + self.socket = self.connection.start_tls(hostname, timeout) + logger.trace("start_tls complete hostname=%r timeout=%r", hostname, timeout) def close(self) -> None: - with self.request_lock: - if self.connection is not None: - self.connection.close() + self.connection.close() diff --git a/httpcore/_sync/connection_pool.py b/httpcore/_sync/connection_pool.py index 4702184b..e1a56a4d 100644 --- a/httpcore/_sync/connection_pool.py +++ b/httpcore/_sync/connection_pool.py @@ -12,12 +12,23 @@ cast, ) -from .._backends.sync import SyncBackend, SyncLock, SyncSemaphore +from .._backends.sync import SyncBackend, SyncLock, SyncSemaphore, SyncSocketStream from .._backends.base import lookup_sync_backend -from .._exceptions import LocalProtocolError, PoolTimeout, UnsupportedProtocol +from .._exceptions import ( + ConnectError, + ConnectTimeout, + LocalProtocolError, + PoolTimeout, + UnsupportedProtocol, +) from .._threadlock import ThreadLock from .._types import URL, Headers, Origin, TimeoutDict -from .._utils import get_logger, origin_to_url_string, url_to_origin +from .._utils import ( + exponential_backoff, + get_logger, + origin_to_url_string, + url_to_origin, +) from .base import ( SyncByteStream, SyncHTTPTransport, @@ -28,6 +39,8 @@ logger = get_logger(__name__) +RETRIES_BACKOFF_FACTOR = 0.5 # 0s, 0.5s, 1s, 2s, 4s, etc. + class NullSemaphore(SyncSemaphore): def __init__(self) -> None: @@ -144,6 +157,8 @@ def __init__( "package is not installed. Use 'pip install httpcore[http2]'." ) + self._ssl_context.set_alpn_protocols(["http/1.1", "h2"]) + @property def _connection_semaphore(self) -> SyncSemaphore: # We do this lazily, to make sure backend autodetection always @@ -167,17 +182,50 @@ def _connection_acquiry_lock(self) -> SyncLock: def _create_connection( self, origin: Tuple[bytes, bytes, int], + socket: SyncSocketStream, ) -> SyncHTTPConnection: return SyncHTTPConnection( origin=origin, - http2=self._http2, - uds=self._uds, + socket=socket, ssl_context=self._ssl_context, - local_address=self._local_address, - retries=self._retries, backend=self._backend, ) + def _open_socket( + self, origin: Origin, timeout: TimeoutDict = None + ) -> SyncSocketStream: + scheme, hostname, port = origin + timeout = {} if timeout is None else timeout + ssl_context = self._ssl_context if scheme == b"https" else None + + logger.trace("open_socket origin=%r timeout=%r", origin, timeout) + + retries_left = self._retries + delays = exponential_backoff(factor=RETRIES_BACKOFF_FACTOR) + + while True: + try: + if self._uds is None: + return self._backend.open_tcp_stream( + hostname, + port, + ssl_context, + timeout, + local_address=self._local_address, + ) + else: + return self._backend.open_uds_stream( + self._uds, hostname, ssl_context, timeout + ) + except (ConnectError, ConnectTimeout): + if retries_left <= 0: + raise + retries_left -= 1 + delay = next(delays) + self._backend.sleep(delay) + except Exception: # noqa: PIE786 + raise + def request( self, method: bytes, @@ -208,7 +256,8 @@ def request( connection = self._get_connection_from_pool(origin) if connection is None: - connection = self._create_connection(origin=origin) + socket = self._open_socket(origin, timeout=timeout) + connection = self._create_connection(origin=origin, socket=socket) logger.trace("created connection=%r", connection) self._add_to_pool(connection, timeout=timeout) else: diff --git a/httpcore/_sync/http_proxy.py b/httpcore/_sync/http_proxy.py index f5576c01..631d8ec7 100644 --- a/httpcore/_sync/http_proxy.py +++ b/httpcore/_sync/http_proxy.py @@ -143,8 +143,11 @@ def _forward_request( connection = self._get_connection_from_pool(origin) if connection is None: + socket = self._open_socket(origin, timeout=timeout) connection = SyncHTTPConnection( - origin=origin, http2=self._http2, ssl_context=self._ssl_context + origin=origin, + socket=socket, + ssl_context=self._ssl_context, ) self._add_to_pool(connection, timeout) @@ -193,9 +196,10 @@ def _tunnel_request( scheme, host, port = origin # First, create a connection to the proxy server + socket = self._open_socket(origin=self.proxy_origin, timeout=timeout) proxy_connection = SyncHTTPConnection( origin=self.proxy_origin, - http2=self._http2, + socket=socket, ssl_context=self._ssl_context, ) @@ -247,9 +251,8 @@ def _tunnel_request( # retain the tunnel. connection = SyncHTTPConnection( origin=origin, - http2=self._http2, - ssl_context=self._ssl_context, socket=proxy_connection.socket, + ssl_context=self._ssl_context, ) self._add_to_pool(connection, timeout)