Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow passing pre-allocated buffer for response body #984

Merged
merged 6 commits into from
Jan 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 19 additions & 15 deletions src/Streams.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module Streams

export Stream, closebody, isaborted, setstatus
export Stream, closebody, isaborted, setstatus, readall!

using Sockets, LoggingExtras
using ..IOExtras, ..Messages, ..ConnectionPool, ..Conditions, ..Exceptions
Expand Down Expand Up @@ -252,23 +252,21 @@ function Base.read(http::Stream, ::Type{UInt8})
end

function http_unsafe_read(http::Stream, p::Ptr{UInt8}, n::UInt)::Int

ntr = UInt(ntoread(http))
if ntr == 0
return 0
end
ntr == 0 && return 0
# If there is spare space in `p`
# read two extra bytes
# (`\r\n` at end ofchunk).
unsafe_read(http.stream, p, min(n, ntr + (http.readchunked ? 2 : 0)))
# If there is spare space in `p`
# read two extra bytes
n = min(n, ntr) # (`\r\n` at end ofchunk).
n = min(n, ntr)
update_ntoread(http, n)
return n
end

function Base.readbytes!(http::Stream, buf::AbstractVector{UInt8},
n=length(buf))
@require n <= length(buf)
return http_unsafe_read(http, pointer(buf), UInt(n))
return GC.@preserve buf http_unsafe_read(http, pointer(buf), UInt(n))
end

function Base.unsafe_read(http::Stream, p::Ptr{UInt8}, n::UInt)
Expand All @@ -282,14 +280,20 @@ function Base.unsafe_read(http::Stream, p::Ptr{UInt8}, n::UInt)
nothing
end

function Base.readbytes!(http::Stream, buf::IOBuffer, n=bytesavailable(http))
Base.ensureroom(buf, n)
unsafe_read(http, pointer(buf.data, buf.size + 1), n)
@noinline bufcheck(buf, n) = ((buf.size + n) <= length(buf.data)) || throw(ArgumentError("Unable to grow response stream IOBuffer large enough for response body size"))

function Base.readbytes!(http::Stream, buf::Base.GenericIOBuffer, n=bytesavailable(http))
Base.ensureroom(buf, buf.size + n)
# check if there's enough room in buf to write n bytes
bufcheck(buf, n)
data = buf.data
GC.@preserve data unsafe_read(http, pointer(data, buf.size + 1), n)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sleepy drive by review alert

Should this be Base.ensureroom(buf, buf.ptr + n - 1) and pointer(data, buf.ptr) and similarly for bufcheck?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, perhaps, but we're updating buf.size ourselves (and not buf.ptr), so buf.size is correct for this case. From what I can tell, buf.size tracks the last byte written to the IOBuffer, whereas buf.ptr tracks the next by to read from the IOBuffer. So if that's right, then I think tracking buf.size is correct here.

Copy link
Collaborator

@Drvi Drvi Jan 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AFAIK, writing to IOBuffer should advance both ptr and size:

julia> io = IOBuffer(UInt8[1, 2, 3, 4, 5, 6], write=true)
IOBuffer(data=UInt8[...], readable=false, writable=true, seekable=true, append=false, size=0, maxsize=Inf, ptr=1, mark=-1)

julia> write(io, "aaaa")
4

julia> io
IOBuffer(data=UInt8[...], readable=false, writable=true, seekable=true, append=false, size=4, maxsize=Inf, ptr=5, mark=-1)

julia> seekstart(io)
IOBuffer(data=UInt8[...], readable=false, writable=true, seekable=true, append=false, size=4, maxsize=Inf, ptr=1, mark=-1)

julia> write(io, "aa")
2

julia> io
IOBuffer(data=UInt8[...], readable=false, writable=true, seekable=true, append=false, size=4, maxsize=Inf, ptr=3, mark=-1)

Note that seekstart will allow us reuse the parts of the buffer that were pre-allocated earlier. So I think we should be working with ptr here, as it says which part of buffer is free (>= ptr) and which is not (< ptr)

buf.size += n
end

function Base.read(http::Stream)
buf = PipeBuffer()
Base.read(http::Stream, buf::Base.GenericIOBuffer=PipeBuffer()) = take!(readall!(http, buf))

function readall!(http::Stream, buf::Base.GenericIOBuffer=PipeBuffer())
if ntoread(http) == unknown_length
while !eof(http)
readbytes!(http, buf)
Expand All @@ -299,7 +303,7 @@ function Base.read(http::Stream)
readbytes!(http, buf, ntoread(http))
end
end
return take!(buf)
return buf
end

function Base.readuntil(http::Stream, f::Function)::ByteView
Expand Down
6 changes: 3 additions & 3 deletions src/clientlayers/RetryRequest.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ e.g. `Sockets.DNSError`, `Base.EOFError` and `HTTP.StatusError`
"""
function retrylayer(handler)
return function(req::Request; retry::Bool=true, retries::Int=4,
retry_delays::ExponentialBackOff=ExponentialBackOff(n = retries), retry_check=FALSE,
retry_delays::ExponentialBackOff=ExponentialBackOff(n = retries, factor=3.0), retry_check=FALSE,
retry_non_idempotent::Bool=false, kw...)
if !retry || retries == 0
# no retry
Expand Down Expand Up @@ -61,8 +61,8 @@ function retrylayer(handler)
@debugv 1 "🚷 No Retry: $(no_retry_reason(ex, req))"
end
return s, retry
end)

end
)
return retry_request(req; kw...)
end
end
Expand Down
28 changes: 27 additions & 1 deletion src/clientlayers/StreamRequest.jl
Original file line number Diff line number Diff line change
Expand Up @@ -128,10 +128,36 @@ function readbody(stream::Stream, res::Response, decompress::Union{Nothing, Bool
end
end

# 2 most common types of IOBuffers
const IOBuffers = Union{IOBuffer, Base.GenericIOBuffer{SubArray{UInt8, 1, Vector{UInt8}, Tuple{UnitRange{Int64}}, true}}}

function readbody!(stream::Stream, res::Response, buf_or_stream)
if !iserror(res)
if isbytes(res.body)
res.body = read(buf_or_stream)
if length(res.body) > 0
# user-provided buffer to read response body into
# specify write=true to make the buffer writable
# but also specify maxsize, which means it won't be grown
# (we don't want to be changing the user's buffer for them)
body = IOBuffer(res.body; write=true, maxsize=length(res.body))
if buf_or_stream isa BufferStream
# if it's a BufferStream, the response body was gzip encoded
# so using the default write is fastest because it utilizes
# readavailable under the hood, for which BufferStream is optimized
write(body, buf_or_stream)
elseif buf_or_stream isa Stream
# for HTTP.Stream, there's already an optimized read method
# that just needs an IOBuffer to write into
readall!(buf_or_stream, body)
else
error("unreachable")
end
else
res.body = read(buf_or_stream)
end
elseif (res.body isa IOBuffers || res.body isa Base.GenericIOBuffer) && buf_or_stream isa Stream
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this could be simplified, since:

julia> IOBuffers <: Base.GenericIOBuffer
true

# optimization for IOBuffer response_stream to avoid temporary allocations
readall!(buf_or_stream, res.body)
else
write(res.body, buf_or_stream)
end
Expand Down
28 changes: 28 additions & 0 deletions test/client.jl
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,34 @@ end
x["headers"]["Host"] == y["headers"]["Host"] &&
x["headers"]["User-Agent"] == y["headers"]["User-Agent"]
end

# pass pre-allocated buffer
body = zeros(UInt8, 100)
r = HTTP.get("https://$httpbin/bytes/100"; response_stream=body, socket_type_tls=tls)
@test body === r.body

# wrapping pre-allocated buffer in IOBuffer will write to buffer directly
io = IOBuffer(body; write=true)
r = HTTP.get("https://$httpbin/bytes/100"; response_stream=io, socket_type_tls=tls)
@test body === r.body.data

# if provided buffer is too small, we won't grow it for user
body = zeros(UInt8, 10)
@test_throws HTTP.RequestError HTTP.get("https://$httpbin/bytes/100"; response_stream=body, socket_type_tls=tls, retry=false)

# also won't shrink it if buffer provided is larger than response body
body = zeros(UInt8, 10)
r = HTTP.get("https://$httpbin/bytes/5"; response_stream=body, socket_type_tls=tls)
@test body === r.body
@test length(body) == 10
@test HTTP.header(r, "Content-Length") == "5"

# but if you wrap it in a writable IOBuffer, we will grow it
io = IOBuffer(body; write=true)
r = HTTP.get("https://$httpbin/bytes/100"; response_stream=io, socket_type_tls=tls)
# same Array, though it was resized larger
@test body === r.body.data
@test length(body) == 100
end

@testset "Client Body Posting - Vector{UTF8}, String, IOStream, IOBuffer, BufferStream, Dict, NamedTuple" begin
Expand Down