Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize masking with math/bits #171

Merged
merged 4 commits into from
Nov 7, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"bufio"
"context"
"crypto/rand"
"encoding/binary"
"errors"
"fmt"
"io"
Expand Down Expand Up @@ -81,7 +82,7 @@ type Conn struct {
readerMsgCtx context.Context
readerMsgHeader header
readerFrameEOF bool
readerMaskPos int
readerMaskKey uint32

setReadTimeout chan context.Context
setWriteTimeout chan context.Context
Expand Down Expand Up @@ -324,7 +325,7 @@ func (c *Conn) handleControl(ctx context.Context, h header) error {
}

if h.masked {
fastXOR(h.maskKey, 0, b)
mask(h.maskKey, b)
}

switch h.opcode {
Expand Down Expand Up @@ -446,7 +447,7 @@ func (c *Conn) reader(ctx context.Context, lock bool) (MessageType, io.Reader, e
c.readerMsgCtx = ctx
c.readerMsgHeader = h
c.readerFrameEOF = false
c.readerMaskPos = 0
c.readerMaskKey = h.maskKey
c.readMsgLeft = c.msgReadLimit.Load()

r := &messageReader{
Expand Down Expand Up @@ -532,7 +533,7 @@ func (r *messageReader) read(p []byte, lock bool) (int, error) {

r.c.readerMsgHeader = h
r.c.readerFrameEOF = false
r.c.readerMaskPos = 0
r.c.readerMaskKey = h.maskKey
}

h := r.c.readerMsgHeader
Expand All @@ -545,7 +546,7 @@ func (r *messageReader) read(p []byte, lock bool) (int, error) {
h.payloadLength -= int64(n)
r.c.readMsgLeft -= int64(n)
if h.masked {
r.c.readerMaskPos = fastXOR(h.maskKey, r.c.readerMaskPos, p)
r.c.readerMaskKey = mask(r.c.readerMaskKey, p)
}
r.c.readerMsgHeader = h

Expand Down Expand Up @@ -761,7 +762,7 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, opcode opcode, p []byte
c.writeHeader.payloadLength = int64(len(p))

if c.client {
_, err := io.ReadFull(rand.Reader, c.writeHeader.maskKey[:])
err = binary.Read(rand.Reader, binary.LittleEndian, &c.writeHeader.maskKey)
if err != nil {
return 0, fmt.Errorf("failed to generate masking key: %w", err)
}
Expand Down Expand Up @@ -809,7 +810,7 @@ func (c *Conn) realWriteFrame(ctx context.Context, h header, p []byte) (n int, e
}

if c.client {
var keypos int
maskKey := h.maskKey
for len(p) > 0 {
if c.bw.Available() == 0 {
err = c.bw.Flush()
Expand All @@ -831,7 +832,7 @@ func (c *Conn) realWriteFrame(ctx context.Context, h header, p []byte) (n int, e
return n, err
}

keypos = fastXOR(h.maskKey, keypos, c.writeBuf[i:i+n2])
maskKey = mask(maskKey, c.writeBuf[i:i+n2])

p = p[n2:]
n += n2
Expand Down
2 changes: 1 addition & 1 deletion conn_export_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ func (c *Conn) ReadFrame(ctx context.Context) (OpCode, []byte, error) {
return 0, nil, err
}
if h.masked {
fastXOR(h.maskKey, 0, b)
mask(h.maskKey, b)
}
return OpCode(h.opcode), b, nil
}
Expand Down
165 changes: 84 additions & 81 deletions frame.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"fmt"
"io"
"math"
"math/bits"
)

//go:generate stringer -type=opcode,MessageType,StatusCode -output=frame_stringer.go
Expand Down Expand Up @@ -69,7 +70,7 @@ type header struct {
payloadLength int64

masked bool
maskKey [4]byte
maskKey uint32
}

func makeWriteHeaderBuf() []byte {
Expand Down Expand Up @@ -119,7 +120,7 @@ func writeHeader(b []byte, h header) []byte {
if h.masked {
b[1] |= 1 << 7
b = b[:len(b)+4]
copy(b[len(b)-4:], h.maskKey[:])
binary.LittleEndian.PutUint32(b[len(b)-4:], h.maskKey)
}

return b
Expand Down Expand Up @@ -192,7 +193,7 @@ func readHeader(b []byte, r io.Reader) (header, error) {
}

if h.masked {
copy(h.maskKey[:], b)
h.maskKey = binary.LittleEndian.Uint32(b)
}

return h, nil
Expand Down Expand Up @@ -321,122 +322,124 @@ func (ce CloseError) bytes() ([]byte, error) {
return buf, nil
}

// xor applies the WebSocket masking algorithm to p
// with the given key where the first 3 bits of pos
// are the starting position in the key.
// fastMask applies the WebSocket masking algorithm to p
// with the given key.
// See https://tools.ietf.org/html/rfc6455#section-5.3
//
// The returned value is the position of the next byte
// to be used for masking in the key. This is so that
// unmasking can be performed without the entire frame.
func fastXOR(key [4]byte, keyPos int, b []byte) int {
// If the payload is greater than or equal to 16 bytes, then it's worth
// masking 8 bytes at a time.
// Optimization from https://github.com/golang/go/issues/31586#issuecomment-485530859
if len(b) >= 16 {
// We first create a key that is 8 bytes long
// and is aligned on the position correctly.
var alignedKey [8]byte
for i := range alignedKey {
alignedKey[i] = key[(i+keyPos)&3]
}
k := binary.LittleEndian.Uint64(alignedKey[:])
// The returned value is the correctly rotated key to
// to continue to mask/unmask the message.
//
// It is optimized for LittleEndian and expects the key
// to be in little endian.
//
// See https://github.com/golang/go/issues/31586
func mask(key uint32, b []byte) uint32 {
if len(b) >= 8 {
key64 := uint64(key)<<32 | uint64(key)

// At some point in the future we can clean these unrolled loops up.
// See https://github.com/golang/go/issues/31586#issuecomment-487436401

// Then we xor until b is less than 128 bytes.
for len(b) >= 128 {
v := binary.LittleEndian.Uint64(b)
binary.LittleEndian.PutUint64(b, v^k)
v = binary.LittleEndian.Uint64(b[8:])
binary.LittleEndian.PutUint64(b[8:], v^k)
v = binary.LittleEndian.Uint64(b[16:])
binary.LittleEndian.PutUint64(b[16:], v^k)
v = binary.LittleEndian.Uint64(b[24:])
binary.LittleEndian.PutUint64(b[24:], v^k)
v = binary.LittleEndian.Uint64(b[32:])
binary.LittleEndian.PutUint64(b[32:], v^k)
v = binary.LittleEndian.Uint64(b[40:])
binary.LittleEndian.PutUint64(b[40:], v^k)
v = binary.LittleEndian.Uint64(b[48:])
binary.LittleEndian.PutUint64(b[48:], v^k)
v = binary.LittleEndian.Uint64(b[56:])
binary.LittleEndian.PutUint64(b[56:], v^k)
v = binary.LittleEndian.Uint64(b[64:])
binary.LittleEndian.PutUint64(b[64:], v^k)
v = binary.LittleEndian.Uint64(b[72:])
binary.LittleEndian.PutUint64(b[72:], v^k)
v = binary.LittleEndian.Uint64(b[80:])
binary.LittleEndian.PutUint64(b[80:], v^k)
v = binary.LittleEndian.Uint64(b[88:])
binary.LittleEndian.PutUint64(b[88:], v^k)
v = binary.LittleEndian.Uint64(b[96:])
binary.LittleEndian.PutUint64(b[96:], v^k)
v = binary.LittleEndian.Uint64(b[104:])
binary.LittleEndian.PutUint64(b[104:], v^k)
v = binary.LittleEndian.Uint64(b[112:])
binary.LittleEndian.PutUint64(b[112:], v^k)
v = binary.LittleEndian.Uint64(b[120:])
binary.LittleEndian.PutUint64(b[120:], v^k)
binary.LittleEndian.PutUint64(b, v^key64)
v = binary.LittleEndian.Uint64(b[8:16])
binary.LittleEndian.PutUint64(b[8:16], v^key64)
v = binary.LittleEndian.Uint64(b[16:24])
binary.LittleEndian.PutUint64(b[16:24], v^key64)
v = binary.LittleEndian.Uint64(b[24:32])
binary.LittleEndian.PutUint64(b[24:32], v^key64)
v = binary.LittleEndian.Uint64(b[32:40])
binary.LittleEndian.PutUint64(b[32:40], v^key64)
v = binary.LittleEndian.Uint64(b[40:48])
binary.LittleEndian.PutUint64(b[40:48], v^key64)
v = binary.LittleEndian.Uint64(b[48:56])
binary.LittleEndian.PutUint64(b[48:56], v^key64)
v = binary.LittleEndian.Uint64(b[56:64])
binary.LittleEndian.PutUint64(b[56:64], v^key64)
v = binary.LittleEndian.Uint64(b[64:72])
binary.LittleEndian.PutUint64(b[64:72], v^key64)
v = binary.LittleEndian.Uint64(b[72:80])
binary.LittleEndian.PutUint64(b[72:80], v^key64)
v = binary.LittleEndian.Uint64(b[80:88])
binary.LittleEndian.PutUint64(b[80:88], v^key64)
v = binary.LittleEndian.Uint64(b[88:96])
binary.LittleEndian.PutUint64(b[88:96], v^key64)
v = binary.LittleEndian.Uint64(b[96:104])
binary.LittleEndian.PutUint64(b[96:104], v^key64)
v = binary.LittleEndian.Uint64(b[104:112])
binary.LittleEndian.PutUint64(b[104:112], v^key64)
v = binary.LittleEndian.Uint64(b[112:120])
binary.LittleEndian.PutUint64(b[112:120], v^key64)
v = binary.LittleEndian.Uint64(b[120:128])
binary.LittleEndian.PutUint64(b[120:128], v^key64)
b = b[128:]
}

// Then we xor until b is less than 64 bytes.
for len(b) >= 64 {
v := binary.LittleEndian.Uint64(b)
binary.LittleEndian.PutUint64(b, v^k)
v = binary.LittleEndian.Uint64(b[8:])
binary.LittleEndian.PutUint64(b[8:], v^k)
v = binary.LittleEndian.Uint64(b[16:])
binary.LittleEndian.PutUint64(b[16:], v^k)
v = binary.LittleEndian.Uint64(b[24:])
binary.LittleEndian.PutUint64(b[24:], v^k)
v = binary.LittleEndian.Uint64(b[32:])
binary.LittleEndian.PutUint64(b[32:], v^k)
v = binary.LittleEndian.Uint64(b[40:])
binary.LittleEndian.PutUint64(b[40:], v^k)
v = binary.LittleEndian.Uint64(b[48:])
binary.LittleEndian.PutUint64(b[48:], v^k)
v = binary.LittleEndian.Uint64(b[56:])
binary.LittleEndian.PutUint64(b[56:], v^k)
binary.LittleEndian.PutUint64(b, v^key64)
v = binary.LittleEndian.Uint64(b[8:16])
binary.LittleEndian.PutUint64(b[8:16], v^key64)
v = binary.LittleEndian.Uint64(b[16:24])
binary.LittleEndian.PutUint64(b[16:24], v^key64)
v = binary.LittleEndian.Uint64(b[24:32])
binary.LittleEndian.PutUint64(b[24:32], v^key64)
v = binary.LittleEndian.Uint64(b[32:40])
binary.LittleEndian.PutUint64(b[32:40], v^key64)
v = binary.LittleEndian.Uint64(b[40:48])
binary.LittleEndian.PutUint64(b[40:48], v^key64)
v = binary.LittleEndian.Uint64(b[48:56])
binary.LittleEndian.PutUint64(b[48:56], v^key64)
v = binary.LittleEndian.Uint64(b[56:64])
binary.LittleEndian.PutUint64(b[56:64], v^key64)
b = b[64:]
}

// Then we xor until b is less than 32 bytes.
for len(b) >= 32 {
v := binary.LittleEndian.Uint64(b)
binary.LittleEndian.PutUint64(b, v^k)
v = binary.LittleEndian.Uint64(b[8:])
binary.LittleEndian.PutUint64(b[8:], v^k)
v = binary.LittleEndian.Uint64(b[16:])
binary.LittleEndian.PutUint64(b[16:], v^k)
v = binary.LittleEndian.Uint64(b[24:])
binary.LittleEndian.PutUint64(b[24:], v^k)
binary.LittleEndian.PutUint64(b, v^key64)
v = binary.LittleEndian.Uint64(b[8:16])
binary.LittleEndian.PutUint64(b[8:16], v^key64)
v = binary.LittleEndian.Uint64(b[16:24])
binary.LittleEndian.PutUint64(b[16:24], v^key64)
v = binary.LittleEndian.Uint64(b[24:32])
binary.LittleEndian.PutUint64(b[24:32], v^key64)
b = b[32:]
}

// Then we xor until b is less than 16 bytes.
for len(b) >= 16 {
v := binary.LittleEndian.Uint64(b)
binary.LittleEndian.PutUint64(b, v^k)
v = binary.LittleEndian.Uint64(b[8:])
binary.LittleEndian.PutUint64(b[8:], v^k)
binary.LittleEndian.PutUint64(b, v^key64)
v = binary.LittleEndian.Uint64(b[8:16])
binary.LittleEndian.PutUint64(b[8:16], v^key64)
b = b[16:]
}

// Then we xor until b is less than 8 bytes.
for len(b) >= 8 {
v := binary.LittleEndian.Uint64(b)
binary.LittleEndian.PutUint64(b, v^k)
binary.LittleEndian.PutUint64(b, v^key64)
b = b[8:]
}
}

// Then we xor until b is less than 4 bytes.
for len(b) >= 4 {
v := binary.LittleEndian.Uint32(b)
binary.LittleEndian.PutUint32(b, v^key)
b = b[4:]
}

// xor remaining bytes.
for i := range b {
b[i] ^= key[keyPos&3]
keyPos++
b[i] ^= byte(key)
key = bits.RotateLeft32(key, -8)
}
return keyPos & 3

return key
}
Loading