Skip to content

Commit

Permalink
Merge pull request #4592 from skmcgrail/infinity
Browse files Browse the repository at this point in the history
Protocol Support for NaN/Infinity/-Infinity float values
  • Loading branch information
RanVaknin authored Oct 31, 2022
2 parents 3896c7a + cf09be1 commit 9bfd737
Show file tree
Hide file tree
Showing 13 changed files with 491 additions and 65 deletions.
19 changes: 15 additions & 4 deletions private/protocol/json/jsonutil/build.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ package jsonutil
import (
"bytes"
"encoding/base64"
"encoding/json"
"fmt"
"math"
"reflect"
Expand All @@ -16,6 +15,12 @@ import (
"github.com/aws/aws-sdk-go/private/protocol"
)

const (
floatNaN = "NaN"
floatInf = "Infinity"
floatNegInf = "-Infinity"
)

var timeType = reflect.ValueOf(time.Time{}).Type()
var byteSliceType = reflect.ValueOf([]byte{}).Type()

Expand Down Expand Up @@ -211,10 +216,16 @@ func buildScalar(v reflect.Value, buf *bytes.Buffer, tag reflect.StructTag) erro
buf.Write(strconv.AppendInt(scratch[:0], value.Int(), 10))
case reflect.Float64:
f := value.Float()
if math.IsInf(f, 0) || math.IsNaN(f) {
return &json.UnsupportedValueError{Value: v, Str: strconv.FormatFloat(f, 'f', -1, 64)}
switch {
case math.IsNaN(f):
writeString(floatNaN, buf)
case math.IsInf(f, 1):
writeString(floatInf, buf)
case math.IsInf(f, -1):
writeString(floatNegInf, buf)
default:
buf.Write(strconv.AppendFloat(scratch[:0], f, 'f', -1, 64))
}
buf.Write(strconv.AppendFloat(scratch[:0], f, 'f', -1, 64))
default:
switch converted := value.Interface().(type) {
case time.Time:
Expand Down
40 changes: 24 additions & 16 deletions private/protocol/json/jsonutil/build_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package jsonutil_test

import (
"encoding/json"
"math"
"strings"
"testing"
"time"
Expand Down Expand Up @@ -41,41 +42,48 @@ var jsonTests = []struct {
err string
}{
{
J{},
`{}`,
``,
in: J{},
out: `{}`,
},
{
J{
in: J{
S: S("str"),
SS: []string{"A", "B", "C"},
D: D(123),
F: F(4.56),
T: T(time.Unix(987, 0)),
},
`{"S":"str","SS":["A","B","C"],"D":123,"F":4.56,"T":987}`,
``,
out: `{"S":"str","SS":["A","B","C"],"D":123,"F":4.56,"T":987}`,
},
{
J{
in: J{
S: S(`"''"`),
},
`{"S":"\"''\""}`,
``,
out: `{"S":"\"''\""}`,
},
{
J{
in: J{
S: S("\x00føø\u00FF\n\\\"\r\t\b\f"),
},
`{"S":"\u0000føøÿ\n\\\"\r\t\b\f"}`,
``,
out: `{"S":"\u0000føøÿ\n\\\"\r\t\b\f"}`,
},
{
J{
F: F(4.56 / zero),
in: J{
F: F(math.NaN()),
},
"",
`json: unsupported value: +Inf`,
out: `{"F":"NaN"}`,
},
{
in: J{
F: F(math.Inf(1)),
},
out: `{"F":"Infinity"}`,
},
{
in: J{
F: F(math.Inf(-1)),
},
out: `{"F":"-Infinity"}`,
},
}

Expand Down
13 changes: 13 additions & 0 deletions private/protocol/json/jsonutil/unmarshal.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"encoding/json"
"fmt"
"io"
"math"
"math/big"
"reflect"
"strings"
Expand Down Expand Up @@ -258,6 +259,18 @@ func (u unmarshaler) unmarshalScalar(value reflect.Value, data interface{}, tag
return err
}
value.Set(reflect.ValueOf(v))
case *float64:
// These are regular strings when parsed by encoding/json's unmarshaler.
switch {
case strings.EqualFold(d, floatNaN):
value.Set(reflect.ValueOf(aws.Float64(math.NaN())))
case strings.EqualFold(d, floatInf):
value.Set(reflect.ValueOf(aws.Float64(math.Inf(1))))
case strings.EqualFold(d, floatNegInf):
value.Set(reflect.ValueOf(aws.Float64(math.Inf(-1))))
default:
return fmt.Errorf("unknown JSON number value: %s", d)
}
default:
return fmt.Errorf("unsupported value: %v (%s)", value.Interface(), value.Type())
}
Expand Down
35 changes: 32 additions & 3 deletions private/protocol/json/jsonutil/unmarshal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package jsonutil_test

import (
"bytes"
"math"
"reflect"
"testing"
"time"
Expand All @@ -21,9 +22,10 @@ func TestUnmarshalJSON_JSONNumber(t *testing.T) {
}

cases := map[string]struct {
JSON string
Value input
Expected input
JSON string
Value input
Expected input
ExpectedFn func(*testing.T, input)
}{
"seconds precision": {
JSON: `{"timeField":1597094942}`,
Expand Down Expand Up @@ -106,6 +108,29 @@ func TestUnmarshalJSON_JSONNumber(t *testing.T) {
FloatField: aws.Float64(123456789.123),
},
},
"float64 field NaN": {
JSON: `{"floatField":"NaN"}`,
ExpectedFn: func(t *testing.T, input input) {
if input.FloatField == nil {
t.Fatal("expect non nil float64")
}
if e, a := true, math.IsNaN(*input.FloatField); e != a {
t.Errorf("expect %v, got %v", e, a)
}
},
},
"float64 field Infinity": {
JSON: `{"floatField":"Infinity"}`,
Expected: input{
FloatField: aws.Float64(math.Inf(1)),
},
},
"float64 field -Infinity": {
JSON: `{"floatField":"-Infinity"}`,
Expected: input{
FloatField: aws.Float64(math.Inf(-1)),
},
},
}

for name, tt := range cases {
Expand All @@ -114,6 +139,10 @@ func TestUnmarshalJSON_JSONNumber(t *testing.T) {
if err != nil {
t.Errorf("expect no error, got %v", err)
}
if tt.ExpectedFn != nil {
tt.ExpectedFn(t, tt.Value)
return
}
if e, a := tt.Expected, tt.Value; !reflect.DeepEqual(e, a) {
t.Errorf("expect %v, got %v", e, a)
}
Expand Down
34 changes: 32 additions & 2 deletions private/protocol/query/queryutil/queryutil.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package queryutil
import (
"encoding/base64"
"fmt"
"math"
"net/url"
"reflect"
"sort"
Expand All @@ -13,6 +14,12 @@ import (
"github.com/aws/aws-sdk-go/private/protocol"
)

const (
floatNaN = "NaN"
floatInf = "Infinity"
floatNegInf = "-Infinity"
)

// Parse parses an object i and fills a url.Values object. The isEC2 flag
// indicates if this is the EC2 Query sub-protocol.
func Parse(body url.Values, i interface{}, isEC2 bool) error {
Expand Down Expand Up @@ -228,9 +235,32 @@ func (q *queryParser) parseScalar(v url.Values, r reflect.Value, name string, ta
case int:
v.Set(name, strconv.Itoa(value))
case float64:
v.Set(name, strconv.FormatFloat(value, 'f', -1, 64))
var str string
switch {
case math.IsNaN(value):
str = floatNaN
case math.IsInf(value, 1):
str = floatInf
case math.IsInf(value, -1):
str = floatNegInf
default:
str = strconv.FormatFloat(value, 'f', -1, 64)
}
v.Set(name, str)
case float32:
v.Set(name, strconv.FormatFloat(float64(value), 'f', -1, 32))
asFloat64 := float64(value)
var str string
switch {
case math.IsNaN(asFloat64):
str = floatNaN
case math.IsInf(asFloat64, 1):
str = floatInf
case math.IsInf(asFloat64, -1):
str = floatNegInf
default:
str = strconv.FormatFloat(asFloat64, 'f', -1, 32)
}
v.Set(name, str)
case time.Time:
const ISO8601UTC = "2006-01-02T15:04:05Z"
format := tag.Get("timestampFormat")
Expand Down
18 changes: 17 additions & 1 deletion private/protocol/rest/build.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"encoding/base64"
"fmt"
"io"
"math"
"net/http"
"net/url"
"path"
Expand All @@ -20,6 +21,12 @@ import (
"github.com/aws/aws-sdk-go/private/protocol"
)

const (
floatNaN = "NaN"
floatInf = "Infinity"
floatNegInf = "-Infinity"
)

// Whether the byte value can be sent without escaping in AWS URLs
var noEscape [256]bool

Expand Down Expand Up @@ -302,7 +309,16 @@ func convertType(v reflect.Value, tag reflect.StructTag) (str string, err error)
case int64:
str = strconv.FormatInt(value, 10)
case float64:
str = strconv.FormatFloat(value, 'f', -1, 64)
switch {
case math.IsNaN(value):
str = floatNaN
case math.IsInf(value, 1):
str = floatInf
case math.IsInf(value, -1):
str = floatNegInf
default:
str = strconv.FormatFloat(value, 'f', -1, 64)
}
case time.Time:
format := tag.Get("timestampFormat")
if len(format) == 0 {
Expand Down
Loading

0 comments on commit 9bfd737

Please sign in to comment.