Skip to content

Commit

Permalink
Allow rejection of NaN and Inf float values on encode and decode.
Browse files Browse the repository at this point in the history
Signed-off-by: Ben Luddy <[email protected]>
  • Loading branch information
benluddy committed Apr 2, 2024
1 parent 3cec62b commit 0312aa9
Show file tree
Hide file tree
Showing 4 changed files with 648 additions and 10 deletions.
114 changes: 114 additions & 0 deletions decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,18 @@ func (e *UnknownFieldError) Error() string {
return fmt.Sprintf("cbor: found unknown field at map element index %d", e.Index)
}

// UnacceptableDataItemError is returned when unmarshaling a CBOR input that contains a data item
// that is not acceptable to a specific CBOR-based application protocol ("invalid or unexpected" as
// described in RFC 8949 Section 5 Paragraph 3).
type UnacceptableDataItemError struct {
CBORType string
Message string
}

func (e UnacceptableDataItemError) Error() string {
return fmt.Sprintf("cbor: data item of cbor type %s is not accepted by protocol: %s", e.CBORType, e.Message)
}

// DupMapKeyMode specifies how to enforce duplicate map key. Two map keys are considered duplicates if:
// 1. When decoding into a struct, both keys match the same struct field. The keys are also
// considered duplicates if neither matches any field and decoding to interface{} would produce
Expand Down Expand Up @@ -496,6 +508,43 @@ func (tttam TimeTagToAnyMode) valid() bool {
return tttam >= 0 && tttam < maxTimeTagToAnyMode
}

// NaNDecodeMode specifies how to decode floating-point values (major type 7, additional information
// 25 through 27) representing NaN (not-a-number).
type NaNDecodeMode int

const (
// NaNDecodeAccept will decode NaN values to Go float32 or float64.
NaNDecodeAccept NaNDecodeMode = iota

// NaNDecodeReject will return an UnacceptableDataItemError on an attempt to decode a NaN value.
NaNDecodeReject

maxNaNDecode
)

func (ndm NaNDecodeMode) valid() bool {
return ndm >= 0 && ndm < maxNaNDecode
}

// InfDecodeMode specifies how to decode floating-point values (major type 7, additional information
// 25 through 27) representing positive or negative infinity.
type InfDecodeMode int

const (
// InfDecodeAccept will decode infinite values to Go float32 or float64.
InfDecodeAccept InfDecodeMode = iota

// InfDecodeReject will return an UnacceptableDataItemError on an attempt to decode an
// infinite value.
InfDecodeReject

maxInfDecode
)

func (idm InfDecodeMode) valid() bool {
return idm >= 0 && idm < maxInfDecode
}

// DecOptions specifies decoding options.
type DecOptions struct {
// DupMapKey specifies whether to enforce duplicate map key.
Expand Down Expand Up @@ -592,6 +641,14 @@ type DecOptions struct {
// TimeTagToAnyMode specifies how to decode CBOR tag 0 and 1 into an empty interface (any).
// Based on the specified mode, Unmarshal can return a time.Time value or a time string in a specific format.
TimeTagToAny TimeTagToAnyMode

// NaNDec specifies how to decode floating-point values (major type 7, additional
// information 25 through 27) representing NaN (not-a-number).
NaNDec NaNDecodeMode

// InfDec specifies how to decode floating-point values (major type 7, additional
// information 25 through 27) representing positive or negative infinity.
InfDec InfDecodeMode
}

// DecMode returns DecMode with immutable options and no tags (safe for concurrency).
Expand Down Expand Up @@ -749,6 +806,14 @@ func (opts DecOptions) decMode() (*decMode, error) {
return nil, errors.New("cbor: invalid TimeTagToAny " + strconv.Itoa(int(opts.TimeTagToAny)))
}

if !opts.NaNDec.valid() {
return nil, errors.New("cbor: invalid NaNDec " + strconv.Itoa(int(opts.NaNDec)))
}

if !opts.InfDec.valid() {
return nil, errors.New("cbor: invalid InfDec " + strconv.Itoa(int(opts.InfDec)))
}

dm := decMode{
dupMapKey: opts.DupMapKey,
timeTag: opts.TimeTag,
Expand All @@ -769,6 +834,8 @@ func (opts DecOptions) decMode() (*decMode, error) {
fieldNameByteString: opts.FieldNameByteString,
unrecognizedTagToAny: opts.UnrecognizedTagToAny,
timeTagToAny: opts.TimeTagToAny,
nanDec: opts.NaNDec,
infDec: opts.InfDec,
}

return &dm, nil
Expand Down Expand Up @@ -841,6 +908,8 @@ type decMode struct {
fieldNameByteString FieldNameByteStringMode
unrecognizedTagToAny UnrecognizedTagToAnyMode
timeTagToAny TimeTagToAnyMode
nanDec NaNDecodeMode
infDec InfDecodeMode
}

var defaultDecMode, _ = DecOptions{}.decMode()
Expand All @@ -867,6 +936,8 @@ func (dm *decMode) DecOptions() DecOptions {
FieldNameByteString: dm.fieldNameByteString,
UnrecognizedTagToAny: dm.unrecognizedTagToAny,
TimeTagToAny: dm.timeTagToAny,
NaNDec: dm.nanDec,
InfDec: dm.infDec,
}
}

Expand Down Expand Up @@ -1181,12 +1252,21 @@ func (d *decoder) parseToValue(v reflect.Value, tInfo *typeInfo) error { //nolin
switch ai {
case 25:
f := float64(float16.Frombits(uint16(val)).Float32())
if err := d.acceptableFloat(f); err != nil {
return err
}
return fillFloat(t, f, v)
case 26:
f := float64(math.Float32frombits(uint32(val)))
if err := d.acceptableFloat(f); err != nil {
return err
}
return fillFloat(t, f, v)
case 27:
f := math.Float64frombits(val)
if err := d.acceptableFloat(f); err != nil {
return err
}
return fillFloat(t, f, v)
default: // ai <= 24
switch ai {
Expand Down Expand Up @@ -1373,10 +1453,19 @@ func (d *decoder) parseToTime() (time.Time, bool, error) {
switch ai {
case 25:
f = float64(float16.Frombits(uint16(val)).Float32())
if err := d.acceptableFloat(f); err != nil {
return time.Time{}, false, err
}
case 26:
f = float64(math.Float32frombits(uint32(val)))
if err := d.acceptableFloat(f); err != nil {
return time.Time{}, false,err
}
case 27:
f = math.Float64frombits(val)
if err := d.acceptableFloat(f); err != nil {
return time.Time{}, false, err
}
default:
return time.Time{}, false, &UnmarshalTypeError{CBORType: t.String(), GoType: typeTime.String()}
}
Expand Down Expand Up @@ -1617,12 +1706,21 @@ func (d *decoder) parse(skipSelfDescribedTag bool) (interface{}, error) { //noli
return nil, nil
case 25:
f := float64(float16.Frombits(uint16(val)).Float32())
if err := d.acceptableFloat(f); err != nil {
return nil, err
}
return f, nil
case 26:
f := float64(math.Float32frombits(uint32(val)))
if err := d.acceptableFloat(f); err != nil {
return nil, err
}
return f, nil
case 27:
f := math.Float64frombits(val)
if err := d.acceptableFloat(f); err != nil {
return nil, err
}
return f, nil
}
case cborTypeArray:
Expand All @@ -1641,6 +1739,22 @@ func (d *decoder) parse(skipSelfDescribedTag bool) (interface{}, error) { //noli
return nil, nil
}

func (d *decoder) acceptableFloat(f float64) error {
switch {
case d.dm.nanDec == NaNDecodeReject && math.IsNaN(f):
return &UnacceptableDataItemError{
CBORType: cborTypePrimitives.String(),
Message: "floating-point NaN",
}
case d.dm.infDec == InfDecodeReject && math.IsInf(f, 0):
return &UnacceptableDataItemError{
CBORType: cborTypePrimitives.String(),
Message: "floating-point infinity",
}
}
return nil
}

// parseByteString parses a CBOR encoded byte string. The returned byte slice
// may be backed directly by the input. The second return value will be true if
// and only if the slice is backed by a copy of the input. Callers are
Expand Down
Loading

0 comments on commit 0312aa9

Please sign in to comment.