diff --git a/stream.go b/stream.go index f60f1290..02fea43c 100644 --- a/stream.go +++ b/stream.go @@ -4,6 +4,7 @@ package cbor import ( + "bytes" "errors" "io" "reflect" @@ -65,6 +66,12 @@ func (dec *Decoder) NumBytesRead() int { return dec.bytesRead } +// Buffered returns a reader for data remaining in Decoder's buffer. +// Returned reader is valid until the next call to Decode or Skip. +func (dec *Decoder) Buffered() io.Reader { + return bytes.NewReader(dec.buf[dec.off:]) +} + // readNext() reads next CBOR data item from Reader to buffer. // It returns the size of next CBOR data item. // It also returns validation error or read error if any. diff --git a/stream_test.go b/stream_test.go index bbd74ac5..095effdd 100644 --- a/stream_test.go +++ b/stream_test.go @@ -577,6 +577,92 @@ func TestDecoderStructTag(t *testing.T) { } } +func TestDecoderBuffered(t *testing.T) { + testCases := []struct { + name string + cborData []byte + buffered []byte + decodeErr error + }{ + { + name: "empty", + cborData: []byte{}, + buffered: []byte{}, + decodeErr: io.EOF, + }, + { + name: "malformed CBOR data item", + cborData: []byte{0xc0}, + buffered: []byte{0xc0}, + decodeErr: io.ErrUnexpectedEOF, + }, + { + name: "1 CBOR data item", + cborData: []byte{0xc2, 0x49, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, + buffered: []byte{}, + }, + { + name: "2 CBOR data items", + cborData: []byte{ + // First CBOR data item + 0xc2, 0x49, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + // Second CBOR data item + 0xc3, 0x49, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + }, + buffered: []byte{ + // Second CBOR data item + 0xc3, 0x49, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + }, + }, + { + name: "1 CBOR data item followed by non-CBOR data", + cborData: []byte{ + // CBOR data item + 0xc2, 0x49, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + // Extraneous non-CBOR data ("abc") + 0x61, 0x62, 0x63, + }, + buffered: []byte{ + // non-CBOR data ("abc") + 0x61, 0x62, 0x63, + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + r := bytes.NewReader(tc.cborData) + + decoder := NewDecoder(r) + + // Decoder's buffer has no data yet. + br := decoder.Buffered() + buffered, err := io.ReadAll(br) + if err != nil { + t.Errorf("failed to read from reader returned by Buffered(): %v", err) + } + if len(buffered) > 0 { + t.Errorf("Buffered() = 0x%x (%d bytes), want 0 bytes", buffered, len(buffered)) + } + + var v interface{} + err = decoder.Decode(&v) + if err != tc.decodeErr { + t.Errorf("Decode() returned error %v, want %v", err, tc.decodeErr) + } + + br = decoder.Buffered() + buffered, err = io.ReadAll(br) + if err != nil { + t.Errorf("failed to read from reader returned by Buffered(): %v", err) + } + if !bytes.Equal(tc.buffered, buffered) { + t.Errorf("Buffered() = 0x%x (%d bytes), want 0x%x (%d bytes)", buffered, len(buffered), tc.buffered, len(tc.buffered)) + } + }) + } +} + func TestEncoder(t *testing.T) { var want bytes.Buffer var w bytes.Buffer