Skip to content

Commit

Permalink
Add MapSortMode to MarshalOptions
Browse files Browse the repository at this point in the history
  • Loading branch information
rvagg committed Jul 27, 2021
1 parent 4987af7 commit 083c395
Show file tree
Hide file tree
Showing 3 changed files with 151 additions and 48 deletions.
66 changes: 66 additions & 0 deletions codec/cbor/roundtrip_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
package cbor

import (
"bytes"
"strings"
"testing"

. "github.com/warpfork/go-wish"

"github.com/ipld/go-ipld-prime/fluent"
basicnode "github.com/ipld/go-ipld-prime/node/basic"
)

var n = fluent.MustBuildMap(basicnode.Prototype__Map{}, 4, func(na fluent.MapAssembler) {
na.AssembleEntry("plain").AssignString("olde string")
na.AssembleEntry("map").CreateMap(2, func(na fluent.MapAssembler) {
na.AssembleEntry("one").AssignInt(1)
na.AssembleEntry("two").AssignInt(2)
})
na.AssembleEntry("list").CreateList(2, func(na fluent.ListAssembler) {
na.AssembleValue().AssignString("three")
na.AssembleValue().AssignString("four")
})
na.AssembleEntry("nested").CreateMap(1, func(na fluent.MapAssembler) {
na.AssembleEntry("deeper").CreateList(1, func(na fluent.ListAssembler) {
na.AssembleValue().AssignString("things")
})
})
})

var serial = "\xa4eplainkolde stringcmap\xa2cone\x01ctwo\x02dlist\x82ethreedfourfnested\xa1fdeeper\x81fthings"

func TestRoundtrip(t *testing.T) {
t.Run("encoding", func(t *testing.T) {
var buf bytes.Buffer
err := Encode(n, &buf)
Require(t, err, ShouldEqual, nil)
Wish(t, buf.String(), ShouldEqual, serial)
})
t.Run("decoding", func(t *testing.T) {
buf := strings.NewReader(serial)
nb := basicnode.Prototype__Map{}.NewBuilder()
err := Decode(nb, buf)
Require(t, err, ShouldEqual, nil)
Wish(t, nb.Build(), ShouldEqual, n)
})
}

func TestRoundtripScalar(t *testing.T) {
nb := basicnode.Prototype__String{}.NewBuilder()
nb.AssignString("applesauce")
simple := nb.Build()
t.Run("encoding", func(t *testing.T) {
var buf bytes.Buffer
err := Encode(simple, &buf)
Require(t, err, ShouldEqual, nil)
Wish(t, buf.String(), ShouldEqual, `japplesauce`)
})
t.Run("decoding", func(t *testing.T) {
buf := strings.NewReader(`japplesauce`)
nb := basicnode.Prototype__String{}.NewBuilder()
err := Decode(nb, buf)
Require(t, err, ShouldEqual, nil)
Wish(t, nb.Build(), ShouldEqual, simple)
})
}
127 changes: 81 additions & 46 deletions codec/dagcbor/marshal.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,20 @@ import (
// except for the `case ipld.Kind_Link` block,
// which is dag-cbor's special sauce for schemafree links.

const (
MapSortMode_none = iota
MapSortMode_RFC7049
)

type MarshalOptions struct {
// If true, allow encoding of Link nodes as CBOR tag(42), otherwise reject
// them as unencodable
AllowLinks bool

// Control the sorting of map keys, MapSortMode_none for no sorting or
// MapSortMode_RFC7049 for length-first bytewise sorting as per RFC7049 and
// DAG-CBOR
MapSortMode int
}

func Marshal(n ipld.Node, sink shared.TokenSink, options MarshalOptions) error {
Expand All @@ -35,52 +45,7 @@ func marshal(n ipld.Node, tk *tok.Token, sink shared.TokenSink, options MarshalO
_, err := sink.Step(tk)
return err
case ipld.Kind_Map:
// Emit start of map.
tk.Type = tok.TMapOpen
tk.Length = int(n.Length()) // TODO: overflow check
if _, err := sink.Step(tk); err != nil {
return err
}
// Collect map entries, then sort by key
type entry struct {
key string
value ipld.Node
}
entries := []entry{}
for itr := n.MapIterator(); !itr.Done(); {
k, v, err := itr.Next()
if err != nil {
return err
}
keyStr, err := k.AsString()
if err != nil {
return err
}
entries = append(entries, entry{keyStr, v})
}
// RFC7049 style sort as per DAG-CBOR spec
sort.Slice(entries, func(i, j int) bool {
li, lj := len(entries[i].key), len(entries[j].key)
if li == lj {
return entries[i].key < entries[j].key
}
return li < lj
})
// Emit map contents (and recurse).
for _, e := range entries {
tk.Type = tok.TString
tk.Str = e.key
if _, err := sink.Step(tk); err != nil {
return err
}
if err := marshal(e.value, tk, sink, options); err != nil {
return err
}
}
// Emit map close.
tk.Type = tok.TMapClose
_, err := sink.Step(tk)
return err
return marshalMap(n, tk, sink, options)
case ipld.Kind_List:
// Emit start of list.
tk.Type = tok.TArrOpen
Expand Down Expand Up @@ -172,3 +137,73 @@ func marshal(n ipld.Node, tk *tok.Token, sink shared.TokenSink, options MarshalO
panic("unreachable")
}
}

func marshalMap(n ipld.Node, tk *tok.Token, sink shared.TokenSink, options MarshalOptions) error {
// Emit start of map.
tk.Type = tok.TMapOpen
tk.Length = int(n.Length()) // TODO: overflow check
if _, err := sink.Step(tk); err != nil {
return err
}
if options.MapSortMode == MapSortMode_RFC7049 {
// Collect map entries, then sort by key
type entry struct {
key string
value ipld.Node
}
entries := []entry{}
for itr := n.MapIterator(); !itr.Done(); {
k, v, err := itr.Next()
if err != nil {
return err
}
keyStr, err := k.AsString()
if err != nil {
return err
}
entries = append(entries, entry{keyStr, v})
}
// RFC7049 style sort as per DAG-CBOR spec
sort.Slice(entries, func(i, j int) bool {
li, lj := len(entries[i].key), len(entries[j].key)
if li == lj {
return entries[i].key < entries[j].key
}
return li < lj
})
// Emit map contents (and recurse).
for _, e := range entries {
tk.Type = tok.TString
tk.Str = e.key
if _, err := sink.Step(tk); err != nil {
return err
}
if err := marshal(e.value, tk, sink, options); err != nil {
return err
}
}
} else { // no sorting
// Emit map contents (and recurse).
for itr := n.MapIterator(); !itr.Done(); {
k, v, err := itr.Next()
if err != nil {
return err
}
tk.Type = tok.TString
tk.Str, err = k.AsString()
if err != nil {
return err
}
if _, err := sink.Step(tk); err != nil {
return err
}
if err := marshal(v, tk, sink, options); err != nil {
return err
}
}
}
// Emit map close.
tk.Type = tok.TMapClose
_, err := sink.Step(tk)
return err
}
6 changes: 4 additions & 2 deletions codec/dagcbor/multicodec.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ func Decode(na ipld.NodeAssembler, r io.Reader) error {
return na2.DecodeDagCbor(r)
}
// Okay, generic builder path.
return Unmarshal(na, cbor.NewDecoder(cbor.DecodeOptions{}, r), UnmarshalOptions{AllowLinks: true})
return Unmarshal(na, cbor.NewDecoder(cbor.DecodeOptions{}, r),
UnmarshalOptions{AllowLinks: true})
}

func Encode(n ipld.Node, w io.Writer) error {
Expand All @@ -40,5 +41,6 @@ func Encode(n ipld.Node, w io.Writer) error {
return n2.EncodeDagCbor(w)
}
// Okay, generic inspection path.
return Marshal(n, cbor.NewEncoder(w), MarshalOptions{AllowLinks: true})
return Marshal(n, cbor.NewEncoder(w),
MarshalOptions{AllowLinks: true, MapSortMode: MapSortMode_RFC7049})
}

0 comments on commit 083c395

Please sign in to comment.