diff --git a/cache.go b/cache.go index cd83d8a9..dca581a2 100644 --- a/cache.go +++ b/cache.go @@ -174,13 +174,14 @@ type encodingStructType struct { } func (st *encodingStructType) getFields(em *encMode) fields { - if em.sort == SortNone { + switch em.sort { + case SortNone, SortFastShuffle: return st.fields - } - if em.sort == SortLengthFirst { + case SortLengthFirst: return st.lengthFirstFields + default: + return st.bytewiseFields } - return st.bytewiseFields } type bytewiseFieldSorter struct { diff --git a/encode.go b/encode.go index 79f32986..1ee8ed8d 100644 --- a/encode.go +++ b/encode.go @@ -11,6 +11,7 @@ import ( "io" "math" "math/big" + "math/rand" "reflect" "sort" "strconv" @@ -141,7 +142,7 @@ func (e *UnsupportedValueError) Error() string { type SortMode int const ( - // SortNone means no sorting. + // SortNone encodes map pairs and struct fields in an arbitrary order. SortNone SortMode = 0 // SortLengthFirst causes map keys or struct fields to be sorted such that: @@ -157,6 +158,12 @@ const ( // in RFC 7049bis. SortBytewiseLexical SortMode = 2 + // SortShuffle encodes map pairs and struct fields in a shuffled + // order. This mode does not guarantee an unbiased permutation, but it + // does guarantee that the runtime of the shuffle algorithm used will be + // constant. + SortFastShuffle SortMode = 3 + // SortCanonical is used in "Canonical CBOR" encoding in RFC 7049 3.9. SortCanonical SortMode = SortLengthFirst @@ -166,7 +173,7 @@ const ( // SortCoreDeterministic is used in "Core Deterministic Encoding" in RFC 7049bis. SortCoreDeterministic SortMode = SortBytewiseLexical - maxSortMode SortMode = 3 + maxSortMode SortMode = 4 ) func (sm SortMode) valid() bool { @@ -1081,8 +1088,12 @@ func (me mapEncodeFunc) encode(e *encoderBuffer, em *encMode, v reflect.Value) e if mlen == 0 { return e.WriteByte(byte(cborTypeMap)) } - if em.sort != SortNone && mlen > 1 { - return me.encodeCanonical(e, em, v) + switch em.sort { + case SortNone, SortFastShuffle: + default: + if mlen > 1 { + return me.encodeCanonical(e, em, v) + } } encodeHead(e, byte(cborTypeMap), uint64(mlen)) @@ -1234,7 +1245,13 @@ func encodeFixedLengthStruct(e *encoderBuffer, em *encMode, v reflect.Value, fld encodeHead(e, byte(cborTypeMap), uint64(len(flds))) - for i := 0; i < len(flds); i++ { + start := 0 + if em.sort == SortFastShuffle { + start = rand.Intn(len(flds)) //nolint:gosec // Don't need a CSPRNG for deck cutting. + } + + for offset := 0; offset < len(flds); offset++ { + i := (start + offset) % len(flds) f := flds[i] if !f.keyAsInt && em.fieldName == FieldNameToByteString { e.Write(f.cborNameByteString) @@ -1263,9 +1280,16 @@ func encodeStruct(e *encoderBuffer, em *encMode, v reflect.Value) (err error) { return encodeFixedLengthStruct(e, em, v, flds) } + start := 0 + if em.sort == SortFastShuffle { + start = rand.Intn(len(flds)) //nolint:gosec // Don't need a CSPRNG for deck cutting. + } + kve := getEncoderBuffer() // encode key-value pairs based on struct field tag options kvcount := 0 - for i := 0; i < len(flds); i++ { + + for offset := 0; offset < len(flds); offset++ { + i := (start + offset) % len(flds) f := flds[i] var fv reflect.Value diff --git a/encode_test.go b/encode_test.go index 2ed416bd..cfcff564 100644 --- a/encode_test.go +++ b/encode_test.go @@ -4343,3 +4343,53 @@ func TestMarshalerReturnsDisallowedCBORData(t *testing.T) { }) } } + +func TestSortModeFastShuffle(t *testing.T) { + em, err := EncOptions{Sort: SortFastShuffle}.EncMode() + if err != nil { + t.Fatal(err) + } + + // These cases are based on the assumption that even a constant-time shuffle algorithm can + // give an unbiased permutation of the keys when there are exactly 2 keys, so each trial + // should succeed with probability 1/2. + + for _, tc := range []struct { + name string + trials int + in interface{} + }{ + { + name: "fixed length struct", + trials: 1024, + in: struct{ A, B int }{}, + }, + { + name: "variable length struct", + trials: 1024, + in: struct { + A int + B int `cbor:",omitempty"` + }{B: 1}, + }, + } { + t.Run(tc.name, func(t *testing.T) { + first, err := em.Marshal(tc.in) + if err != nil { + t.Fatal(err) + } + + for i := 1; i <= tc.trials; i++ { + next, err := em.Marshal(tc.in) + if err != nil { + t.Fatal(err) + } + if string(first) != string(next) { + return + } + } + + t.Errorf("object encoded identically in %d consecutive trials using SortFastShuffle", tc.trials) + }) + } +}