Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

permessage-deflate decompression support in websocket #3263

Merged
merged 20 commits into from
May 20, 2024
1 change: 1 addition & 0 deletions lib/web/fetch/data-url.js
Original file line number Diff line number Diff line change
Expand Up @@ -737,6 +737,7 @@ module.exports = {
collectAnHTTPQuotedString,
serializeAMimeType,
removeChars,
removeHTTPWhitespace,
minimizeSupportedMimeType,
HTTP_TOKEN_CODEPOINTS,
isomorphicDecode
Expand Down
27 changes: 18 additions & 9 deletions lib/web/websocket/connection.js
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ const {
kReceivedClose,
kResponse
} = require('./symbols')
const { fireEvent, failWebsocketConnection, isClosing, isClosed, isEstablished } = require('./util')
const { fireEvent, failWebsocketConnection, isClosing, isClosed, isEstablished, parseExtensions } = require('./util')
const { channels } = require('../../core/diagnostics')
const { CloseEvent } = require('./events')
const { makeRequest } = require('../fetch/request')
Expand All @@ -31,7 +31,7 @@ try {
* @param {URL} url
* @param {string|string[]} protocols
* @param {import('./websocket').WebSocket} ws
* @param {(response: any) => void} onEstablish
* @param {(response: any, extensions: string[] | undefined) => void} onEstablish
* @param {Partial<import('../../types/websocket').WebSocketInit>} options
*/
function establishWebSocketConnection (url, protocols, client, ws, onEstablish, options) {
Expand Down Expand Up @@ -91,12 +91,11 @@ function establishWebSocketConnection (url, protocols, client, ws, onEstablish,
// 9. Let permessageDeflate be a user-agent defined
// "permessage-deflate" extension header value.
// https://github.com/mozilla/gecko-dev/blob/ce78234f5e653a5d3916813ff990f053510227bc/netwerk/protocol/websocket/WebSocketChannel.cpp#L2673
// TODO: enable once permessage-deflate is supported
const permessageDeflate = '' // 'permessage-deflate; 15'
const permessageDeflate = 'permessage-deflate; client_max_window_bits'

// 10. Append (`Sec-WebSocket-Extensions`, permessageDeflate) to
// request’s header list.
// request.headersList.append('sec-websocket-extensions', permessageDeflate)
request.headersList.append('sec-websocket-extensions', permessageDeflate)

// 11. Fetch request with useParallelQueue set to true, and
// processResponse given response being these steps:
Expand Down Expand Up @@ -167,10 +166,15 @@ function establishWebSocketConnection (url, protocols, client, ws, onEstablish,
// header field to determine which extensions are requested is
// discussed in Section 9.1.)
const secExtension = response.headersList.get('Sec-WebSocket-Extensions')
let extensions

if (secExtension !== null && secExtension !== permessageDeflate) {
failWebsocketConnection(ws, 'Received different permessage-deflate than the one set.')
return
if (secExtension !== null) {
extensions = parseExtensions(secExtension)

if (!extensions.has('permessage-deflate')) {
failWebsocketConnection(ws, 'Sec-WebSocket-Extensions header does not match.')
return
}
}

// 6. If the response includes a |Sec-WebSocket-Protocol| header field
Expand Down Expand Up @@ -206,7 +210,7 @@ function establishWebSocketConnection (url, protocols, client, ws, onEstablish,
})
}

onEstablish(response)
onEstablish(response, extensions)
}
})

Expand Down Expand Up @@ -290,6 +294,11 @@ function onSocketData (chunk) {
*/
function onSocketClose () {
const { ws } = this
const { [kResponse]: response } = ws

response.socket.off('data', onSocketData)
response.socket.off('close', onSocketClose)
response.socket.off('error', onSocketError)

// If the TCP connection was closed after the
// WebSocket closing handshake was completed, the WebSocket connection
Expand Down
10 changes: 9 additions & 1 deletion lib/web/websocket/constants.js
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,13 @@ const parserStates = {

const emptyBuffer = Buffer.allocUnsafe(0)

const sendHints = {
string: 1,
typedArray: 2,
arrayBuffer: 3,
blob: 4
}

module.exports = {
uid,
sentCloseFrameState,
Expand All @@ -54,5 +61,6 @@ module.exports = {
opcodes,
maxUnsigned16Bit,
parserStates,
emptyBuffer
emptyBuffer,
sendHints
}
70 changes: 70 additions & 0 deletions lib/web/websocket/permessage-deflate.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
'use strict'

const { createInflateRaw, Z_DEFAULT_WINDOWBITS } = require('node:zlib')
const { isValidClientWindowBits } = require('./util')

const tail = Buffer.from([0x00, 0x00, 0xff, 0xff])
const kBuffer = Symbol('kBuffer')
const kLength = Symbol('kLength')

class PerMessageDeflate {
/** @type {import('node:zlib').InflateRaw} */
#inflate

#options = {}

constructor (extensions) {
this.#options.serverNoContextTakeover = extensions.has('server_no_context_takeover')
this.#options.serverMaxWindowBits = extensions.get('server_max_window_bits')
}

decompress (chunk, fin, callback) {
// An endpoint uses the following algorithm to decompress a message.
// 1. Append 4 octets of 0x00 0x00 0xff 0xff to the tail end of the
// payload of the message.
// 2. Decompress the resulting data using DEFLATE.

if (!this.#inflate) {
let windowBits = Z_DEFAULT_WINDOWBITS

if (this.#options.serverMaxWindowBits) { // empty values default to Z_DEFAULT_WINDOWBITS
if (!isValidClientWindowBits(this.#options.serverMaxWindowBits)) {
callback(new Error('Invalid server_max_window_bits'))
return
}

windowBits = Number.parseInt(this.#options.serverMaxWindowBits)
}

this.#inflate = createInflateRaw({ windowBits })
this.#inflate[kBuffer] = []
this.#inflate[kLength] = 0

this.#inflate.on('data', (data) => {
this.#inflate[kBuffer].push(data)
this.#inflate[kLength] += data.length
})

this.#inflate.on('error', (err) => {
this.#inflate = null
callback(err)
})
}

this.#inflate.write(chunk)
if (fin) {
this.#inflate.write(tail)
}

this.#inflate.flush(() => {
const full = Buffer.concat(this.#inflate[kBuffer], this.#inflate[kLength])

this.#inflate[kBuffer].length = 0
this.#inflate[kLength] = 0

callback(null, full)
})
}
}

module.exports = { PerMessageDeflate }
79 changes: 63 additions & 16 deletions lib/web/websocket/receiver.js
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ const {
} = require('./util')
const { WebsocketFrameSend } = require('./frame')
const { closeWebSocketConnection } = require('./connection')
const { PerMessageDeflate } = require('./permessage-deflate')

// This code was influenced by ws released under the MIT license.
// Copyright (c) 2011 Einar Otto Stangvik <[email protected]>
Expand All @@ -33,10 +34,18 @@ class ByteParser extends Writable {
#info = {}
#fragments = []

constructor (ws) {
/** @type {Map<string, PerMessageDeflate>} */
#extensions

constructor (ws, extensions) {
super()

this.ws = ws
this.#extensions = extensions == null ? new Map() : extensions

if (this.#extensions.has('permessage-deflate')) {
this.#extensions.set('permessage-deflate', new PerMessageDeflate(extensions))
}
}

/**
Expand Down Expand Up @@ -91,7 +100,16 @@ class ByteParser extends Writable {
// the negotiated extensions defines the meaning of such a nonzero
// value, the receiving endpoint MUST _Fail the WebSocket
// Connection_.
if (rsv1 !== 0 || rsv2 !== 0 || rsv3 !== 0) {
// This document allocates the RSV1 bit of the WebSocket header for
// PMCEs and calls the bit the "Per-Message Compressed" bit. On a
// WebSocket connection where a PMCE is in use, this bit indicates
// whether a message is compressed or not.
if (rsv1 !== 0 && !this.#extensions.has('permessage-deflate')) {
failWebsocketConnection(this.ws, 'Expected RSV1 to be clear.')
return
}

if (rsv2 !== 0 || rsv3 !== 0) {
failWebsocketConnection(this.ws, 'RSV1, RSV2, RSV3 must be clear')
return
}
Expand Down Expand Up @@ -122,7 +140,7 @@ class ByteParser extends Writable {
return
}

if (isContinuationFrame(opcode) && this.#fragments.length === 0) {
if (isContinuationFrame(opcode) && this.#fragments.length === 0 && !this.#info.compressed) {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is probably wrong (?)

failWebsocketConnection(this.ws, 'Unexpected continuation frame')
return
}
Expand All @@ -138,6 +156,7 @@ class ByteParser extends Writable {

if (isTextBinaryFrame(opcode)) {
this.#info.binaryType = opcode
this.#info.compressed = rsv1 !== 0
}

this.#info.opcode = opcode
Expand Down Expand Up @@ -185,21 +204,50 @@ class ByteParser extends Writable {

if (isControlFrame(this.#info.opcode)) {
this.#loop = this.parseControlFrame(body)
this.#state = parserStates.INFO
} else {
this.#fragments.push(body)

// If the frame is not fragmented, a message has been received.
// If the frame is fragmented, it will terminate with a fin bit set
// and an opcode of 0 (continuation), therefore we handle that when
// parsing continuation frames, not here.
if (!this.#info.fragmented && this.#info.fin) {
const fullMessage = Buffer.concat(this.#fragments)
websocketMessageReceived(this.ws, this.#info.binaryType, fullMessage)
this.#fragments.length = 0
if (!this.#info.compressed) {
this.#fragments.push(body)

// If the frame is not fragmented, a message has been received.
// If the frame is fragmented, it will terminate with a fin bit set
// and an opcode of 0 (continuation), therefore we handle that when
// parsing continuation frames, not here.
if (!this.#info.fragmented && this.#info.fin) {
const fullMessage = Buffer.concat(this.#fragments)
websocketMessageReceived(this.ws, this.#info.binaryType, fullMessage)
this.#fragments.length = 0
}

this.#state = parserStates.INFO
} else {
this.#extensions.get('permessage-deflate').decompress(body, this.#info.fin, (error, data) => {
if (error) {
closeWebSocketConnection(this.ws, 1007, error.message, error.message.length)
return
}

this.#fragments.push(data)

if (!this.#info.fin) {
this.#state = parserStates.INFO
this.#loop = true
this.run(callback)
return
}

websocketMessageReceived(this.ws, this.#info.binaryType, Buffer.concat(this.#fragments))

this.#loop = true
this.#state = parserStates.INFO
this.run(callback)
this.#fragments.length = 0
})

this.#loop = false
break
}
}

this.#state = parserStates.INFO
}
}
}
Expand Down Expand Up @@ -333,7 +381,6 @@ class ByteParser extends Writable {
this.ws[kReadyState] = states.CLOSING
this.ws[kReceivedClose] = true

this.end()
return false
} else if (opcode === opcodes.PING) {
// Upon receipt of a Ping frame, an endpoint MUST send a Pong frame in
Expand Down
85 changes: 85 additions & 0 deletions lib/web/websocket/sender.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
'use strict'

const { WebsocketFrameSend } = require('./frame')
const { opcodes, sendHints } = require('./constants')

/** @type {Uint8Array} */
const FastBuffer = Buffer[Symbol.species]

class SendQueue {
#queued = new Set()
#size = 0

/** @type {import('net').Socket} */
#socket

constructor (socket) {
this.#socket = socket
}

add (item, cb, hint) {
if (hint !== sendHints.blob) {
const data = clone(item, hint)

if (this.#size === 0) {
this.#dispatch(data, cb, hint)
} else {
this.#queued.add([data, cb, true, hint])
this.#size++

this.#run()
}

return
}

const promise = item.arrayBuffer()
const queue = [null, cb, false, hint]
promise.then((ab) => {
queue[0] = clone(ab, hint)
queue[2] = true

this.#run()
})

this.#queued.add(queue)
this.#size++
}

#run () {
for (const queued of this.#queued) {
const [data, cb, done, hint] = queued

if (!done) return

this.#queued.delete(queued)
this.#size--

this.#dispatch(data, cb, hint)
}
}

#dispatch (data, cb, hint) {
const frame = new WebsocketFrameSend()
const opcode = hint === sendHints.string ? opcodes.TEXT : opcodes.BINARY

frame.frameData = data
const buffer = frame.createFrame(opcode)

this.#socket.write(buffer, cb)
}
}

function clone (data, hint) {
switch (hint) {
case sendHints.string:
return Buffer.from(data)
case sendHints.arrayBuffer:
case sendHints.blob:
return new FastBuffer(data)
case sendHints.typedArray:
return Buffer.copyBytesFrom(data)
}
}

module.exports = { SendQueue }
Loading
Loading