Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add compactfloats directive for lossless float64 -> float32 conversion #366

Merged
merged 1 commit into from
Sep 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions _generated/compactfloats.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package _generated

//go:generate msgp

//msgp:compactfloats

//msgp:ignore F64
type F64 float64

//msgp:replace F64 with:float64

type Floats struct {
A float64
B float32
Slice []float64
Map map[string]float64
F F64
OE float64 `msg:",omitempty"`
}
78 changes: 78 additions & 0 deletions _generated/compactfloats_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
package _generated

import (
"bytes"
"reflect"
"testing"

"github.com/tinylib/msgp/msgp"
)

func TestCompactFloats(t *testing.T) {
// Constant that can be represented in f32 without loss
const f32ok = -1e2
allF32 := Floats{
A: f32ok,
B: f32ok,
Slice: []float64{f32ok, f32ok},
Map: map[string]float64{"a": f32ok},
F: f32ok,
OE: f32ok,
}
asF32 := float32(f32ok)
wantF32 := map[string]any{"A": asF32, "B": asF32, "F": asF32, "Map": map[string]any{"a": asF32}, "OE": asF32, "Slice": []any{asF32, asF32}}

enc, err := allF32.MarshalMsg(nil)
if err != nil {
t.Error(err)
}
i, _, _ := msgp.ReadIntfBytes(enc)
got := i.(map[string]any)
if !reflect.DeepEqual(got, wantF32) {
t.Errorf("want: %v, got: %v (diff may be types)", wantF32, got)
}

var buf bytes.Buffer
en := msgp.NewWriter(&buf)
allF32.EncodeMsg(en)
en.Flush()
enc = buf.Bytes()
i, _, _ = msgp.ReadIntfBytes(enc)
got = i.(map[string]any)
if !reflect.DeepEqual(got, wantF32) {
t.Errorf("want: %v, got: %v (diff may be types)", wantF32, got)
}

const f64ok = -10e64
allF64 := Floats{
A: f64ok,
B: f32ok,
Slice: []float64{f64ok, f64ok},
Map: map[string]float64{"a": f64ok},
F: f64ok,
OE: f64ok,
}
asF64 := float64(f64ok)
wantF64 := map[string]any{"A": asF64, "B": asF32, "F": asF64, "Map": map[string]any{"a": asF64}, "OE": asF64, "Slice": []any{asF64, asF64}}

enc, err = allF64.MarshalMsg(nil)
if err != nil {
t.Error(err)
}
i, _, _ = msgp.ReadIntfBytes(enc)
got = i.(map[string]any)
if !reflect.DeepEqual(got, wantF64) {
t.Errorf("want: %v, got: %v (diff may be types)", wantF64, got)
}

buf.Reset()
en = msgp.NewWriter(&buf)
allF64.EncodeMsg(en)
en.Flush()
enc = buf.Bytes()
i, _, _ = msgp.ReadIntfBytes(enc)
got = i.(map[string]any)
if !reflect.DeepEqual(got, wantF64) {
t.Errorf("want: %v, got: %v (diff may be types)", wantF64, got)
}
}
5 changes: 2 additions & 3 deletions gen/decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ func (d *decodeGen) needsField() {
d.hasfield = true
}

func (d *decodeGen) Execute(p Elem) error {
func (d *decodeGen) Execute(p Elem, ctx Context) error {
d.ctx = &ctx
p = d.applyall(p)
if p == nil {
return nil
Expand All @@ -43,8 +44,6 @@ func (d *decodeGen) Execute(p Elem) error {
return nil
}

d.ctx = &Context{}

d.p.comment("DecodeMsg implements msgp.Decodable")

d.p.printf("\nfunc (%s %s) DecodeMsg(dc *msgp.Reader) (err error) {", p.Varname(), methodReceiver(p))
Expand Down
9 changes: 6 additions & 3 deletions gen/encode.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ func (e *encodeGen) Apply(dirs []string) error {
}

func (e *encodeGen) writeAndCheck(typ string, argfmt string, arg interface{}) {
if e.ctx.compFloats && typ == "Float64" {
typ = "Float"
}

e.p.printf("\nerr = en.Write%s(%s)", typ, fmt.Sprintf(argfmt, arg))
e.p.wrapErrCheck(e.ctx.ArgsStr())
}
Expand All @@ -47,7 +51,8 @@ func (e *encodeGen) Fuse(b []byte) {
}
}

func (e *encodeGen) Execute(p Elem) error {
func (e *encodeGen) Execute(p Elem, ctx Context) error {
e.ctx = &ctx
if !e.p.ok() {
return e.p.err
}
Expand All @@ -59,8 +64,6 @@ func (e *encodeGen) Execute(p Elem) error {
return nil
}

e.ctx = &Context{}

e.p.comment("EncodeMsg implements msgp.Encodable")
rcv := imutMethodReceiver(p)
ogVar := p.Varname()
Expand Down
8 changes: 5 additions & 3 deletions gen/marshal.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ func (m *marshalGen) Apply(dirs []string) error {
return nil
}

func (m *marshalGen) Execute(p Elem) error {
func (m *marshalGen) Execute(p Elem, ctx Context) error {
m.ctx = &ctx
if !m.p.ok() {
return m.p.err
}
Expand All @@ -39,8 +40,6 @@ func (m *marshalGen) Execute(p Elem) error {
return nil
}

m.ctx = &Context{}

m.p.comment("MarshalMsg implements msgp.Marshaler")

// save the vname before
Expand All @@ -64,6 +63,9 @@ func (m *marshalGen) Execute(p Elem) error {
}

func (m *marshalGen) rawAppend(typ string, argfmt string, arg interface{}) {
if m.ctx.compFloats && typ == "Float64" {
typ = "Float"
}
m.p.printf("\no = msgp.Append%s(o, %s)", typ, fmt.Sprintf(argfmt, arg))
}

Expand Down
4 changes: 2 additions & 2 deletions gen/size.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ func (s *sizeGen) addConstant(sz string) {
panic("unknown size state")
}

func (s *sizeGen) Execute(p Elem) error {
func (s *sizeGen) Execute(p Elem, ctx Context) error {
s.ctx = &ctx
if !s.p.ok() {
return s.p.err
}
Expand All @@ -81,7 +82,6 @@ func (s *sizeGen) Execute(p Elem) error {
return nil
}

s.ctx = &Context{}
s.ctx.PushString(p.TypeName())

s.p.comment("Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message")
Expand Down
10 changes: 6 additions & 4 deletions gen/spec.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,8 @@ const (
)

type Printer struct {
gens []generator
gens []generator
CompactFloats bool
}

func NewPrinter(m Method, out io.Writer, tests io.Writer) *Printer {
Expand Down Expand Up @@ -144,7 +145,7 @@ func (p *Printer) Print(e Elem) error {
// collisions between idents created during SetVarname and idents created during Print,
// hence the separate prefixes.
resetIdent("zb")
err := g.Execute(e)
err := g.Execute(e, Context{compFloats: p.CompactFloats})
resetIdent("za")

if err != nil {
Expand All @@ -171,7 +172,8 @@ func (c contextVar) Arg() string {
}

type Context struct {
path []contextItem
path []contextItem
compFloats bool
}

func (c *Context) PushString(s string) {
Expand Down Expand Up @@ -202,7 +204,7 @@ func (c *Context) ArgsStr() string {
type generator interface {
Method() Method
Add(p TransformPass)
Execute(Elem) error // execute writes the method for the provided object.
Execute(Elem, Context) error // execute writes the method for the provided object.
}

type passes []TransformPass
Expand Down
4 changes: 2 additions & 2 deletions gen/testgen.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ type mtestGen struct {
w io.Writer
}

func (m *mtestGen) Execute(p Elem) error {
func (m *mtestGen) Execute(p Elem, _ Context) error {
p = m.applyall(p)
if p != nil && IsPrintable(p) {
switch p.(type) {
Expand All @@ -48,7 +48,7 @@ func etest(w io.Writer) *etestGen {
return &etestGen{w: w}
}

func (e *etestGen) Execute(p Elem) error {
func (e *etestGen) Execute(p Elem, _ Context) error {
p = e.applyall(p)
if p != nil && IsPrintable(p) {
switch p.(type) {
Expand Down
5 changes: 2 additions & 3 deletions gen/unmarshal.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,9 @@ func (u *unmarshalGen) needsField() {
u.hasfield = true
}

func (u *unmarshalGen) Execute(p Elem) error {
func (u *unmarshalGen) Execute(p Elem, ctx Context) error {
u.hasfield = false
u.ctx = &ctx
if !u.p.ok() {
return u.p.err
}
Expand All @@ -41,8 +42,6 @@ func (u *unmarshalGen) Execute(p Elem) error {
return nil
}

u.ctx = &Context{}

u.p.comment("UnmarshalMsg implements msgp.Unmarshaler")

u.p.printf("\nfunc (%s %s) UnmarshalMsg(bts []byte) (o []byte, err error) {", p.Varname(), methodReceiver(p))
Expand Down
10 changes: 10 additions & 0 deletions msgp/write.go
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,16 @@ func (mw *Writer) WriteNil() error {
return mw.push(mnil)
}

// WriteFloat writes a float to the writer as either float64
// or float32 when it represents the exact same value
func (mw *Writer) WriteFloat(f float64) error {
f32 := float32(f)
if float64(f32) == f {
return mw.prefix32(mfloat32, math.Float32bits(f32))
}
return mw.prefix64(mfloat64, math.Float64bits(f))
}

// WriteFloat64 writes a float64 to the writer
func (mw *Writer) WriteFloat64(f float64) error {
return mw.prefix64(mfloat64, math.Float64bits(f))
Expand Down
10 changes: 10 additions & 0 deletions msgp/write_bytes.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,16 @@ func AppendArrayHeader(b []byte, sz uint32) []byte {
// AppendNil appends a 'nil' byte to the slice
func AppendNil(b []byte) []byte { return append(b, mnil) }

// AppendFloat appends a float to the slice as either float64
// or float32 when it represents the exact same value
func AppendFloat(b []byte, f float64) []byte {
f32 := float32(f)
if float64(f32) == f {
return AppendFloat32(b, f32)
}
return AppendFloat64(b, f)
}

// AppendFloat64 appends a float64 to the slice
func AppendFloat64(b []byte, f float64) []byte {
o, n := ensure(b, Float64Size)
Expand Down
69 changes: 69 additions & 0 deletions msgp/write_bytes_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package msgp
import (
"bytes"
"math"
"math/rand"
"reflect"
"strings"
"testing"
Expand Down Expand Up @@ -134,6 +135,74 @@ func TestAppendNil(t *testing.T) {
}
}

func TestAppendFloat(t *testing.T) {
rng := rand.New(rand.NewSource(0))
const n = 1e7
src := make([]float64, n)
for i := range src {
// ~50% full float64, 50% converted from float32.
if rng.Uint32()&1 == 1 {
src[i] = rng.NormFloat64()
} else {
src[i] = float64(math.MaxFloat32 * (0.5 - rng.Float32()))
}
}

var buf bytes.Buffer
en := NewWriter(&buf)

var bts []byte
for _, f := range src {
en.WriteFloat(f)
bts = AppendFloat(bts, f)
}
en.Flush()
if buf.Len() != len(bts) {
t.Errorf("encoder wrote %d; append wrote %d bytes", buf.Len(), len(bts))
}
t.Logf("%f bytes/value", float64(buf.Len())/n)
a, b := bts, buf.Bytes()
for i := range a {
if a[i] != b[i] {
t.Errorf("mismatch at byte %d, %d != %d", i, a[i], b[i])
break
}
}

for i, want := range src {
var got float64
var err error
got, a, err = ReadFloat64Bytes(a)
if err != nil {
t.Fatal(err)
}
if want != got {
t.Errorf("value #%d: want %v; got %v", i, want, got)
}
}
}

func BenchmarkAppendFloat(b *testing.B) {
rng := rand.New(rand.NewSource(0))
const n = 1 << 16
src := make([]float64, n)
for i := range src {
// ~50% full float64, 50% converted from float32.
if rng.Uint32()&1 == 1 {
src[i] = rng.NormFloat64()
} else {
src[i] = float64(math.MaxFloat32 * (0.5 - rng.Float32()))
}
}
buf := make([]byte, 0, 9)
b.SetBytes(8)
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
AppendFloat64(buf, src[i&(n-1)])
}
}

func TestAppendFloat64(t *testing.T) {
f := float64(3.14159)
var buf bytes.Buffer
Expand Down
Loading
Loading