From 2c9a6cb325795bfd28a6d9c26448f95faea3cb21 Mon Sep 17 00:00:00 2001 From: Ujin Date: Mon, 27 Jun 2022 05:17:57 +0300 Subject: [PATCH] permessage-deflate support --- src/WebSockets.jl | 57 +++++++++++++++++++++++++++++++++++------------ 1 file changed, 43 insertions(+), 14 deletions(-) diff --git a/src/WebSockets.jl b/src/WebSockets.jl index 14342f49d..c830d6100 100644 --- a/src/WebSockets.jl +++ b/src/WebSockets.jl @@ -1,6 +1,6 @@ module WebSockets -using Base64, LoggingExtras, UUIDs, Sockets, Random +using Base64, LoggingExtras, UUIDs, Sockets, Random, CodecZlib using MbedTLS: digest, MD_SHA1, SSLContext using ..IOExtras, ..Streams, ..ConnectionPool, ..Messages, ..Conditions, ..Servers import ..open @@ -55,7 +55,7 @@ FrameFlags(final::Bool, opcode::OpCode, masked::Bool, len::Integer; rsv1::Bool=f ) Base.show(io::IO, x::FrameFlags) = - print(io, "FrameFlags(", "final=", x.final, ", ", "opcode=", x.opcode, ", ", "masked=", x.masked, ", ", "len=", x.len, ")") + print(io, "FrameFlags(", "final=", x.final, ", isdeflate=", x.rsv1, ", ", "opcode=", x.opcode, ", ", "masked=", x.masked, ", ", "len=", x.len, ")") primitive type Mask 32 end Base.UInt32(x::Mask) = Base.bitcast(UInt32, x) @@ -91,6 +91,27 @@ function mask!(bytes::Vector{UInt8}, mask) return end +function compress(data::T) where T <: AbstractVector{UInt8} + compressed = transcode(DeflateCompressor, data) + return vcat(compressed, 0x00) +end + +function compress(data::String) + compressed = transcode(DeflateCompressor, data) + return String(vcat(compressed, 0x00)) +end + +function decompress(data::T) where T <: AbstractVector{UInt8} + decompressed = transcode(DeflateDecompressor, vcat(data, [0x00, 0x00, 0xff, 0xff, 0x03, 0x00])) + return decompressed +end + +function decompress(data::String) + decompressed = transcode(DeflateDecompressor, vcat(Vector{UInt8}(data), [0x00, 0x00, 0xff, 0xff, 0x03, 0x00])) + return String(decompressed) +end + + # send method Frame constructor function Frame(final::Bool, opcode::OpCode, client::Bool, payload::AbstractVector{UInt8}; rsv1::Bool=false, rsv2::Bool=false, rsv3::Bool=false) len, extlen = wslength(length(payload)) @@ -293,12 +314,13 @@ mutable struct WebSocket writebuffer::Vector{UInt8} readclosed::Bool writeclosed::Bool + isdeflate::Bool end const DEFAULT_MAX_FRAG = 1024 -WebSocket(io::Connection, req=Request(), resp=Response(); client::Bool=true, maxframesize::Integer=typemax(Int), maxfragmentation::Integer=DEFAULT_MAX_FRAG) = - WebSocket(uuid4(), io, req, resp, maxframesize, maxfragmentation, client, UInt8[], UInt8[], false, false) +WebSocket(io::Connection, req=Request(), resp=Response(); client::Bool=true, maxframesize::Integer=typemax(Int), maxfragmentation::Integer=DEFAULT_MAX_FRAG, isdeflate::Bool=false) = + WebSocket(uuid4(), io, req, resp, maxframesize, maxfragmentation, client, UInt8[], UInt8[], false, false, isdeflate) """ WebSockets.isclosed(ws) -> Bool @@ -347,7 +369,7 @@ WebSockets.open(url) do ws end ``` """ -function open(f::Function, url; suppress_close_error::Bool=false, verbose=false, headers=[], maxframesize::Integer=typemax(Int), maxfragmentation::Integer=DEFAULT_MAX_FRAG, kw...) +function open(f::Function, url; suppress_close_error::Bool=false, verbose=false, headers=[], maxframesize::Integer=typemax(Int), maxfragmentation::Integer=DEFAULT_MAX_FRAG, isdeflate=false, kw...) key = base64encode(rand(Random.RandomDevice(), UInt8, 16)) headers = [ "Upgrade" => "websocket", @@ -363,13 +385,14 @@ function open(f::Function, url; suppress_close_error::Bool=false, verbose=false, if header(http, "Sec-WebSocket-Accept") != hashedkey(key) throw(WebSocketError("Invalid Sec-WebSocket-Accept\n" * "$(http.message)")) end + isdeflate = occursin("permessage-deflate", header(http, "Sec-Websocket-Extensions")) # later stream logic checks to see if the HTTP message is "complete" # by seeing if ntoread is 0, which is typemax(Int) for websockets by default # so set it to 0 so it's correctly viewed as "complete" once we're done # doing websocket things http.ntoread = 0 io = http.stream - ws = WebSocket(io, http.message.request, http.message; maxframesize, maxfragmentation) + ws = WebSocket(io, http.message.request, http.message; maxframesize, maxfragmentation, isdeflate) @debugv 2 "$(ws.id): WebSocket opened" try f(ws) @@ -416,7 +439,8 @@ function listen end listen(f, args...; kw...) = Servers.listen(http -> upgrade(f, http; kw...), args...; kw...) listen!(f, args...; kw...) = Servers.listen!(http -> upgrade(f, http; kw...), args...; kw...) -function upgrade(f::Function, http::Streams.Stream; suppress_close_error::Bool=false, maxframesize::Integer=typemax(Int), maxfragmentation::Integer=DEFAULT_MAX_FRAG, kw...) +function upgrade(f::Function, http::Streams.Stream; suppress_close_error::Bool=false, maxframesize::Integer=typemax(Int), + maxfragmentation::Integer=DEFAULT_MAX_FRAG, isdeflate=false, kw...) @debugv 2 "Server websocket upgrade requested" isupgrade(http.message) || handshakeerror() if !hasheader(http, "Sec-WebSocket-Version", "13") @@ -430,10 +454,11 @@ function upgrade(f::Function, http::Streams.Stream; suppress_close_error::Bool=f setheader(http, "Connection" => "Upgrade") key = header(http, "Sec-WebSocket-Key") setheader(http, "Sec-WebSocket-Accept" => hashedkey(key)) + isdeflate && setheader(http, "Sec-Websocket-Extensions" => "permessage-deflate; client_no_context_takeover") startwrite(http) io = http.stream req = http.message - ws = WebSocket(io, req, req.response; client=false, maxframesize, maxfragmentation) + ws = WebSocket(io, req, req.response; client=false, maxframesize, maxfragmentation, isdeflate) @debugv 2 "$(ws.id): WebSocket upgraded; connection established" try f(ws) @@ -507,7 +532,7 @@ function Sockets.send(ws::WebSocket, x) # so we can appropriately set the FIN bit for the last fragmented frame nextstate = iterate(x, st) while true - n += writeframe(ws.io, Frame(nextstate === nothing, first ? opcode(item) : CONTINUATION, ws.client, payload(ws, item))) + n += writeframe(ws.io, Frame(nextstate === nothing, first ? opcode(item) : CONTINUATION, ws.client, payload(ws, item); rsv1 = first ? ws.isdeflate : false)) first = false nextstate === nothing && break item, st = nextstate @@ -516,7 +541,8 @@ function Sockets.send(ws::WebSocket, x) else # single binary or text frame for message @label write_single_frame - return writeframe(ws.io, Frame(true, opcode(x), ws.client, payload(ws, x))) + pl = ws.isdeflate ? compress(payload(ws, x)) : payload(ws, x) + return writeframe(ws.io, Frame(true, opcode(x), ws.client, pl; rsv1=ws.isdeflate)) end end @@ -603,7 +629,7 @@ end @noinline utf8check(x) = isvalid(x) || throw(WebSocketError(CloseFrameBody(1007, "Invalid UTF-8"))) function checkreadframe!(ws::WebSocket, frame::Frame) - if frame.flags.rsv1 || frame.flags.rsv2 || frame.flags.rsv3 + if frame.flags.rsv2 || frame.flags.rsv3 throw(WebSocketError(CloseFrameBody(1002, "Reserved bits set in control frame"))) end opcode = frame.flags.opcode @@ -624,8 +650,6 @@ function checkreadframe!(ws::WebSocket, frame::Frame) elseif opcode == PONG control_len_check(frame.flags.len) return false - elseif frame.flags.final && frame.flags.opcode == TEXT && frame.payload isa String - utf8check(frame.payload) end return frame.flags.final end @@ -659,7 +683,11 @@ function receive(ws::WebSocket) @debugv 2 "$(ws.id): Received frame: $frame" done = checkreadframe!(ws, frame) # common case of reading single non-control frame - done && return frame.payload + if done + payload = ws.isdeflate ? decompress(frame.payload) : frame.payload + payload isa String && utf8check(payload) + return payload + end opcode = frame.flags.opcode iscontrol(opcode) && return receive(ws) # if we're here, we're reading a fragmented message @@ -674,6 +702,7 @@ function receive(ws::WebSocket) end done && break end + payload = ws.isdeflate ? decompress(payload) : payload payload isa String && utf8check(payload) @debugv 2 "Read message: $(payload[1:min(1024, sizeof(payload))])" return payload