Skip to content

Commit

Permalink
More optimizations
Browse files Browse the repository at this point in the history
  • Loading branch information
pboyd04 committed Sep 20, 2024
1 parent eb50787 commit d59475d
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 36 deletions.
2 changes: 1 addition & 1 deletion dbus.go
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ func alignment(t reflect.Type) int {
return 1
case reflect.Uint16, reflect.Int16:
return 2
case reflect.Uint, reflect.Int, reflect.Uint32, reflect.Int32, reflect.String, reflect.Array, reflect.Slice, reflect.Map:
case reflect.Uint, reflect.Int, reflect.Uint32, reflect.Int32, reflect.String, reflect.Array, reflect.Slice, reflect.Map, reflect.Bool:
return 4
case reflect.Uint64, reflect.Int64, reflect.Float64, reflect.Struct:
return 8
Expand Down
67 changes: 66 additions & 1 deletion encoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,18 @@ func (enc *encoder) Encode(vs ...interface{}) (err error) {
return nil
}

func CountFDs(vs ...interface{}) (int, error) {
var err error
defer func() {
err, _ = recover().(error)
}()
count := 0
for _, v := range vs {
count += fdCounter(reflect.ValueOf(v), 0)
}
return count, err
}

// encode encodes the given value to the writer and panics on error. depth holds
// the depth of the container nesting.
func (enc *encoder) encode(v reflect.Value, depth int) {
Expand All @@ -273,7 +285,7 @@ func (enc *encoder) encode(v reflect.Value, depth int) {
if v.Bool() {
enc.binWriteIntType(uint32(1))
} else {
enc.binWriteIntType(uint32(1))
enc.binWriteIntType(uint32(0))
}
enc.pos += 4
case reflect.Int16:
Expand Down Expand Up @@ -414,3 +426,56 @@ func (enc *encoder) encode(v reflect.Value, depth int) {
panic(InvalidTypeError{v.Type()})
}
}

func fdCounter(v reflect.Value, depth int) int {
if depth > 64 {
panic(FormatError("input exceeds depth limitation"))
}
switch v.Kind() {
case reflect.Int, reflect.Int32:
if v.Type() == unixFDType {
return 1
}
return 0
case reflect.Ptr:
return fdCounter(v.Elem(), depth)
case reflect.Slice, reflect.Array:
// we don't really need the child encoder in this case since we aren't actually messing with the buffer at all
count := 0
for i := 0; i < v.Len(); i++ {
count += fdCounter(v.Index(i), depth+1)
}
return count
case reflect.Struct:
switch t := v.Type(); t {
case variantType:
variant := v.Interface().(Variant)
return fdCounter(reflect.ValueOf(variant.value), depth+1)
default:
count := 0
for i := 0; i < v.Type().NumField(); i++ {
field := t.Field(i)
if field.PkgPath == "" && field.Tag.Get("dbus") != "-" {
count += fdCounter(v.Field(i), depth+1)
}
}
return count
}
case reflect.Map:
// Maps are arrays of structures, so they actually increase the depth by
// 2.
// we don't really need the child encoder in this case since we aren't actually messing with the buffer at all
iter := v.MapRange()
count := 0
for iter.Next() {
count += fdCounter(iter.Key(), depth+2)
count += fdCounter(iter.Value(), depth+2)
}
return count
case reflect.Interface:
return fdCounter(reflect.ValueOf(MakeVariant(v.Interface())), depth)
default:
// do nothing we are skipping most types
return 0
}
}
50 changes: 19 additions & 31 deletions message.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package dbus
import (
"bytes"
"encoding/binary"
"errors"
"io"
"reflect"
"strconv"
Expand Down Expand Up @@ -203,33 +202,20 @@ func DecodeMessage(rd io.Reader) (msg *Message, err error) {
return DecodeMessageWithFDs(rd, make([]int, 0))
}

type nullwriter struct{}

func (nullwriter) Write(p []byte) (cnt int, err error) {
return len(p), nil
}

func (msg *Message) CountFds() (int, error) {
if len(msg.Body) == 0 {
return 0, nil
}
enc := newEncoder(nullwriter{}, nativeEndian, make([]int, 0))
err := enc.Encode(msg.Body...)
return len(enc.fds), err
return CountFDs(msg.Body...)
}

func (msg *Message) EncodeToWithFDs(out io.Writer, order binary.ByteOrder) (fds []int, err error) {
if err := msg.validateHeader(); err != nil {
return nil, err
}
var vs [7]interface{}
switch order {
case binary.LittleEndian:
vs[0] = byte('l')
case binary.BigEndian:
vs[0] = byte('B')
default:
return nil, errors.New("dbus: invalid byte order")
endianByte := byte('l')
if order == binary.BigEndian {
endianByte = byte('B')
}
body := new(bytes.Buffer)
fds = make([]int, 0)
Expand All @@ -240,32 +226,34 @@ func (msg *Message) EncodeToWithFDs(out io.Writer, order binary.ByteOrder) (fds
return
}
}
vs[1] = msg.Type
vs[2] = msg.Flags
vs[3] = protoVersion
vs[4] = uint32(len(body.Bytes()))
vs[5] = msg.serial
headers := make([]header, 0, len(msg.Headers))
for k, v := range msg.Headers {
headers = append(headers, header{byte(k), v})
}
vs[6] = headers
var buf bytes.Buffer
enc = newEncoder(&buf, order, enc.fds)
err = enc.Encode(vs[:]...)
buf := bytes.NewBuffer(make([]byte, 0, 128))
// No need to alloc a new encoder, just reset the old one
enc.Reset(buf, order, enc.fds)
buf.WriteByte(endianByte)
buf.WriteByte(byte(msg.Type))
buf.WriteByte(byte(msg.Flags))
buf.WriteByte(protoVersion)
enc.binWriteIntType(uint32(len(body.Bytes())))
enc.binWriteIntType(msg.serial)
enc.pos = 12
err = enc.Encode(headers)
if err != nil {
return
}
enc.align(8)
if _, err := body.WriteTo(&buf); err != nil {
return nil, err
}
if buf.Len() > 1<<27 {
if buf.Len()+body.Len() > 1<<27 {
return nil, InvalidMessageError("message is too long")
}
if _, err := buf.WriteTo(out); err != nil {
return nil, err
}
if _, err := body.WriteTo(out); err != nil {
return nil, err
}
return enc.fds, nil
}

Expand Down
13 changes: 10 additions & 3 deletions transport_unix.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"errors"
"io"
"net"
"sync"
"syscall"
)

Expand All @@ -31,7 +32,6 @@ type oobReader struct {
// The following fields are used to reduce memory allocs.
csheader []byte
b *bytes.Buffer
dec *decoder
msghead
}

Expand Down Expand Up @@ -92,6 +92,12 @@ func (t *unixTransport) EnableUnixFDs() {
t.hasUnixFDs = true
}

var decodePool = sync.Pool{
New: func() interface{} {
return new(decoder)
},
}

func (t *unixTransport) ReadMessage() (*Message, error) {
// To be sure that all bytes of out-of-band data are read, we use a special
// reader that uses ReadUnix on the underlying connection instead of Read
Expand All @@ -102,15 +108,16 @@ func (t *unixTransport) ReadMessage() (*Message, error) {
// This buffer is used to decode the part of the header that has a constant size.
csheader: make([]byte, 16),
b: bytes.NewBuffer(make([]byte, defaultBufferSize)),
dec: &decoder{},
}
} else {
t.rdr.oob = t.rdr.oob[:0]
}
var (
b = t.rdr.b
dec = t.rdr.dec
dec = decodePool.Get().(*decoder)
)
// Put the decoder back in the pool for others to use.
defer decodePool.Put(dec)

b.Reset()
if _, err := io.CopyN(b, t.rdr, 16); err != nil {
Expand Down

0 comments on commit d59475d

Please sign in to comment.