Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

permessage-deflate support #868

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 62 additions & 16 deletions src/WebSockets.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module WebSockets

using Base64, LoggingExtras, UUIDs, Sockets, Random
using Base64, LoggingExtras, UUIDs, Sockets, Random, CodecZlib
using MbedTLS: digest, MD_SHA1, SSLContext
using ..IOExtras, ..Streams, ..ConnectionPool, ..Messages, ..Conditions, ..Servers
import ..open
Expand Down Expand Up @@ -55,7 +55,7 @@ FrameFlags(final::Bool, opcode::OpCode, masked::Bool, len::Integer; rsv1::Bool=f
)

Base.show(io::IO, x::FrameFlags) =
print(io, "FrameFlags(", "final=", x.final, ", ", "opcode=", x.opcode, ", ", "masked=", x.masked, ", ", "len=", x.len, ")")
print(io, "FrameFlags(", "final=", x.final, ", isdeflate=", x.rsv1, ", ", "opcode=", x.opcode, ", ", "masked=", x.masked, ", ", "len=", x.len, ")")

primitive type Mask 32 end
Base.UInt32(x::Mask) = Base.bitcast(UInt32, x)
Expand Down Expand Up @@ -90,6 +90,20 @@ function mask!(bytes::Vector{UInt8}, mask)
end
return
end
function final_deflate_codecs(t::Tuple)
CodecZlib.TranscodingStreams.finalize(t[1])
CodecZlib.TranscodingStreams.finalize(t[2])
end

function init_deflate_codecs()
codecco = DeflateCompressor()
CodecZlib.TranscodingStreams.initialize(codecco)
codecde = DeflateDecompressor()
CodecZlib.TranscodingStreams.initialize(codecde)

return (codecco, codecde)
end


# send method Frame constructor
function Frame(final::Bool, opcode::OpCode, client::Bool, payload::AbstractVector{UInt8}; rsv1::Bool=false, rsv2::Bool=false, rsv3::Bool=false)
Expand Down Expand Up @@ -293,19 +307,21 @@ mutable struct WebSocket
writebuffer::Vector{UInt8}
readclosed::Bool
writeclosed::Bool
deflate::Union{Nothing, Tuple{CodecZlib.CompressorCodec, CodecZlib.DecompressorCodec}}
end

const DEFAULT_MAX_FRAG = 1024

WebSocket(io::Connection, req=Request(), resp=Response(); client::Bool=true, maxframesize::Integer=typemax(Int), maxfragmentation::Integer=DEFAULT_MAX_FRAG) =
WebSocket(uuid4(), io, req, resp, maxframesize, maxfragmentation, client, UInt8[], UInt8[], false, false)
WebSocket(io::Connection, req=Request(), resp=Response(); client::Bool=true, maxframesize::Integer=typemax(Int), maxfragmentation::Integer=DEFAULT_MAX_FRAG, isdeflate::Bool=false) =
WebSocket(uuid4(), io, req, resp, maxframesize, maxfragmentation, client, UInt8[], UInt8[], false, false, isdeflate ? init_deflate_codecs() : nothing)

"""
WebSockets.isclosed(ws) -> Bool

Check whether a `WebSocket` has sent and received CLOSE frames.
"""
isclosed(ws::WebSocket) = ws.readclosed && ws.writeclosed
isdeflate(ws::WebSocket) = !isnothing(ws.deflate)

# Handshake
"Check whether a HTTP.Request or HTTP.Response is a websocket upgrade request/response"
Expand Down Expand Up @@ -347,7 +363,7 @@ WebSockets.open(url) do ws
end
```
"""
function open(f::Function, url; suppress_close_error::Bool=false, verbose=false, headers=[], maxframesize::Integer=typemax(Int), maxfragmentation::Integer=DEFAULT_MAX_FRAG, kw...)
function open(f::Function, url; suppress_close_error::Bool=false, verbose=false, headers=[], maxframesize::Integer=typemax(Int), maxfragmentation::Integer=DEFAULT_MAX_FRAG, isdeflate=false, kw...)
key = base64encode(rand(Random.RandomDevice(), UInt8, 16))
headers = [
"Upgrade" => "websocket",
Expand All @@ -363,13 +379,14 @@ function open(f::Function, url; suppress_close_error::Bool=false, verbose=false,
if header(http, "Sec-WebSocket-Accept") != hashedkey(key)
throw(WebSocketError("Invalid Sec-WebSocket-Accept\n" * "$(http.message)"))
end
isdeflate = occursin("permessage-deflate", header(http, "Sec-Websocket-Extensions"))
# later stream logic checks to see if the HTTP message is "complete"
# by seeing if ntoread is 0, which is typemax(Int) for websockets by default
# so set it to 0 so it's correctly viewed as "complete" once we're done
# doing websocket things
http.ntoread = 0
io = http.stream
ws = WebSocket(io, http.message.request, http.message; maxframesize, maxfragmentation)
ws = WebSocket(io, http.message.request, http.message; maxframesize, maxfragmentation, isdeflate)
@debugv 2 "$(ws.id): WebSocket opened"
try
f(ws)
Expand Down Expand Up @@ -416,7 +433,8 @@ function listen end
listen(f, args...; kw...) = Servers.listen(http -> upgrade(f, http; kw...), args...; kw...)
listen!(f, args...; kw...) = Servers.listen!(http -> upgrade(f, http; kw...), args...; kw...)

function upgrade(f::Function, http::Streams.Stream; suppress_close_error::Bool=false, maxframesize::Integer=typemax(Int), maxfragmentation::Integer=DEFAULT_MAX_FRAG, kw...)
function upgrade(f::Function, http::Streams.Stream; suppress_close_error::Bool=false, maxframesize::Integer=typemax(Int),
maxfragmentation::Integer=DEFAULT_MAX_FRAG, isdeflate=false, kw...)
@debugv 2 "Server websocket upgrade requested"
isupgrade(http.message) || handshakeerror()
if !hasheader(http, "Sec-WebSocket-Version", "13")
Expand All @@ -430,10 +448,11 @@ function upgrade(f::Function, http::Streams.Stream; suppress_close_error::Bool=f
setheader(http, "Connection" => "Upgrade")
key = header(http, "Sec-WebSocket-Key")
setheader(http, "Sec-WebSocket-Accept" => hashedkey(key))
isdeflate && setheader(http, "Sec-Websocket-Extensions" => "permessage-deflate; client_no_context_takeover")
startwrite(http)
io = http.stream
req = http.message
ws = WebSocket(io, req, req.response; client=false, maxframesize, maxfragmentation)
ws = WebSocket(io, req, req.response; client=false, maxframesize, maxfragmentation, isdeflate)
@debugv 2 "$(ws.id): WebSocket upgraded; connection established"
try
f(ws)
Expand Down Expand Up @@ -507,7 +526,7 @@ function Sockets.send(ws::WebSocket, x)
# so we can appropriately set the FIN bit for the last fragmented frame
nextstate = iterate(x, st)
while true
n += writeframe(ws.io, Frame(nextstate === nothing, first ? opcode(item) : CONTINUATION, ws.client, payload(ws, item)))
n += writeframe(ws.io, Frame(nextstate === nothing, first ? opcode(item) : CONTINUATION, ws.client, payload(ws, item); rsv1 = first ? isdeflate(ws) : false))
first = false
nextstate === nothing && break
item, st = nextstate
Expand All @@ -516,7 +535,8 @@ function Sockets.send(ws::WebSocket, x)
else
# single binary or text frame for message
@label write_single_frame
return writeframe(ws.io, Frame(true, opcode(x), ws.client, payload(ws, x)))
pl = isdeflate(ws) ? compress(ws, payload(ws, x)) : payload(ws, x)
return writeframe(ws.io, Frame(true, opcode(x), ws.client, pl; rsv1=isdeflate(ws)))
end
end

Expand All @@ -531,7 +551,7 @@ to when a PING message is received by a websocket connection.
function ping(ws::WebSocket, data=UInt8[])
@require !ws.writeclosed
@debugv 2 "$(ws.id): sending ping"
return writeframe(ws.io, Frame(true, PING, ws.client, payload(ws, data)))
return writeframe(ws.io, Frame(true, PING, ws.client, payload(ws, isdeflate(ws) ? compress(ws, data) : data)))
end

"""
Expand Down Expand Up @@ -592,18 +612,41 @@ function Base.close(ws::WebSocket, body::CloseFrameBody=CloseFrameBody(1000, "")
@assert ws.readclosed
# if we're the server, it's our job to close the underlying socket
!ws.client && isopen(ws.io) && close(ws.io)
final_deflate_codecs(ws.deflate)
return
end

# Receiving messages

function compress(ws::WebSocket, data::T) where T <: AbstractVector{UInt8}
compressed = transcode(ws.deflate[1], data)
push!(compressed, 0x00)
return compressed
end

function compress(ws::WebSocket, data::String)
compressed = transcode(ws.deflate[1], data)
push!(compressed, 0x00)
return String(compressed)
end

function decompress(ws::WebSocket, data::T) where T <: AbstractVector{UInt8}
decompressed = transcode(ws.deflate[2], append!(data, [0x00, 0x00, 0xff, 0xff, 0x03, 0x00]))
return decompressed
end

function decompress(ws::WebSocket, data::String)
decompressed = transcode(ws.deflate[2], append!(Vector{UInt8}(data), [0x00, 0x00, 0xff, 0xff, 0x03, 0x00]))
return String(decompressed)
end

# returns whether additional frames should be read
# true if fragmented message or a ping/pong frame was handled
@noinline control_len_check(len) = len > 125 && throw(WebSocketError(CloseFrameBody(1002, "Invalid length for control frame")))
@noinline utf8check(x) = isvalid(x) || throw(WebSocketError(CloseFrameBody(1007, "Invalid UTF-8")))

function checkreadframe!(ws::WebSocket, frame::Frame)
if frame.flags.rsv1 || frame.flags.rsv2 || frame.flags.rsv3
if frame.flags.rsv2 || frame.flags.rsv3
throw(WebSocketError(CloseFrameBody(1002, "Reserved bits set in control frame")))
end
opcode = frame.flags.opcode
Expand All @@ -616,16 +659,14 @@ function checkreadframe!(ws::WebSocket, frame::Frame)
if !ws.writeclosed
close(ws)
end
throw(WebSocketError(frame.payload))
throw(WebSocketError(isdeflate(ws) ? decompress(ws, frame.payload) : frame.payload))
elseif opcode == PING
control_len_check(frame.flags.len)
pong(ws, frame.payload)
return false
elseif opcode == PONG
control_len_check(frame.flags.len)
return false
elseif frame.flags.final && frame.flags.opcode == TEXT && frame.payload isa String
utf8check(frame.payload)
end
return frame.flags.final
end
Expand Down Expand Up @@ -659,7 +700,11 @@ function receive(ws::WebSocket)
@debugv 2 "$(ws.id): Received frame: $frame"
done = checkreadframe!(ws, frame)
# common case of reading single non-control frame
done && return frame.payload
if done
payload = isdeflate(ws) ? decompress(ws, frame.payload) : frame.payload
payload isa String && utf8check(payload)
return payload
end
opcode = frame.flags.opcode
iscontrol(opcode) && return receive(ws)
# if we're here, we're reading a fragmented message
Expand All @@ -674,6 +719,7 @@ function receive(ws::WebSocket)
end
done && break
end
payload = isdeflate(ws) ? decompress(ws, payload) : payload
payload isa String && utf8check(payload)
@debugv 2 "Read message: $(payload[1:min(1024, sizeof(payload))])"
return payload
Expand Down