Skip to content

Commit

Permalink
thriftbp: Report request/response payload sizes from client side
Browse files Browse the repository at this point in the history
We currently already have a middleware to report that from the server
side, but that's simulated by reconstruct THeaderProtocols and
read/write to them, which will not include the header size and will not
reflect compressions (zlib, etc.).

This client side reporting is done directly at the TTransport (TSocket)
level, so it will reflect the exact number actually read/written at that
level. If zlib is enabled in THeaderProtocol the size will reflect
compressed payload. The only thing it will not reflect is potential SSL
layer (e.g. when using TSSLSocket over TSocket, which we don't support
yet).
  • Loading branch information
fishy committed May 7, 2024
1 parent 904e7db commit 3868423
Show file tree
Hide file tree
Showing 5 changed files with 134 additions and 37 deletions.
6 changes: 4 additions & 2 deletions thriftbp/client_pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -555,13 +555,15 @@ func newClient(
genAddr AddressGenerator,
protoFactory thrift.TProtocolFactory,
) (*ttlClient, error) {
return newTTLClient(func() (thrift.TClient, thrift.TTransport, error) {
return newTTLClient(func() (thrift.TClient, *countingDelegateTransport, error) {
addr, err := genAddr()
if err != nil {
return nil, nil, fmt.Errorf("thriftbp: error getting next address for new Thrift client: %w", err)
}

transport := thrift.NewTSocketConf(addr, cfg)
transport := &countingDelegateTransport{
TTransport: thrift.NewTSocketConf(addr, cfg),
}
if err := transport.Open(); err != nil {
return nil, nil, fmt.Errorf("thriftbp: error opening TSocket for new Thrift client: %w", err)
}
Expand Down
46 changes: 30 additions & 16 deletions thriftbp/prometheus.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,32 +135,46 @@ const (
)

var (
payloadSizeLabels = []string{
serverPayloadSizeLabels = []string{
methodLabel,
protoLabel,
}

clientPayloadSizeLabels = []string{
methodLabel,
clientNameLabel,
successLabel,
}

// 8 bytes to 4 mebibytes
// some endpoints can have very small, almost zero payloads (for example,
// is_healthy), but we do have some endpoints with very large payloads
// (up to ~500 KiB).
payloadSizeBuckets = prometheus.ExponentialBuckets(8, 2, 20)

payloadSizeRequestBytes = promauto.With(prometheusbpint.GlobalRegistry).NewHistogramVec(prometheus.HistogramOpts{
Namespace: promNamespace,
Subsystem: subsystemServer,
Name: "request_payload_size_bytes",
Help: "The size of thrift request payloads",
Buckets: payloadSizeBuckets,
}, payloadSizeLabels)

payloadSizeResponseBytes = promauto.With(prometheusbpint.GlobalRegistry).NewHistogramVec(prometheus.HistogramOpts{
Namespace: promNamespace,
Subsystem: subsystemServer,
Name: "response_payload_size_bytes",
Help: "The size of thrift response payloads",
Buckets: payloadSizeBuckets,
}, payloadSizeLabels)
serverPayloadSizeRequestBytes = promauto.With(prometheusbpint.GlobalRegistry).NewHistogramVec(prometheus.HistogramOpts{
Name: "thriftbp_server_request_payload_size_bytes",
Help: "The (approximate) size of thrift request payloads",
Buckets: payloadSizeBuckets,
}, serverPayloadSizeLabels)

serverPayloadSizeResponseBytes = promauto.With(prometheusbpint.GlobalRegistry).NewHistogramVec(prometheus.HistogramOpts{
Name: "thriftbp_server_response_payload_size_bytes",
Help: "The (approximate) size of thrift response payloads",
Buckets: payloadSizeBuckets,
}, serverPayloadSizeLabels)

clientPayloadSizeRequestBytes = promauto.With(prometheusbpint.GlobalRegistry).NewHistogramVec(prometheus.HistogramOpts{
Name: "thriftbp_client_request_payload_size_bytes",
Help: "The size of thrift request payloads",
Buckets: payloadSizeBuckets,
}, clientPayloadSizeLabels)

clientPayloadSizeResponseBytes = promauto.With(prometheusbpint.GlobalRegistry).NewHistogramVec(prometheus.HistogramOpts{
Name: "thriftbp_client_response_payload_size_bytes",
Help: "The size of thrift response payloads",
Buckets: payloadSizeBuckets,
}, clientPayloadSizeLabels)
)

var (
Expand Down
16 changes: 7 additions & 9 deletions thriftbp/server_middlewares.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ type DefaultProcessorMiddlewaresArgs struct {

// Report the payload size metrics with this sample rate.
//
// This is optional. If it's not set none of the requests will be sampled.
// Deprecated: Prometheus payload size metrics are always 100% reported.
ReportPayloadSizeMetricsSampleRate float64

// The edge context implementation. Optional.
Expand Down Expand Up @@ -75,7 +75,7 @@ func BaseplateDefaultProcessorMiddlewares(args DefaultProcessorMiddlewaresArgs)
ExtractDeadlineBudget,
InjectServerSpan(args.ErrorSpanSuppressor),
InjectEdgeContext(args.EdgeContextImpl),
ReportPayloadSizeMetrics(args.ReportPayloadSizeMetricsSampleRate),
ReportPayloadSizeMetrics(0),
PrometheusServerMiddleware,
}
}
Expand Down Expand Up @@ -277,11 +277,9 @@ func AbandonCanceledRequests(name string, next thrift.TProcessorFunction) thrift
// If the request is not in THeaderProtocol it does nothing no matter what the
// sample rate is.
//
// For endpoint named "myEndpoint", it reports histograms at:
//
// - payload.size.myEndpoint.request
//
// - payload.size.myEndpoint.response
// The prometheus histograms are:
// - thriftbp_server_request_payload_size_bytes
// - thriftbp_server_response_payload_size_bytes
func ReportPayloadSizeMetrics(_ float64) thrift.ProcessorMiddleware {
return func(name string, next thrift.TProcessorFunction) thrift.TProcessorFunction {
return thrift.WrappedTProcessorFunction{
Expand Down Expand Up @@ -317,8 +315,8 @@ func ReportPayloadSizeMetrics(_ float64) thrift.ProcessorMiddleware {
methodLabel: name,
protoLabel: proto,
}
payloadSizeRequestBytes.With(labels).Observe(isize)
payloadSizeResponseBytes.With(labels).Observe(osize)
serverPayloadSizeRequestBytes.With(labels).Observe(isize)
serverPayloadSizeResponseBytes.With(labels).Observe(osize)
}()
}

Expand Down
43 changes: 40 additions & 3 deletions thriftbp/ttl_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package thriftbp

import (
"context"
"sync/atomic"
"time"

"github.com/apache/thrift/lib/go/thrift"
Expand All @@ -11,7 +12,7 @@ import (
"github.com/reddit/baseplate.go/randbp"
)

type ttlClientGenerator func() (thrift.TClient, thrift.TTransport, error)
type ttlClientGenerator func() (thrift.TClient, *countingDelegateTransport, error)

// DefaultMaxConnectionAge is the default max age for a Thrift client connection.
const DefaultMaxConnectionAge = time.Minute * 5
Expand All @@ -24,7 +25,7 @@ var _ Client = (*ttlClient)(nil)

type ttlClientState struct {
client thrift.TClient
transport thrift.TTransport
transport *countingDelegateTransport
expiration time.Time // if expiration is zero, then the client will be kept open indefinetly.
timer *time.Timer
closed bool
Expand Down Expand Up @@ -66,11 +67,22 @@ func (c *ttlClient) Close() error {
return state.transport.Close()
}

func (c *ttlClient) Call(ctx context.Context, method string, args, result thrift.TStruct) (thrift.ResponseMeta, error) {
func (c *ttlClient) Call(ctx context.Context, method string, args, result thrift.TStruct) (_ thrift.ResponseMeta, err error) {
state := <-c.state
defer func() {
c.state <- state
}()

defer func() {
read, written := state.transport.getBytesAndReset()
labels := prometheus.Labels{
methodLabel: method,
clientNameLabel: c.slug,
successLabel: prometheusbp.BoolString(err == nil),
}
clientPayloadSizeRequestBytes.With(labels).Observe(float64(written))
clientPayloadSizeResponseBytes.With(labels).Observe(float64(read))
}()
return state.client.Call(ctx, method, args, result)
}

Expand Down Expand Up @@ -166,3 +178,28 @@ func newTTLClient(generator ttlClientGenerator, ttl time.Duration, jitter float6

return c, nil
}

type countingDelegateTransport struct {
thrift.TTransport

bytesRead atomic.Uint64
bytesWritten atomic.Uint64
}

func (cdt *countingDelegateTransport) Read(p []byte) (n int, err error) {
defer func() {
cdt.bytesRead.Add(uint64(n))
}()
return cdt.TTransport.Read(p)
}

func (cdt *countingDelegateTransport) Write(p []byte) (n int, err error) {
defer func() {
cdt.bytesWritten.Add(uint64(n))
}()
return cdt.TTransport.Write(p)
}

func (cdt *countingDelegateTransport) getBytesAndReset() (read, written uint64) {
return cdt.bytesRead.Swap(0), cdt.bytesWritten.Swap(0)
}
60 changes: 53 additions & 7 deletions thriftbp/ttl_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,14 @@ import (

// firstSuccessGenerator is a ttlClientGenerator implementation that would
// return client and transport on the first call, and errors afterwards.
func firstSuccessGenerator(transport thrift.TTransport) ttlClientGenerator {
func firstSuccessGenerator(transport *countingDelegateTransport) ttlClientGenerator {
factory := thrift.NewTBinaryProtocolFactoryConf(nil)
client := thrift.NewTStandardClient(
factory.GetProtocol(transport),
factory.GetProtocol(transport),
)
first := true
return func() (thrift.TClient, thrift.TTransport, error) {
return func() (thrift.TClient, *countingDelegateTransport, error) {
if first {
first = false
return client, transport, nil
Expand All @@ -28,7 +28,9 @@ func firstSuccessGenerator(transport thrift.TTransport) ttlClientGenerator {
}

func TestTTLClient(t *testing.T) {
transport := thrift.NewTMemoryBuffer()
transport := &countingDelegateTransport{
TTransport: thrift.NewTMemoryBuffer(),
}
ttl := time.Millisecond
jitter := 0.1

Expand Down Expand Up @@ -60,7 +62,9 @@ func TestTTLClient(t *testing.T) {
}

func TestTTLClientNegativeTTL(t *testing.T) {
transport := thrift.NewTMemoryBuffer()
transport := &countingDelegateTransport{
TTransport: thrift.NewTMemoryBuffer(),
}
ttl := time.Millisecond

client, err := newTTLClient(firstSuccessGenerator(transport), -ttl, 0.1, "")
Expand Down Expand Up @@ -114,7 +118,7 @@ func TestTTLClientRenew(t *testing.T) {
// alwaysSuccessGenerator is a ttlClientGenerator implementation that would
// always return client, transport, and no error.
type alwaysSuccessGenerator struct {
transport thrift.TTransport
transport *countingDelegateTransport

called atomic.Int64
}
Expand All @@ -125,7 +129,7 @@ func (g *alwaysSuccessGenerator) generator() ttlClientGenerator {
factory.GetProtocol(g.transport),
factory.GetProtocol(g.transport),
)
return func() (thrift.TClient, thrift.TTransport, error) {
return func() (thrift.TClient, *countingDelegateTransport, error) {
g.called.Add(1)
return client, g.transport, nil
}
Expand Down Expand Up @@ -159,7 +163,9 @@ func TestTTLClientRefresh(t *testing.T) {
jitter = 0
)

g := alwaysSuccessGenerator{transport: &transport}
g := alwaysSuccessGenerator{transport: &countingDelegateTransport{
TTransport: &transport,
}}
client, err := newTTLClient(g.generator(), ttl, jitter, "")
if err != nil {
t.Fatalf("newTTLClient returned error: %v", err)
Expand Down Expand Up @@ -191,3 +197,43 @@ func TestTTLClientRefresh(t *testing.T) {
}
})
}

func TestCountingDelegateTransport(t *testing.T) {
const payload = "payload"

membuf := thrift.NewTMemoryBuffer()
transport := countingDelegateTransport{
TTransport: membuf,
}

if _, err := transport.Write([]byte(payload)); err != nil {
t.Fatalf("Failed to write: %v", err)
}

var buf [1024]byte
n, err := transport.Read(buf[:])
if err != nil {
t.Fatalf("Failed to read: %v", err)
}
if got := string(buf[:n]); got != payload {
t.Errorf("Read %q want %q", got, payload)
}

read, written := transport.getBytesAndReset()
want := uint64(len(payload))
if read != want {
t.Errorf("Read %d bytes want %d", read, want)
}
if written != want {
t.Errorf("Written %d bytes want %d", written, want)
}

read, written = transport.getBytesAndReset()
want = 0
if read != want {
t.Errorf("After reset: Read %d bytes want %d", read, want)
}
if written != want {
t.Errorf("After reset: Written %d bytes want %d", written, want)
}
}

0 comments on commit 3868423

Please sign in to comment.