From 2c05a12568a543290626890b346b918e2f0f1ad5 Mon Sep 17 00:00:00 2001 From: Jacob Quinn Date: Sat, 15 Jul 2023 07:01:17 +0300 Subject: [PATCH] Refactor to support parameterized OpenSSL.SSLStream --- src/Connections.jl | 59 ++++++++++------------------------------------ src/HTTP.jl | 6 ++--- src/Servers.jl | 2 +- test/client.jl | 6 ++--- 4 files changed, 19 insertions(+), 54 deletions(-) diff --git a/src/Connections.jl b/src/Connections.jl index f41cdd282..e2c2b2932 100644 --- a/src/Connections.jl +++ b/src/Connections.jl @@ -39,18 +39,11 @@ function __init__() # there was no artificial restriction on overall throughput default_connection_limit[] = max(16, Threads.nthreads() * 4) nosslcontext[] = OpenSSL.SSLContext(OpenSSL.TLSClientMethod()) - TCP_POOL[] = CPool{Sockets.TCPSocket}(default_connection_limit[]) - MBEDTLS_POOL[] = CPool{MbedTLS.SSLContext}(default_connection_limit[]) - OPENSSL_POOL[] = CPool{OpenSSL.SSLStream}(default_connection_limit[]) return end function set_default_connection_limit!(n) default_connection_limit[] = n - # reinitialize the global connection pools - TCP_POOL[] = CPool{Sockets.TCPSocket}(n) - MBEDTLS_POOL[] = CPool{MbedTLS.SSLContext}(n) - OPENSSL_POOL[] = CPool{OpenSSL.SSLStream}(n) return end @@ -360,47 +353,27 @@ A pool can be passed to any of the `HTTP.request` methods via the `pool` keyword """ struct Pool lock::ReentrantLock - tcp::CPool{Sockets.TCPSocket} - mbedtls::CPool{MbedTLS.SSLContext} - openssl::CPool{OpenSSL.SSLStream} - other::IdDict{Type, CPool} + pools::IdDict{Type, CPool} max::Int end function Pool(max::Union{Int, Nothing}=nothing) max = something(max, default_connection_limit[]) return Pool(ReentrantLock(), - CPool{Sockets.TCPSocket}(max), - CPool{MbedTLS.SSLContext}(max), - CPool{OpenSSL.SSLStream}(max), IdDict{Type, CPool}(), max, ) end -# Default HTTP global connection pools -const TCP_POOL = Ref{CPool{Sockets.TCPSocket}}() -const MBEDTLS_POOL = Ref{CPool{MbedTLS.SSLContext}}() -const OPENSSL_POOL = Ref{CPool{OpenSSL.SSLStream}}() -const OTHER_POOL = Lockable(IdDict{Type, CPool}()) +# Default HTTP global connection pool +const POOL = Lockable(IdDict{Type, CPool}()) -getpool(::Nothing, ::Type{Sockets.TCPSocket}) = TCP_POOL[] -getpool(::Nothing, ::Type{MbedTLS.SSLContext}) = MBEDTLS_POOL[] -getpool(::Nothing, ::Type{OpenSSL.SSLStream}) = OPENSSL_POOL[] -getpool(::Nothing, ::Type{T}) where {T} = Base.@lock OTHER_POOL get!(OTHER_POOL[], T) do +getpool(::Nothing, ::Type{T}) where {T} = Base.@lock POOL get!(POOL[], T) do CPool{T}(default_connection_limit[]) end function getpool(pool::Pool, ::Type{T})::CPool{T} where {T} - if T === Sockets.TCPSocket - return pool.tcp - elseif T === MbedTLS.SSLContext - return pool.mbedtls - elseif T === OpenSSL.SSLStream - return pool.openssl - else - return Base.@lock pool.lock get!(() -> CPool{T}(pool.max), pool.other, T) - end + return Base.@lock pool.lock get!(() -> CPool{T}(pool.max), pool.pools, T) end """ @@ -411,15 +384,9 @@ If `pool` is not specified, the default global pools are closed. """ function closeall(pool::Union{Nothing, Pool}=nothing) if pool === nothing - drain!(TCP_POOL[]) - drain!(MBEDTLS_POOL[]) - drain!(OPENSSL_POOL[]) - Base.@lock OTHER_POOL foreach(drain!, values(OTHER_POOL[])) + Base.@lock POOL foreach(drain!, values(POOL[])) else - drain!(pool.tcp) - drain!(pool.mbedtls) - drain!(pool.openssl) - Base.@lock pool.lock foreach(drain!, values(pool.other)) + Base.@lock pool.lock foreach(drain!, values(pool.pools)) end return end @@ -570,20 +537,20 @@ function getconnection(::Type{SSLContext}, return sslconnection(SSLContext, tcp, host; kw...) end -function getconnection(::Type{SSLStream}, +function getconnection(::Type{SSLStream{T}}, host::AbstractString, port::AbstractString; - kw...)::SSLStream + kw...)::SSLStream{T} where {T} port = isempty(port) ? "443" : port @debugv 2 "SSL connect: $host:$port..." - tcp = getconnection(TCPSocket, host, port; kw...) - return sslconnection(SSLStream, tcp, host; kw...) + tcp = getconnection(T, host, port; kw...) + return sslconnection(SSLStream{T}, tcp, host; kw...) end -function sslconnection(::Type{SSLStream}, tcp::TCPSocket, host::AbstractString; +function sslconnection(::Type{SSLStream{T}}, tcp::T, host::AbstractString; require_ssl_verification::Bool=NetworkOptions.verify_host(host, "SSL"), sslconfig::OpenSSL.SSLContext=nosslcontext[], - kw...)::SSLStream + kw...)::SSLStream{T} where {T} if sslconfig === nosslcontext[] sslconfig = global_sslcontext() end diff --git a/src/HTTP.jl b/src/HTTP.jl index cc3d027b8..2bef0d93b 100644 --- a/src/HTTP.jl +++ b/src/HTTP.jl @@ -24,7 +24,7 @@ end function open end -const SOCKET_TYPE_TLS = Ref{Any}(OpenSSL.SSLStream) +const SOCKET_TYPE_TLS = Ref{Any}(OpenSSL.SSLStream{TCPSocket}) include("Conditions.jl") ;using .Conditions include("access_log.jl") @@ -190,8 +190,8 @@ SSL arguments: ["... peer must present a valid certificate, handshake is aborted if verification failed."](https://tls.mbed.org/api/ssl_8h.html#a5695285c9dbfefec295012b566290f37) - `sslconfig = SSLConfig(require_ssl_verification)` - - `socket_type_tls = MbedTLS.SSLContext`, the type of socket to use for TLS connections. Defaults to `MbedTLS.SSLContext`. - Also supported is passing `socket_type_tls = OpenSSL.SSLStream`. To change the global default, set `HTTP.SOCKET_TYPE_TLS[] = OpenSSL.SSLStream`. + - `socket_type_tls = OpenSSL.SSLStream{TCPSocket}`, the type of socket to use for TLS connections. Defaults to `OpenSSL.SSLStream{TCPSocket}`. + Also supported is passing `socket_type_tls = MbedTLS.SSLContext`. To change the global default, set `HTTP.SOCKET_TYPE_TLS[] = MbedTLS.SSLContext`. Cookie arguments: - `cookies::Union{Bool, Dict{<:AbstractString, <:AbstractString}} = true`, enable cookies, or alternatively, diff --git a/src/Servers.jl b/src/Servers.jl index 6afb16590..261757957 100644 --- a/src/Servers.jl +++ b/src/Servers.jl @@ -14,7 +14,7 @@ export listen, listen!, Server, forceclose, port using Sockets, Logging, LoggingExtras, MbedTLS, Dates using MbedTLS: SSLContext, SSLConfig using ..IOExtras, ..Streams, ..Messages, ..Parsers, ..Connections, ..Exceptions -import ..access_threaded, ..SOCKET_TYPE_TLS, ..@logfmt_str +import ..access_threaded, ..@logfmt_str TRUE(x) = true getinet(host::String, port::Integer) = Sockets.InetAddr(parse(IPAddr, host), port) diff --git a/test/client.jl b/test/client.jl index 4c12f9c64..89c1f236f 100644 --- a/test/client.jl +++ b/test/client.jl @@ -14,9 +14,7 @@ using InteractiveUtils: @which # test we can adjust default_connection_limit for x in (10, 12) HTTP.set_default_connection_limit!(x) - @test HTTP.Connections.TCP_POOL[].max == x - @test HTTP.Connections.MBEDTLS_POOL[].max == x - @test HTTP.Connections.OPENSSL_POOL[].max == x + @test HTTP.Connections.default_connection_limit[] == x end @testset "@client macro" begin @@ -43,7 +41,7 @@ end end end -@testset "Client.jl" for tls in [MbedTLS.SSLContext, OpenSSL.SSLStream] +@testset "Client.jl" for tls in [MbedTLS.SSLContext, OpenSSL.SSLStream{TCPSocket}] @testset "GET, HEAD, POST, PUT, DELETE, PATCH" begin @test isok(HTTP.get("https://$httpbin/ip", socket_type_tls=tls)) @test isok(HTTP.head("https://$httpbin/ip", socket_type_tls=tls))