Skip to content

Commit

Permalink
Reduce slightly allocations when reading/writing HTTP messages (#950)
Browse files Browse the repository at this point in the history
* Avoid creating additional string when writing headers

* Define our own simpler Version type

* Update src/Messages.jl

Co-authored-by: Jacob Quinn <[email protected]>

* Update src/Messages.jl

Co-authored-by: Jacob Quinn <[email protected]>

* Test we don't allocate

* Tidy tests

* Remove unused code

* Update src/clientlayers/MessageRequest.jl

* Write headers to temporary buffer first

- ...before flushing to socket, to avoid task switching
  mid writing to the socket itself.

* Remove unneeded temp buffer in Stream startwrite

* Make loopback test less flaky by using Event instead of sleep

* Don't connect a socket for Loopback test

- unnecessary and was leading to occasional DNS Errors

Co-authored-by: Jacob Quinn <[email protected]>
  • Loading branch information
nickrobinson251 and quinnj authored Dec 21, 2022
1 parent 7a54ffc commit a2da061
Show file tree
Hide file tree
Showing 11 changed files with 158 additions and 45 deletions.
3 changes: 2 additions & 1 deletion src/ConnectionPool.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ mutable struct Connection <: IO
timestamp::Float64
readable::Bool
writable::Bool
writebuffer::IOBuffer
state::Any # populated & used by Servers code
end

Expand All @@ -92,7 +93,7 @@ Connection(host::AbstractString, port::AbstractString,
Connection(host, port, idle_timeout,
require_ssl_verification,
safe_getpeername(io)..., localport(io),
io, client, PipeBuffer(), time(), false, false, nothing)
io, client, PipeBuffer(), time(), false, false, IOBuffer(), nothing)

Connection(io; require_ssl_verification::Bool=true) =
Connection("", "", 0, require_ssl_verification, io, false)
Expand Down
62 changes: 32 additions & 30 deletions src/Messages.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ export Message, Request, Response,

using URIs, CodecZlib
using ..Pairs, ..IOExtras, ..Parsers, ..Strings, ..Forms, ..Conditions
using ..ConnectionPool

const nobody = UInt8[]
const unknown_length = typemax(Int)
Expand All @@ -77,7 +78,7 @@ abstract type Message end
Represents an HTTP response message with fields:
- `version::VersionNumber`
- `version::HTTPVersion`
[RFC7230 2.6](https://tools.ietf.org/html/rfc7230#section-2.6)
- `status::Int16`
Expand All @@ -94,14 +95,14 @@ Represents an HTTP response message with fields:
"""
mutable struct Response <: Message
version::VersionNumber
version::HTTPVersion
status::Int16
headers::Headers
body::Any # Usually Vector{UInt8} or IO
request::Union{Message, Nothing} # Union{Request, Nothing}
end

function Response(status::Integer, headers, body; version::VersionNumber=v"1.1", request=nothing)
function Response(status::Integer, headers, body; version=HTTPVersion(1, 1), request=nothing)
b = isbytes(body) ? bytes(body) : something(body, nobody)
@assert (request isa Request || request === nothing)
return Response(version, status, mkheaders(headers), b, request)
Expand All @@ -119,7 +120,7 @@ Response(body) = Response(200; body=body)
Base.convert(::Type{Response}, s::AbstractString) = Response(s)

function reset!(r::Response)
r.version = v"1.1"
r.version = HTTPVersion(1, 1)
r.status = 0
if !isempty(r.headers)
empty!(r.headers)
Expand All @@ -136,8 +137,10 @@ body(r::Response) = getfield(r, :body)
const Context = Dict{Symbol, Any}

"""
HTTP.Request(method, target, headers=[], body=nobody;
version=v"1.1", url::URI=URI(), responsebody=nothing, parent=nothing, context=HTTP.Context())
HTTP.Request(
method, target, headers=[], body=nobody;
version=v"1.1", url::URI=URI(), responsebody=nothing, parent=nothing, context=HTTP.Context()
)
Represents a HTTP Request Message with fields:
Expand All @@ -147,7 +150,7 @@ Represents a HTTP Request Message with fields:
- `target::String`
[RFC7230 5.3](https://tools.ietf.org/html/rfc7230#section-5.3)
- `version::VersionNumber`
- `version::HTTPVersion`
[RFC7230 2.6](https://tools.ietf.org/html/rfc7230#section-2.6)
- `headers::HTTP.Headers`
Expand All @@ -170,7 +173,7 @@ Represents a HTTP Request Message with fields:
mutable struct Request <: Message
method::String
target::String
version::VersionNumber
version::HTTPVersion
headers::Headers
body::Any # Usually Vector{UInt8} or some kind of IO
response::Response
Expand All @@ -181,8 +184,10 @@ end

Request() = Request("", "")

function Request(method::String, target, headers=[], body=nobody;
version=v"1.1", url::URI=URI(), responsebody=nothing, parent=nothing, context=Context())
function Request(
method::String, target, headers=[], body=nobody;
version=HTTPVersion(1, 1), url::URI=URI(), responsebody=nothing, parent=nothing, context=Context()
)
b = isbytes(body) ? bytes(body) : body
r = Request(method, target == "" ? "/" : target, version,
mkheaders(headers), b, Response(0; body=responsebody),
Expand Down Expand Up @@ -463,26 +468,19 @@ function decode(m::Message, encoding::String="gzip")::Vector{UInt8}
end

# Writing HTTP Messages to IO streams
"""
HTTP.httpversion(::Message)
e.g. `"HTTP/1.1"`
"""
httpversion(m::Message) = "HTTP/$(m.version.major).$(m.version.minor)"
Base.write(io::IO, v::HTTPVersion) = write(io, "HTTP/", string(v.major), ".", string(v.minor))

"""
writestartline(::IO, ::Message)
e.g. `"GET /path HTTP/1.1\\r\\n"` or `"HTTP/1.1 200 OK\\r\\n"`
"""
function writestartline(io::IO, r::Request)
write(io, "$(r.method) $(r.target) $(httpversion(r))\r\n")
return
return write(io, r.method, " ", r.target, " ", r.version, "\r\n")
end

function writestartline(io::IO, r::Response)
write(io, "$(httpversion(r)) $(r.status) $(statustext(r.status))\r\n")
return
return write(io, r.version, " ", string(r.status), " ", statustext(r.status), "\r\n")
end

"""
Expand All @@ -491,14 +489,18 @@ end
Write `Message` start line and
a line for each "name: value" pair and a trailing blank line.
"""
function writeheaders(io::IO, m::Message)
writestartline(io, m)
writeheaders(io::IO, m::Message) = writeheaders(io, m, IOBuffer())
writeheaders(io::Connection, m::Message) = writeheaders(io, m, io.writebuffer)

function writeheaders(io::IO, m::Message, buf::IOBuffer)
writestartline(buf, m)
for (name, value) in m.headers
# match curl convention of not writing empty headers
!isempty(value) && write(io, "$name: $value\r\n")
!isempty(value) && write(buf, name, ": ", value, "\r\n")
end
write(io, "\r\n")
return
write(buf, "\r\n")
nwritten = write(io, take!(buf))
return nwritten
end

"""
Expand All @@ -507,15 +509,15 @@ end
Write start line, headers and body of HTTP Message.
"""
function Base.write(io::IO, m::Message)
writeheaders(io, m)
write(io, m.body)
return
nwritten = writeheaders(io, m)
nwritten += write(io, m.body)
return nwritten
end

function Base.String(m::Message)
io = IOBuffer()
write(io, m)
String(take!(io))
return String(take!(io))
end

# Reading HTTP Messages from IO streams
Expand Down Expand Up @@ -589,7 +591,7 @@ end
function compactstartline(m::Message)
b = IOBuffer()
writestartline(b, m)
strip(String(take!(b)))
return strip(String(take!(b)))
end

# temporary replacement for isvalid(String, s), until the
Expand Down
4 changes: 2 additions & 2 deletions src/Parsers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ function parse_request_line!(bytes::AbstractString, request)::SubString{String}
end
request.method = group(1, re, bytes)
request.target = group(2, re, bytes)
request.version = VersionNumber(group(3, re, bytes))
request.version = HTTPVersion(group(3, re, bytes))
return nextbytes(re, bytes)
end

Expand All @@ -205,7 +205,7 @@ function parse_status_line!(bytes::AbstractString, response)::SubString{String}
if !exec(re, bytes)
throw(ParseError(:INVALID_STATUS_LINE, bytes))
end
response.version = VersionNumber(group(1, re, bytes))
response.version = HTTPVersion(group(1, re, bytes))
response.status = parse(Int, group(2, re, bytes))
return nextbytes(re, bytes)
end
Expand Down
2 changes: 1 addition & 1 deletion src/Servers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -494,7 +494,7 @@ function check_readtimeout(c, readtimeout, wait_for_timeout)
if inactiveseconds(c) > readtimeout
@warnv 2 "Connection Timeout: $c"
try
writeheaders(c.io, Response(408, ["Connection" => "close"]))
writeheaders(c, Response(408, ["Connection" => "close"]))
finally
closeconnection(c)
end
Expand Down
4 changes: 1 addition & 3 deletions src/Streams.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,7 @@ function IOExtras.startwrite(http::Stream)
else
http.writechunked = ischunked(m)
end
buf = IOBuffer()
writeheaders(buf, m)
n = write(http.stream, take!(buf))
n = writeheaders(http.stream, m)
# nwritten starts at -1 so that we can tell if we've written anything yet
http.nwritten = 0 # should not include headers
return n
Expand Down
81 changes: 80 additions & 1 deletion src/Strings.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,88 @@
module Strings

export escapehtml, tocameldash, iso8859_1_to_utf8, ascii_lc_isequal
export HTTPVersion, escapehtml, tocameldash, iso8859_1_to_utf8, ascii_lc_isequal

using ..IOExtras

# A `Base.VersionNumber` is a SemVer spec, whereas a HTTP versions is just 2 digits,
# This allows us to use a smaller type and more importantly write a simple parse method
# that avoid allocations.
"""
HTTPVersion(major, minor)
The HTTP version number consists of two digits separated by a
"." (period or decimal point). The first digit (`major` version)
indicates the HTTP messaging syntax, whereas the second digit (`minor`
version) indicates the highest minor version within that major
version to which the sender is conformant and able to understand for
future communication.
See [RFC7230 2.6](https://tools.ietf.org/html/rfc7230#section-2.6)
"""
struct HTTPVersion
major::UInt8
minor::UInt8
end

HTTPVersion(major::Integer) = HTTPVersion(major, 0x00)
HTTPVersion(v::AbstractString) = parse(HTTPVersion, v)
HTTPVersion(v::VersionNumber) = convert(HTTPVersion, v)
# Lossy conversion. We ignore patch/prerelease/build parts even if non-zero/non-empty,
# because we don't want to add overhead for a case that should never be relevant.
Base.convert(::Type{HTTPVersion}, v::VersionNumber) = HTTPVersion(v.major, v.minor)
Base.VersionNumber(v::HTTPVersion) = VersionNumber(v.major, v.minor)

Base.show(io::IO, v::HTTPVersion) = print(io, "HTTPVersion(\"", string(v.major), ".", string(v.minor), "\")")

Base.:(==)(va::VersionNumber, vb::HTTPVersion) = va == VersionNumber(vb)
Base.:(==)(va::HTTPVersion, vb::VersionNumber) = VersionNumber(va) == vb

Base.isless(va::VersionNumber, vb::HTTPVersion) = isless(va, VersionNumber(vb))
Base.isless(va::HTTPVersion, vb::VersionNumber) = isless(VersionNumber(va), vb)
function Base.isless(va::HTTPVersion, vb::HTTPVersion)
va.major < vb.major && return true
va.major > vb.major && return false
va.minor < vb.minor && return true
return false
end

function Base.parse(::Type{HTTPVersion}, v::AbstractString)
ver = tryparse(HTTPVersion, v)
ver === nothing && throw(ArgumentError("invalid HTTP version string: $(repr(v))"))
return ver
end

# We only support single-digits for major and minor versions
# - we can parse 0.9 but not 0.10
# - we can parse 9.0 but not 10.0
function Base.tryparse(::Type{HTTPVersion}, v::AbstractString)
isempty(v) && return nothing
len = ncodeunits(v)

i = firstindex(v)
d1 = v[i]
if isdigit(d1)
major = parse(UInt8, d1)
else
return nothing
end

i = nextind(v, i)
i > len && return HTTPVersion(major)
dot = v[i]
dot == '.' || return nothing

i = nextind(v, i)
i > len && return HTTPVersion(major)
d2 = v[i]
if isdigit(d2)
minor = parse(UInt8, d2)
else
return nothing
end
return HTTPVersion(major, minor)
end

"""
escapehtml(i::String)
Expand Down
4 changes: 3 additions & 1 deletion src/clientlayers/MessageRequest.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ module MessageRequest

using URIs
using ..IOExtras, ..Messages, ..Parsers, ..Exceptions
using ..Messages, ..Parsers
using ..Strings: HTTPVersion

export messagelayer

Expand All @@ -12,7 +14,7 @@ Construct a [`Request`](@ref) object from method, url, headers, and body.
Hard-coded as the first layer in the request pipeline.
"""
function messagelayer(handler)
return function(method::String, url::URI, headers::Headers, body; response_stream=nothing, http_version=v"1.1", kw...)
return function(method::String, url::URI, headers::Headers, body; response_stream=nothing, http_version=HTTPVersion(1, 1), kw...)
req = Request(method, resource(url), headers, body; url=url, version=http_version, responsebody=response_stream)
local resp
try
Expand Down
31 changes: 31 additions & 0 deletions test/httpversion.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
@testset "HTTPVersion" begin

# Constructors
@test HTTPVersion(1) == HTTPVersion(1, 0)
@test HTTPVersion(1) == HTTPVersion("1")
@test HTTPVersion(1) == HTTPVersion("1.0")
@test HTTPVersion(1) == HTTPVersion(v"1")
@test HTTPVersion(1) == HTTPVersion(v"1.0")

@test HTTPVersion(1, 1) == HTTPVersion("1.1")
@test HTTPVersion(1, 1) == HTTPVersion(v"1.1")
@test HTTPVersion(1, 1) == HTTPVersion(v"1.1.0")

@test VersionNumber(HTTPVersion(1)) == v"1"
@test VersionNumber(HTTPVersion(1, 1)) == v"1.1"

# Important that we can parse a string into a `HTTPVersion` without allocations,
# as we do this for every request/response. Similarly if we then want a `VersionNumber`.
@test @allocated(HTTPVersion("1.1")) == 0
@test @allocated(VersionNumber(HTTPVersion("1.1"))) == 0

# Test comparisons with `VersionNumber`s
req = HTTP.Request("GET", "http://httpbin.org/anything")
res = HTTP.Response(200)
for r in (req, res)
@test r.version == v"1.1"
@test r.version <= v"1.1"
@test r.version < v"1.2"
end

end # testset
6 changes: 4 additions & 2 deletions test/loopback.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ Base.readavailable(fio::FunctionIO) = (call(fio); readavailable(fio.buf))
Base.readavailable(lb::Loopback) = readavailable(lb.io)
Base.unsafe_read(lb::Loopback, p::Ptr, n::Integer) = unsafe_read(lb.io, p, n)

HTTP.IOExtras.tcpsocket(::Loopback) = Sockets.connect("$httpbin", 80)
HTTP.IOExtras.tcpsocket(::Loopback) = TCPSocket()

lbreq(req, headers, body; method="GET", kw...) =
HTTP.request(method, "http://test/$req", headers, body; config..., kw...)
Expand Down Expand Up @@ -262,8 +262,9 @@ end
@test_throws HTTP.StatusError begin
r = lbopen("abort", []) do http
@sync begin
event = Base.Event()
@async try
sleep(0.1)
wait(event)
write(http, "Hello World!")
closewrite(http)
body_sent = true
Expand All @@ -278,6 +279,7 @@ end
startread(http)
body = read(http)
closeread(http)
notify(event)
end
end
end
Expand Down
Loading

0 comments on commit a2da061

Please sign in to comment.