diff --git a/decode.go b/decode.go index e37931f4..64651918 100644 --- a/decode.go +++ b/decode.go @@ -539,6 +539,43 @@ func NewSimpleValueRegistryFromDefaults(fns ...func(*SimpleValueRegistry) error) return &r, nil } +// NaNMode specifies how to decode floating-point values (major type 7, additional information 25 +// through 27) representing NaN (not-a-number). +type NaNMode int + +const ( + // NaNDecodeAllowed will decode NaN values to Go float32 or float64. + NaNDecodeAllowed NaNMode = iota + + // NaNDecodeForbidden will return an UnacceptableDataItemError on an attempt to decode a NaN value. + NaNDecodeForbidden + + maxNaNDecode +) + +func (ndm NaNMode) valid() bool { + return ndm >= 0 && ndm < maxNaNDecode +} + +// InfMode specifies how to decode floating-point values (major type 7, additional information 25 +// through 27) representing positive or negative infinity. +type InfMode int + +const ( + // InfDecodeAllowed will decode infinite values to Go float32 or float64. + InfDecodeAllowed InfMode = iota + + // InfDecodeForbidden will return an UnacceptableDataItemError on an attempt to decode an + // infinite value. + InfDecodeForbidden + + maxInfDecode +) + +func (idm InfMode) valid() bool { + return idm >= 0 && idm < maxInfDecode +} + // DecOptions specifies decoding options. type DecOptions struct { // DupMapKey specifies whether to enforce duplicate map key. @@ -645,6 +682,14 @@ type DecOptions struct { // Users may provide a custom SimpleValueRegistry constructed via // NewSimpleValueRegistryFromDefaults. SimpleValues *SimpleValueRegistry + + // NaN specifies how to decode floating-point values (major type 7, additional information + // 25 through 27) representing NaN (not-a-number). + NaN NaNMode + + // Inf specifies how to decode floating-point values (major type 7, additional information + // 25 through 27) representing positive or negative infinity. + Inf InfMode } // DecMode returns DecMode with immutable options and no tags (safe for concurrency). @@ -815,6 +860,14 @@ func (opts DecOptions) decMode() (*decMode, error) { return nil, errors.New("cbor: invalid TimeTagToAny " + strconv.Itoa(int(opts.TimeTagToAny))) } + if !opts.NaN.valid() { + return nil, errors.New("cbor: invalid NaNDec " + strconv.Itoa(int(opts.NaN))) + } + + if !opts.Inf.valid() { + return nil, errors.New("cbor: invalid InfDec " + strconv.Itoa(int(opts.Inf))) + } + dm := decMode{ dupMapKey: opts.DupMapKey, timeTag: opts.TimeTag, @@ -836,6 +889,8 @@ func (opts DecOptions) decMode() (*decMode, error) { unrecognizedTagToAny: opts.UnrecognizedTagToAny, timeTagToAny: opts.TimeTagToAny, simpleValues: simpleValues, + nanDec: opts.NaN, + infDec: opts.Inf, } return &dm, nil @@ -909,6 +964,8 @@ type decMode struct { unrecognizedTagToAny UnrecognizedTagToAnyMode timeTagToAny TimeTagToAnyMode simpleValues *SimpleValueRegistry + nanDec NaNMode + infDec InfMode } var defaultDecMode, _ = DecOptions{}.decMode() @@ -943,6 +1000,8 @@ func (dm *decMode) DecOptions() DecOptions { UnrecognizedTagToAny: dm.unrecognizedTagToAny, TimeTagToAny: dm.timeTagToAny, SimpleValues: simpleValues, + NaN: dm.nanDec, + Inf: dm.infDec, } } diff --git a/decode_test.go b/decode_test.go index afe287a0..52e6dd17 100644 --- a/decode_test.go +++ b/decode_test.go @@ -4919,6 +4919,8 @@ func TestDecOptions(t *testing.T) { UnrecognizedTagToAny: UnrecognizedTagContentToAny, TimeTagToAny: TimeTagToRFC3339, SimpleValues: simpleValues, + NaN: NaNDecodeForbidden, + Inf: InfDecodeForbidden, } ov := reflect.ValueOf(opts1) for i := 0; i < ov.NumField(); i++ { @@ -8908,3 +8910,495 @@ func TestDecModeTimeTagToAny(t *testing.T) { }) } } + +func TestDecModeInvalidNaNDec(t *testing.T) { + for _, tc := range []struct { + name string + opts DecOptions + wantErrorMsg string + }{ + { + name: "below range of valid modes", + opts: DecOptions{NaN: -1}, + wantErrorMsg: "cbor: invalid NaNDec -1", + }, + { + name: "above range of valid modes", + opts: DecOptions{NaN: 101}, + wantErrorMsg: "cbor: invalid NaNDec 101", + }, + } { + t.Run(tc.name, func(t *testing.T) { + _, err := tc.opts.DecMode() + if err == nil { + t.Errorf("DecMode() didn't return an error") + } else if err.Error() != tc.wantErrorMsg { + t.Errorf("DecMode() returned error %q, want %q", err.Error(), tc.wantErrorMsg) + } + }) + } +} + +func TestNaNDecMode(t *testing.T) { + for _, tc := range []struct { + opt NaNMode + src []byte + dst interface{} + reject bool + }{ + { + opt: NaNDecodeForbidden, + src: hexDecode("197e00"), + dst: new(interface{}), + reject: false, + }, + { + opt: NaNDecodeForbidden, + src: hexDecode("f90000"), // 0.0 + dst: new(interface{}), + reject: false, + }, + { + opt: NaNDecodeForbidden, + src: hexDecode("f90000"), // 0.0 + dst: new(float32), + reject: false, + }, + { + opt: NaNDecodeForbidden, + src: hexDecode("f90000"), // 0.0 + dst: new(float64), + reject: false, + }, + { + opt: NaNDecodeForbidden, + src: hexDecode("f90000"), // 0.0 + dst: new(time.Time), + reject: false, + }, + { + opt: NaNDecodeForbidden, + src: hexDecode("fa47c35000"), // 100000.0 + dst: new(interface{}), + reject: false, + }, + { + opt: NaNDecodeForbidden, + src: hexDecode("fa47c35000"), // 100000.0 + dst: new(float32), + reject: false, + }, + { + opt: NaNDecodeForbidden, + src: hexDecode("fa47c35000"), // 100000.0 + dst: new(float64), + reject: false, + }, + { + opt: NaNDecodeForbidden, + src: hexDecode("fa47c35000"), // 100000.0 + dst: new(time.Time), + reject: false, + }, + { + opt: NaNDecodeForbidden, + src: hexDecode("fb3ff199999999999a"), // 1.1 + dst: new(interface{}), + reject: false, + }, + { + opt: NaNDecodeForbidden, + src: hexDecode("fb3ff199999999999a"), // 1.1 + dst: new(float32), + reject: false, + }, + { + opt: NaNDecodeForbidden, + src: hexDecode("fb3ff199999999999a"), // 1.1 + dst: new(float64), + reject: false, + }, + { + opt: NaNDecodeForbidden, + src: hexDecode("fb3ff199999999999a"), // 1.1 + dst: new(time.Time), + reject: false, + }, + { + opt: NaNDecodeForbidden, + src: hexDecode("f97e00"), + dst: new(interface{}), + reject: true, + }, + { + opt: NaNDecodeForbidden, + src: hexDecode("f97e00"), + dst: new(float32), + reject: true, + }, + { + opt: NaNDecodeForbidden, + src: hexDecode("f97e00"), + dst: new(float64), + reject: true, + }, + { + opt: NaNDecodeForbidden, + src: hexDecode("f97e00"), + dst: new(time.Time), + reject: true, + }, + { + opt: NaNDecodeForbidden, + src: hexDecode("fa7fc00000"), + dst: new(interface{}), + reject: true, + }, + { + opt: NaNDecodeForbidden, + src: hexDecode("fa7fc00000"), + dst: new(float32), + reject: true, + }, + { + opt: NaNDecodeForbidden, + src: hexDecode("fa7fc00000"), + dst: new(float64), + reject: true, + }, + { + opt: NaNDecodeForbidden, + src: hexDecode("fa7fc00000"), + dst: new(time.Time), + reject: true, + }, + { + opt: NaNDecodeForbidden, + src: hexDecode("fb7ff8000000000000"), + dst: new(interface{}), + reject: true, + }, + { + opt: NaNDecodeForbidden, + src: hexDecode("fb7ff8000000000000"), + dst: new(float32), + reject: true, + }, + { + opt: NaNDecodeForbidden, + src: hexDecode("fb7ff8000000000000"), + dst: new(float64), + reject: true, + }, + { + opt: NaNDecodeForbidden, + src: hexDecode("fb7ff8000000000000"), + dst: new(time.Time), + reject: true, + }, + } { + t.Run(fmt.Sprintf("mode=%d/0x%x into %s", tc.opt, tc.src, reflect.TypeOf(tc.dst).String()), func(t *testing.T) { + dm, err := DecOptions{NaN: tc.opt}.DecMode() + if err != nil { + t.Fatal(err) + } + want := &UnacceptableDataItemError{ + CBORType: cborTypePrimitives.String(), + Message: "floating-point NaN", + } + if got := dm.Unmarshal(tc.src, tc.dst); got != nil { + if tc.reject { + if !reflect.DeepEqual(want, got) { + t.Errorf("want error: %v, got error: %v", want, got) + } + } else { + t.Errorf("unexpected error: %v", got) + } + } else if tc.reject { + t.Error("unexpected nil error") + } + }) + } +} + +func TestDecModeInvalidInfDec(t *testing.T) { + for _, tc := range []struct { + name string + opts DecOptions + wantErrorMsg string + }{ + { + name: "below range of valid modes", + opts: DecOptions{Inf: -1}, + wantErrorMsg: "cbor: invalid InfDec -1", + }, + { + name: "above range of valid modes", + opts: DecOptions{Inf: 101}, + wantErrorMsg: "cbor: invalid InfDec 101", + }, + } { + t.Run(tc.name, func(t *testing.T) { + _, err := tc.opts.DecMode() + if err == nil { + t.Errorf("DecMode() didn't return an error") + } else if err.Error() != tc.wantErrorMsg { + t.Errorf("DecMode() returned error %q, want %q", err.Error(), tc.wantErrorMsg) + } + }) + } +} + +func TestInfDecMode(t *testing.T) { + for _, tc := range []struct { + opt InfMode + src []byte + dst interface{} + reject bool + }{ + { + opt: InfDecodeForbidden, + src: hexDecode("197c00"), + dst: new(interface{}), + reject: false, + }, + { + opt: InfDecodeForbidden, + src: hexDecode("f90000"), // 0.0 + dst: new(interface{}), + reject: false, + }, + { + opt: InfDecodeForbidden, + src: hexDecode("f90000"), // 0.0 + dst: new(float32), + reject: false, + }, + { + opt: InfDecodeForbidden, + src: hexDecode("f90000"), // 0.0 + dst: new(float64), + reject: false, + }, + { + opt: InfDecodeForbidden, + src: hexDecode("f90000"), // 0.0 + dst: new(time.Time), + reject: false, + }, + { + opt: InfDecodeForbidden, + src: hexDecode("fa47c35000"), // 100000.0 + dst: new(interface{}), + reject: false, + }, + { + opt: InfDecodeForbidden, + src: hexDecode("fa47c35000"), // 100000.0 + dst: new(float32), + reject: false, + }, + { + opt: InfDecodeForbidden, + src: hexDecode("fa47c35000"), // 100000.0 + dst: new(float64), + reject: false, + }, + { + opt: InfDecodeForbidden, + src: hexDecode("fa47c35000"), // 100000.0 + dst: new(time.Time), + reject: false, + }, + { + opt: InfDecodeForbidden, + src: hexDecode("fb3ff199999999999a"), // 1.1 + dst: new(interface{}), + reject: false, + }, + { + opt: InfDecodeForbidden, + src: hexDecode("fb3ff199999999999a"), // 1.1 + dst: new(float32), + reject: false, + }, + { + opt: InfDecodeForbidden, + src: hexDecode("fb3ff199999999999a"), // 1.1 + dst: new(float64), + reject: false, + }, + { + opt: InfDecodeForbidden, + src: hexDecode("fb3ff199999999999a"), // 1.1 + dst: new(time.Time), + reject: false, + }, + { + opt: InfDecodeForbidden, + src: hexDecode("f97c00"), // Infinity + dst: new(interface{}), + reject: true, + }, + { + opt: InfDecodeForbidden, + src: hexDecode("f97c00"), // Infinity + dst: new(float32), + reject: true, + }, + { + opt: InfDecodeForbidden, + src: hexDecode("f97c00"), // Infinity + dst: new(float64), + reject: true, + }, + { + opt: InfDecodeForbidden, + src: hexDecode("f97c00"), // Infinity + dst: new(time.Time), + reject: true, + }, + { + opt: InfDecodeForbidden, + src: hexDecode("f9fc00"), // -Infinity + dst: new(interface{}), + reject: true, + }, + { + opt: InfDecodeForbidden, + src: hexDecode("f9fc00"), // -Infinity + dst: new(float32), + reject: true, + }, + { + opt: InfDecodeForbidden, + src: hexDecode("f9fc00"), // -Infinity + dst: new(float64), + reject: true, + }, + { + opt: InfDecodeForbidden, + src: hexDecode("f9fc00"), // -Infinity + dst: new(time.Time), + reject: true, + }, + { + opt: InfDecodeForbidden, + src: hexDecode("fa7f800000"), // Infinity + dst: new(interface{}), + reject: true, + }, + { + opt: InfDecodeForbidden, + src: hexDecode("fa7f800000"), // Infinity + dst: new(float32), + reject: true, + }, + { + opt: InfDecodeForbidden, + src: hexDecode("fa7f800000"), // Infinity + dst: new(float64), + reject: true, + }, + { + opt: InfDecodeForbidden, + src: hexDecode("fa7f800000"), // Infinity + dst: new(time.Time), + reject: true, + }, + { + opt: InfDecodeForbidden, + src: hexDecode("faff800000"), // -Infinity + dst: new(interface{}), + reject: true, + }, + { + opt: InfDecodeForbidden, + src: hexDecode("faff800000"), // -Infinity + dst: new(float32), + reject: true, + }, + { + opt: InfDecodeForbidden, + src: hexDecode("faff800000"), // -Infinity + dst: new(float64), + reject: true, + }, + { + opt: InfDecodeForbidden, + src: hexDecode("faff800000"), // -Infinity + dst: new(time.Time), + reject: true, + }, + { + opt: InfDecodeForbidden, + src: hexDecode("fb7ff0000000000000"), // Infinity + dst: new(interface{}), + reject: true, + }, + { + opt: InfDecodeForbidden, + src: hexDecode("fb7ff0000000000000"), // Infinity + dst: new(float32), + reject: true, + }, + { + opt: InfDecodeForbidden, + src: hexDecode("fb7ff0000000000000"), // Infinity + dst: new(float64), + reject: true, + }, + { + opt: InfDecodeForbidden, + src: hexDecode("fb7ff0000000000000"), // Infinity + dst: new(time.Time), + reject: true, + }, + { + opt: InfDecodeForbidden, + src: hexDecode("fbfff0000000000000"), // -Infinity + dst: new(interface{}), + reject: true, + }, + { + opt: InfDecodeForbidden, + src: hexDecode("fbfff0000000000000"), // -Infinity + dst: new(float32), + reject: true, + }, + { + opt: InfDecodeForbidden, + src: hexDecode("fbfff0000000000000"), // -Infinity + dst: new(float64), + reject: true, + }, + { + opt: InfDecodeForbidden, + src: hexDecode("fbfff0000000000000"), // -Infinity + dst: new(time.Time), + reject: true, + }, + } { + t.Run(fmt.Sprintf("mode=%d/0x%x into %s", tc.opt, tc.src, tc.dst), func(t *testing.T) { + dm, err := DecOptions{Inf: tc.opt}.DecMode() + if err != nil { + t.Fatal(err) + } + want := &UnacceptableDataItemError{ + CBORType: cborTypePrimitives.String(), + Message: "floating-point infinity", + } + if got := dm.Unmarshal(tc.src, tc.dst); got != nil { + if tc.reject { + if !reflect.DeepEqual(want, got) { + t.Errorf("want error: %v, got error: %v", want, got) + } + } else { + t.Errorf("unexpected error: %v", got) + } + } else if tc.reject { + t.Error("unexpected nil error") + } + }) + } +} diff --git a/encode.go b/encode.go index 79f32986..7a7eb84a 100644 --- a/encode.go +++ b/encode.go @@ -240,6 +240,9 @@ const ( // NaN payload. NaNConvertQuiet + // NaNConvertReject returns UnsupportedValueError on attempts to encode a NaN value. + NaNConvertReject + maxNaNConvert ) @@ -258,6 +261,9 @@ const ( // InfConvertNone never converts (used by CTAP2 Canonical CBOR). InfConvertNone + // InfConvertReject returns UnsupportedValueError on attempts to encode an infinite value. + InfConvertReject + maxInfConvert ) @@ -908,7 +914,10 @@ func encodeFloat(e *encoderBuffer, em *encMode, v reflect.Value) error { func encodeInf(e *encoderBuffer, em *encMode, v reflect.Value) error { f64 := v.Float() - if em.infConvert == InfConvertFloat16 { + switch em.infConvert { + case InfConvertReject: + return &UnsupportedValueError{msg: "floating-point infinity"} + case InfConvertFloat16: if f64 > 0 { e.Write(cborPositiveInfinity) } else { @@ -935,6 +944,9 @@ func encodeNaN(e *encoderBuffer, em *encMode, v reflect.Value) error { f32 := float32NaNFromReflectValue(v) return encodeFloat32(e, f32) + case NaNConvertReject: + return &UnsupportedValueError{msg: "floating-point NaN"} + default: // NaNConvertPreserveSignal, NaNConvertQuiet if v.Kind() == reflect.Float64 { f64 := v.Float() diff --git a/encode_test.go b/encode_test.go index 2ed416bd..5eb4f6e4 100644 --- a/encode_test.go +++ b/encode_test.go @@ -3004,7 +3004,7 @@ func TestInfConvert(t *testing.T) { t.Run(tc.name, func(t *testing.T) { em, err := tc.opts.EncMode() if err != nil { - t.Errorf("EncMode() returned an error %v", err) + t.Fatalf("EncMode() returned an error %v", err) } b, err := em.Marshal(tc.v) if err != nil { @@ -3013,6 +3013,23 @@ func TestInfConvert(t *testing.T) { t.Errorf("Marshal(%v) = 0x%x, want 0x%x", tc.v, b, tc.wantCborData) } }) + var vName string + switch v := tc.v.(type) { + case float32: + vName = fmt.Sprintf("0x%x", math.Float32bits(v)) + case float64: + vName = fmt.Sprintf("0x%x", math.Float64bits(v)) + } + t.Run("reject inf "+vName, func(t *testing.T) { + em, err := EncOptions{InfConvert: InfConvertReject}.EncMode() + if err != nil { + t.Fatalf("EncMode() returned an error %v", err) + } + want := &UnsupportedValueError{msg: "floating-point infinity"} + if _, got := em.Marshal(tc.v); !reflect.DeepEqual(want, got) { + t.Errorf("expected Marshal(%v) to return error: %v, got: %v", tc.v, want, got) + } + }) } } @@ -3318,6 +3335,13 @@ func TestNaNConvert(t *testing.T) { }}, } for _, tc := range testCases { + var vName string + switch v := tc.v.(type) { + case float32: + vName = fmt.Sprintf("0x%x", math.Float32bits(v)) + case float64: + vName = fmt.Sprintf("0x%x", math.Float64bits(v)) + } for _, convert := range tc.convert { var convertName string switch convert.opt.NaNConvert { @@ -3330,18 +3354,11 @@ func TestNaNConvert(t *testing.T) { case NaNConvertQuiet: convertName = "ConvertQuiet" } - var vName string - switch v := tc.v.(type) { - case float32: - vName = fmt.Sprintf("0x%x", math.Float32bits(v)) - case float64: - vName = fmt.Sprintf("0x%x", math.Float64bits(v)) - } name := convertName + "_" + vName t.Run(name, func(t *testing.T) { em, err := convert.opt.EncMode() if err != nil { - t.Errorf("EncMode() returned an error %v", err) + t.Fatalf("EncMode() returned an error %v", err) } b, err := em.Marshal(tc.v) if err != nil { @@ -3351,6 +3368,17 @@ func TestNaNConvert(t *testing.T) { } }) } + + t.Run("ConvertReject_"+vName, func(t *testing.T) { + em, err := EncOptions{NaNConvert: NaNConvertReject}.EncMode() + if err != nil { + t.Fatalf("EncMode() returned an error %v", err) + } + want := &UnsupportedValueError{msg: "floating-point NaN"} + if _, got := em.Marshal(tc.v); !reflect.DeepEqual(want, got) { + t.Errorf("expected Marshal(%v) to return error: %v, got: %v", tc.v, want, got) + } + }) } } diff --git a/valid.go b/valid.go index 11013faa..06c07d5f 100644 --- a/valid.go +++ b/valid.go @@ -7,7 +7,10 @@ import ( "encoding/binary" "errors" "io" + "math" "strconv" + + "github.com/x448/float16" ) // SyntaxError is a description of a CBOR syntax error. @@ -295,24 +298,39 @@ func (d *decoder) wellformedHead() (t cborType, ai byte, val uint64, err error) if dataLen < 3 { return 0, 0, 0, io.ErrUnexpectedEOF } - val = uint64(binary.BigEndian.Uint16(d.data[d.off : d.off+2])) - d.off += 2 + if t == cborTypePrimitives { + val = uint64(binary.BigEndian.Uint16(d.data[d.off : d.off+2])) + d.off += 2 + if err := d.acceptableFloat(float64(float16.Frombits(uint16(val)).Float32())); err != nil { + return 0, 0, 0, err + } + } return t, ai, val, nil } if ai == 26 { if dataLen < 5 { return 0, 0, 0, io.ErrUnexpectedEOF } - val = uint64(binary.BigEndian.Uint32(d.data[d.off : d.off+4])) - d.off += 4 + if t == cborTypePrimitives { + val = uint64(binary.BigEndian.Uint32(d.data[d.off : d.off+4])) + d.off += 4 + if err := d.acceptableFloat(float64(math.Float32frombits(uint32(val)))); err != nil { + return 0, 0, 0, err + } + } return t, ai, val, nil } if ai == 27 { if dataLen < 9 { return 0, 0, 0, io.ErrUnexpectedEOF } - val = binary.BigEndian.Uint64(d.data[d.off : d.off+8]) - d.off += 8 + if t == cborTypePrimitives { + val = binary.BigEndian.Uint64(d.data[d.off : d.off+8]) + d.off += 8 + if err := d.acceptableFloat(math.Float64frombits(val)); err != nil { + return 0, 0, 0, err + } + } return t, ai, val, nil } if ai == 31 { @@ -327,3 +345,19 @@ func (d *decoder) wellformedHead() (t cborType, ai byte, val uint64, err error) // ai == 28, 29, 30 return 0, 0, 0, &SyntaxError{"cbor: invalid additional information " + strconv.Itoa(int(ai)) + " for type " + t.String()} } + +func (d *decoder) acceptableFloat(f float64) error { + switch { + case d.dm.nanDec == NaNDecodeForbidden && math.IsNaN(f): + return &UnacceptableDataItemError{ + CBORType: cborTypePrimitives.String(), + Message: "floating-point NaN", + } + case d.dm.infDec == InfDecodeForbidden && math.IsInf(f, 0): + return &UnacceptableDataItemError{ + CBORType: cborTypePrimitives.String(), + Message: "floating-point infinity", + } + } + return nil +}