Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add option to decode CBOR bignum to interface{} as *big.Int #456

Merged
merged 1 commit into from
Dec 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 41 additions & 1 deletion decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ const (
// FieldNameMatchingPreferCaseSensitive prefers to decode map items into struct fields whose names (or tag
// names) exactly match the item's key. If there is no such field, a map item will be decoded into a field whose
// name is a case-insensitive match for the item's key.
FieldNameMatchingPreferCaseSensitive = iota
FieldNameMatchingPreferCaseSensitive FieldNameMatchingMode = iota

// FieldNameMatchingCaseSensitive decodes map items only into a struct field whose name (or tag name) is an
// exact match for the item's key.
Expand All @@ -374,6 +374,25 @@ func (fnmm FieldNameMatchingMode) valid() bool {
return fnmm >= 0 && fnmm < maxFieldNameMatchingMode
}

// BigIntDecMode specifies how to decode CBOR bignum to Go interface{}.
type BigIntDecMode int

const (
// BigIntDecodeValue makes CBOR bignum decode to big.Int (instead of *big.Int)
// when unmarshalling into a Go interface{}.
BigIntDecodeValue BigIntDecMode = iota

// BigIntDecodePointer makes CBOR bignum decode to *big.Int when
// unmarshalling into a Go interface{}.
BigIntDecodePointer

maxBigIntDecMode
)

func (bidm BigIntDecMode) valid() bool {
return bidm >= 0 && bidm < maxBigIntDecMode
}

// DecOptions specifies decoding options.
type DecOptions struct {
// DupMapKey specifies whether to enforce duplicate map key.
Expand Down Expand Up @@ -426,6 +445,9 @@ type DecOptions struct {

// FieldNameMatching specifies how string keys in CBOR maps are matched to Go struct field names.
FieldNameMatching FieldNameMatchingMode

// BigIntDec specifies how to decode CBOR bignum to Go interface{}.
BigIntDec BigIntDecMode
}

// DecMode returns DecMode with immutable options and no tags (safe for concurrency).
Expand Down Expand Up @@ -537,6 +559,9 @@ func (opts DecOptions) decMode() (*decMode, error) {
if !opts.FieldNameMatching.valid() {
return nil, errors.New("cbor: invalid FieldNameMatching " + strconv.Itoa(int(opts.FieldNameMatching)))
}
if !opts.BigIntDec.valid() {
return nil, errors.New("cbor: invalid BigIntDec " + strconv.Itoa(int(opts.BigIntDec)))
}
dm := decMode{
dupMapKey: opts.DupMapKey,
timeTag: opts.TimeTag,
Expand All @@ -551,6 +576,7 @@ func (opts DecOptions) decMode() (*decMode, error) {
defaultMapType: opts.DefaultMapType,
utf8: opts.UTF8,
fieldNameMatching: opts.FieldNameMatching,
bigIntDec: opts.BigIntDec,
}
return &dm, nil
}
Expand Down Expand Up @@ -616,6 +642,7 @@ type decMode struct {
defaultMapType reflect.Type
utf8 UTF8Mode
fieldNameMatching FieldNameMatchingMode
bigIntDec BigIntDecMode
}

var defaultDecMode, _ = DecOptions{}.decMode()
Expand All @@ -635,6 +662,7 @@ func (dm *decMode) DecOptions() DecOptions {
ExtraReturnErrors: dm.extraReturnErrors,
UTF8: dm.utf8,
FieldNameMatching: dm.fieldNameMatching,
BigIntDec: dm.bigIntDec,
}
}

Expand Down Expand Up @@ -1184,6 +1212,10 @@ func (d *decoder) parse(skipSelfDescribedTag bool) (interface{}, error) { //noli
bi := new(big.Int).SetUint64(val)
bi.Add(bi, big.NewInt(1))
bi.Neg(bi)

if d.dm.bigIntDec == BigIntDecodePointer {
return bi, nil
}
return *bi, nil
}
nValue := int64(-1) ^ int64(val)
Expand All @@ -1208,12 +1240,20 @@ func (d *decoder) parse(skipSelfDescribedTag bool) (interface{}, error) { //noli
case 2:
b := d.parseByteString()
bi := new(big.Int).SetBytes(b)

if d.dm.bigIntDec == BigIntDecodePointer {
return bi, nil
}
return *bi, nil
case 3:
b := d.parseByteString()
bi := new(big.Int).SetBytes(b)
bi.Add(bi, big.NewInt(1))
bi.Neg(bi)

if d.dm.bigIntDec == BigIntDecodePointer {
return bi, nil
}
return *bi, nil
}

Expand Down
102 changes: 102 additions & 0 deletions decode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6263,3 +6263,105 @@ func TestDecodeFieldNameMatching(t *testing.T) {
})
}
}

func TestInvalidBigIntDecMode(t *testing.T) {
for _, tc := range []struct {
name string
opts DecOptions
wantErrorMsg string
}{
{
name: "below range of valid modes",
opts: DecOptions{BigIntDec: -1},
wantErrorMsg: "cbor: invalid BigIntDec -1",
},
{
name: "above range of valid modes",
opts: DecOptions{BigIntDec: 101},
wantErrorMsg: "cbor: invalid BigIntDec 101",
},
} {
t.Run(tc.name, func(t *testing.T) {
_, err := tc.opts.DecMode()
if err == nil {
t.Errorf("DecMode() didn't return an error")
} else if err.Error() != tc.wantErrorMsg {
t.Errorf("DecMode() returned error %q, want %q", err.Error(), tc.wantErrorMsg)
}
})
}
}

func TestDecodeBignumToEmptyInterface(t *testing.T) {
decOptionsDecodeToBigIntValue := DecOptions{BigIntDec: BigIntDecodeValue}
decOptionsDecodeToBigIntPointer := DecOptions{BigIntDec: BigIntDecodePointer}

cborDataPositiveBignum := hexDecode("c249010000000000000000") // positive bignum: 18446744073709551616
pbn, _ := new(big.Int).SetString("18446744073709551616", 10)

cborDataNegativeBignum := hexDecode("c349010000000000000000") // negative bignum: -18446744073709551617
nbn, _ := new(big.Int).SetString("-18446744073709551617", 10)

cborDataLargeNegativeInt := hexDecode("3bffffffffffffffff") // -18446744073709551616
ni, _ := new(big.Int).SetString("-18446744073709551616", 10)

testCases := []struct {
name string
opts DecOptions
cborData []byte
wantValue interface{}
}{
{
name: "decode positive bignum to big.Int",
opts: decOptionsDecodeToBigIntValue,
cborData: cborDataPositiveBignum,
wantValue: *pbn,
},
{
name: "decode negative bignum to big.Int",
opts: decOptionsDecodeToBigIntValue,
cborData: cborDataNegativeBignum,
wantValue: *nbn,
},
{
name: "decode large negative int to big.Int",
opts: decOptionsDecodeToBigIntValue,
cborData: cborDataLargeNegativeInt,
wantValue: *ni,
},
{
name: "decode positive bignum to *big.Int",
opts: decOptionsDecodeToBigIntPointer,
cborData: cborDataPositiveBignum,
wantValue: pbn,
},
{
name: "decode negative bignum to *big.Int",
opts: decOptionsDecodeToBigIntPointer,
cborData: cborDataNegativeBignum,
wantValue: nbn,
},
{
name: "decode large negative int to *big.Int",
opts: decOptionsDecodeToBigIntPointer,
cborData: cborDataLargeNegativeInt,
wantValue: ni,
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
decMode, _ := tc.opts.DecMode()

var v interface{}
err := decMode.Unmarshal(tc.cborData, &v)
if err != nil {
t.Errorf("Unmarshal(0x%x) to empty interface returned error %v", tc.cborData, err)
} else {
if !reflect.DeepEqual(v, tc.wantValue) {
t.Errorf("Unmarshal(0x%x) = %v (%T), want %v (%T)", tc.cborData, v, v, tc.wantValue, tc.wantValue)
}
}
})
}
}
Loading