Skip to content
This repository has been archived by the owner on Jul 31, 2023. It is now read-only.

Commit

Permalink
Allow customizing gRPC ServerHandler (#558)
Browse files Browse the repository at this point in the history
* Added ServerHandler.IsPublicEndpoint to allow a configuration
  suitable for a public-facing server.
* Added ServerHandler.StartOptions to allow customizing the sampler
  used for new spans.
  • Loading branch information
Ramon Nogueira authored Mar 14, 2018
1 parent 6891f95 commit 0ac2803
Show file tree
Hide file tree
Showing 8 changed files with 130 additions and 72 deletions.
5 changes: 4 additions & 1 deletion plugin/ocgrpc/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,20 +36,23 @@ func (c *ClientHandler) HandleConn(ctx context.Context, cs stats.ConnStats) {
// no-op
}

// TagConn exists to satisfy gRPC stats.Handler.
func (c *ClientHandler) TagConn(ctx context.Context, cti *stats.ConnTagInfo) context.Context {
// no-op
return ctx
}

// HandleRPC implements per-RPC tracing and stats instrumentation.
func (c *ClientHandler) HandleRPC(ctx context.Context, rs stats.RPCStats) {
if !c.NoTrace {
c.traceHandleRPC(ctx, rs)
traceHandleRPC(ctx, rs)
}
if !c.NoStats {
c.statsHandleRPC(ctx, rs)
}
}

// TagRPC implements per-RPC context management.
func (c *ClientHandler) TagRPC(ctx context.Context, rti *stats.RPCTagInfo) context.Context {
if !c.NoTrace {
ctx = c.traceTagRPC(ctx, rti)
Expand Down
8 changes: 4 additions & 4 deletions plugin/ocgrpc/client_stats_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import (
"google.golang.org/grpc/status"
)

// TagRPC gets the tag.Map populated by the application code, serializes
// statsTagRPC gets the tag.Map populated by the application code, serializes
// its tags into the GRPC metadata in order to be sent to the server.
func (h *ClientHandler) statsTagRPC(ctx context.Context, info *stats.RPCTagInfo) context.Context {
startTime := time.Now()
Expand All @@ -54,7 +54,7 @@ func (h *ClientHandler) statsTagRPC(ctx context.Context, info *stats.RPCTagInfo)
return context.WithValue(ctx, grpcClientRPCKey, d)
}

// HandleRPC processes the RPC events.
// statsHandleRPC processes the RPC events.
func (h *ClientHandler) statsHandleRPC(ctx context.Context, s stats.RPCStats) {
switch st := s.(type) {
case *stats.Begin, *stats.OutHeader, *stats.InHeader, *stats.InTrailer, *stats.OutTrailer:
Expand Down Expand Up @@ -87,7 +87,7 @@ func (h *ClientHandler) handleRPCInPayload(ctx context.Context, s *stats.InPaylo
d, ok := ctx.Value(grpcClientRPCKey).(*rpcData)
if !ok {
if grpclog.V(2) {
grpclog.Infoln("clientHandler.handleRPCInPayload failed to retrieve *rpcData from context")
grpclog.Infoln("failed to retrieve *rpcData from context")
}
return
}
Expand All @@ -100,7 +100,7 @@ func (h *ClientHandler) handleRPCEnd(ctx context.Context, s *stats.End) {
d, ok := ctx.Value(grpcClientRPCKey).(*rpcData)
if !ok {
if grpclog.V(2) {
grpclog.Infoln("clientHandler.handleRPCEnd failed to retrieve *rpcData from context")
grpclog.Infoln("failed to retrieve *rpcData from context")
}
return
}
Expand Down
2 changes: 2 additions & 0 deletions plugin/ocgrpc/doc.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,6 @@

// Package ocgrpc contains OpenCensus stats and trace
// integrations for gRPC.
//
// Use ServerHandler for servers and ClientHandler for clients.
package ocgrpc // import "go.opencensus.io/plugin/ocgrpc"
92 changes: 56 additions & 36 deletions plugin/ocgrpc/grpc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (

"go.opencensus.io/stats/view"
"golang.org/x/net/context"
"google.golang.org/grpc/metadata"

"go.opencensus.io/trace"

Expand Down Expand Up @@ -70,45 +71,64 @@ func TestClientHandler(t *testing.T) {
}

func TestServerHandler(t *testing.T) {
ctx := context.Background()
te := &traceExporter{}
trace.RegisterExporter(te)
if err := ServerRequestCountView.Subscribe(); err != nil {
t.Fatal(err)
tests := []struct {
name string
newTrace bool
expectTraces int
}{
{"trust_metadata", false, 1},
{"no_trust_metadata", true, 0},
}

// Ensure we start tracing.
span := trace.NewSpan("/foo", nil, trace.StartOptions{
Sampler: trace.AlwaysSample(),
})
ctx = trace.WithSpan(ctx, span)

handler := &ServerHandler{}
ctx = handler.TagRPC(ctx, &stats.RPCTagInfo{
FullMethodName: "/service.foo/method",
})
handler.HandleRPC(ctx, &stats.Begin{
BeginTime: time.Now(),
})
handler.HandleRPC(ctx, &stats.End{
EndTime: time.Now(),
})

stats, err := view.RetrieveData(ServerRequestCountView.Name)
if err != nil {
t.Fatal(err)
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {

ctx := context.Background()

handler := &ServerHandler{
IsPublicEndpoint: test.newTrace,
StartOptions: trace.StartOptions{
Sampler: trace.ProbabilitySampler(0.0),
},
}

te := &traceExporter{}
trace.RegisterExporter(te)
if err := ServerRequestCountView.Subscribe(); err != nil {
t.Fatal(err)
}

md := metadata.MD{
"grpc-trace-bin": []string{string([]byte{0, 0, 62, 116, 14, 118, 117, 157, 126, 7, 114, 152, 102, 125, 235, 34, 114, 238, 1, 187, 201, 24, 210, 231, 20, 175, 241, 2, 1})},
}
ctx = metadata.NewIncomingContext(ctx, md)
ctx = handler.TagRPC(ctx, &stats.RPCTagInfo{
FullMethodName: "/service.foo/method",
})
handler.HandleRPC(ctx, &stats.Begin{
BeginTime: time.Now(),
})
handler.HandleRPC(ctx, &stats.End{
EndTime: time.Now(),
})

rows, err := view.RetrieveData(ServerRequestCountView.Name)
if err != nil {
t.Fatal(err)
}
traces := te.buffer

if got, want := len(rows), 1; got != want {
t.Errorf("Got %v rows; want %v", got, want)
}
if got, want := len(traces), test.expectTraces; got != want {
t.Errorf("Got %v traces; want %v", got, want)
}

// Cleanup.
view.Unsubscribe(ServerRequestCountView)
})
}
traces := te.buffer

if got, want := len(stats), 1; got != want {
t.Errorf("Got %v stats; want %v", got, want)
}
if got, want := len(traces), 1; got != want {
t.Errorf("Got %v traces; want %v", got, want)
}

// Cleanup.
view.Unsubscribe(ServerRequestCountView)
}

type traceExporter struct {
Expand Down
41 changes: 36 additions & 5 deletions plugin/ocgrpc/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,41 +15,72 @@
package ocgrpc

import (
"go.opencensus.io/trace"
"golang.org/x/net/context"

"google.golang.org/grpc/stats"
)

// ServerHandler implements gRPC stats.Handler recording OpenCensus stats and
// traces. Use with gRPC servers.
//
// When installed (see Example), tracing metadata is read from inbound RPCs
// by default. If no tracing metadata is present, or if the tracing metadata is
// present but the SpanContext isn't sampled, then a new trace may be started
// (as determined by Sampler).
type ServerHandler struct {
// NoTrace may be set to disable recording OpenCensus Spans around
// gRPC methods.
// NoTrace may be set to true to disable OpenCensus tracing integration.
// If set to true, no trace metadata will be read from inbound RPCs and no
// new Spans will be created.
NoTrace bool

// NoStats may be set to disable recording OpenCensus Stats around each
// gRPC method.
// NoStats may be set to true to disable recording OpenCensus stats for RPCs.
NoStats bool

// IsPublicEndpoint may be set to true to always start a new trace around
// each RPC. Any SpanContext in the RPC metadata will be added as a linked
// span instead of making it the parent of the span created around the
// server RPC.
//
// Be aware that if you leave this false (the default) on a public-facing
// server, callers will be able to send tracing metadata in gRPC headers
// and trigger traces in your backend.
IsPublicEndpoint bool

// StartOptions to use for to spans started around RPCs handled by this server.
//
// These will apply even if there is tracing metadata already
// present on the inbound RPC but the SpanContext is not sampled. This
// ensures that each service has some opportunity to be traced. If you would
// like to not add any additional traces for this gRPC service, set:
// StartOptions.Sampler = trace.ProbabilitySampler(0.0)
StartOptions trace.StartOptions
}

var _ stats.Handler = (*ServerHandler)(nil)

// HandleConn exists to satisfy gRPC stats.Handler.
func (s *ServerHandler) HandleConn(ctx context.Context, cs stats.ConnStats) {
// no-op
}

// TagConn exists to satisfy gRPC stats.Handler.
func (s *ServerHandler) TagConn(ctx context.Context, cti *stats.ConnTagInfo) context.Context {
// no-op
return ctx
}

// HandleRPC implements per-RPC tracing and stats instrumentation.
func (s *ServerHandler) HandleRPC(ctx context.Context, rs stats.RPCStats) {
if !s.NoTrace {
s.traceHandleRPC(ctx, rs)
traceHandleRPC(ctx, rs)
}
if !s.NoStats {
s.statsHandleRPC(ctx, rs)
}
}

// TagRPC implements per-RPC context management.
func (s *ServerHandler) TagRPC(ctx context.Context, rti *stats.RPCTagInfo) context.Context {
if !s.NoTrace {
ctx = s.traceTagRPC(ctx, rti)
Expand Down
8 changes: 4 additions & 4 deletions plugin/ocgrpc/server_stats_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import (
"google.golang.org/grpc/status"
)

// TagRPC gets the metadata from gRPC context, extracts the encoded tags from
// statsTagRPC gets the metadata from gRPC context, extracts the encoded tags from
// it and creates a new tag.Map and puts them into the returned context.
func (h *ServerHandler) statsTagRPC(ctx context.Context, info *stats.RPCTagInfo) context.Context {
startTime := time.Now()
Expand All @@ -48,7 +48,7 @@ func (h *ServerHandler) statsTagRPC(ctx context.Context, info *stats.RPCTagInfo)
return context.WithValue(ctx, grpcServerRPCKey, d)
}

// HandleRPC processes the RPC events.
// statsHandleRPC processes the RPC events.
func (h *ServerHandler) statsHandleRPC(ctx context.Context, s stats.RPCStats) {
switch st := s.(type) {
case *stats.Begin, *stats.InHeader, *stats.InTrailer, *stats.OutHeader, *stats.OutTrailer:
Expand All @@ -69,7 +69,7 @@ func (h *ServerHandler) handleRPCInPayload(ctx context.Context, s *stats.InPaylo
d, ok := ctx.Value(grpcServerRPCKey).(*rpcData)
if !ok {
if grpclog.V(2) {
grpclog.Infoln("serverHandler.handleRPCInPayload failed to retrieve *rpcData from context")
grpclog.Infoln("handleRPCInPayload: failed to retrieve *rpcData from context")
}
return
}
Expand All @@ -82,7 +82,7 @@ func (h *ServerHandler) handleRPCOutPayload(ctx context.Context, s *stats.OutPay
d, ok := ctx.Value(grpcServerRPCKey).(*rpcData)
if !ok {
if grpclog.V(2) {
grpclog.Infoln("serverHandler.handleRPCOutPayload failed to retrieve *rpcData from context")
grpclog.Infoln("handleRPCOutPayload: failed to retrieve *rpcData from context")
}
return
}
Expand Down
36 changes: 19 additions & 17 deletions plugin/ocgrpc/trace_common.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,28 +52,30 @@ func (c *ClientHandler) traceTagRPC(ctx context.Context, rti *stats.RPCTagInfo)
func (s *ServerHandler) traceTagRPC(ctx context.Context, rti *stats.RPCTagInfo) context.Context {
md, _ := metadata.FromIncomingContext(ctx)
name := "Recv" + strings.Replace(rti.FullMethodName, "/", ".", -1)
if s := md[traceContextKey]; len(s) > 0 {
if parent, ok := propagation.FromBinary([]byte(s[0])); ok {
span := trace.NewSpanWithRemoteParent(name, parent, trace.StartOptions{})
traceContext := md[traceContextKey]
var (
parent trace.SpanContext
haveParent bool
)
if len(traceContext) > 0 {
// Metadata with keys ending in -bin are actually binary. They are base64
// encoded before being put on the wire, see:
// https://github.com/grpc/grpc-go/blob/08d6261/Documentation/grpc-metadata.md#storing-binary-data-in-metadata
traceContextBinary := []byte(traceContext[0])
parent, haveParent = propagation.FromBinary(traceContextBinary)
if haveParent && !s.IsPublicEndpoint {
span := trace.NewSpanWithRemoteParent(name, parent, s.StartOptions)
return trace.WithSpan(ctx, span)
}
}
// TODO(ramonza): should we ignore the in-process parent here?
ctx, _ = trace.StartSpan(ctx, name)
return ctx
}

// HandleRPC processes the RPC stats, adding information to the current trace span.
func (c *ClientHandler) traceHandleRPC(ctx context.Context, rs stats.RPCStats) {
handleRPC(ctx, rs)
}

// HandleRPC processes the RPC stats, adding information to the current trace span.
func (s *ServerHandler) traceHandleRPC(ctx context.Context, rs stats.RPCStats) {
handleRPC(ctx, rs)
span := trace.NewSpan(name, nil, s.StartOptions)
if haveParent {
span.AddLink(trace.Link{TraceID: parent.TraceID, SpanID: parent.SpanID, Type: trace.LinkTypeChild})
}
return trace.WithSpan(ctx, span)
}

func handleRPC(ctx context.Context, rs stats.RPCStats) {
func traceHandleRPC(ctx context.Context, rs stats.RPCStats) {
span := trace.FromContext(ctx)
// TODO: compressed and uncompressed sizes are not populated in every message.
switch rs := rs.(type) {
Expand Down
10 changes: 5 additions & 5 deletions plugin/ocgrpc/trace_common_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ func (s *testServer) Multiple(stream testpb.Foo_MultipleServer) error {
}
}

func newTestClientAndServer() (client testpb.FooClient, server *grpc.Server, cleanup func(), err error) {
func newTracingOnlyTestClientAndServer() (client testpb.FooClient, server *grpc.Server, cleanup func(), err error) {
// initialize server
listener, err := net.Listen("tcp", "localhost:0")
if err != nil {
Expand Down Expand Up @@ -93,7 +93,7 @@ func TestStreaming(t *testing.T) {
trace.RegisterExporter(&te)
defer trace.UnregisterExporter(&te)

client, _, cleanup, err := newTestClientAndServer()
client, _, cleanup, err := newTracingOnlyTestClientAndServer()
if err != nil {
t.Fatalf("initializing client and server: %v", err)
}
Expand Down Expand Up @@ -139,7 +139,7 @@ func TestStreamingFail(t *testing.T) {
trace.RegisterExporter(&te)
defer trace.UnregisterExporter(&te)

client, _, cleanup, err := newTestClientAndServer()
client, _, cleanup, err := newTracingOnlyTestClientAndServer()
if err != nil {
t.Fatalf("initializing client and server: %v", err)
}
Expand Down Expand Up @@ -183,7 +183,7 @@ func TestSingle(t *testing.T) {
trace.RegisterExporter(&te)
defer trace.UnregisterExporter(&te)

client, _, cleanup, err := newTestClientAndServer()
client, _, cleanup, err := newTracingOnlyTestClientAndServer()
if err != nil {
t.Fatalf("initializing client and server: %v", err)
}
Expand Down Expand Up @@ -212,7 +212,7 @@ func TestSingleFail(t *testing.T) {
trace.RegisterExporter(&te)
defer trace.UnregisterExporter(&te)

client, _, cleanup, err := newTestClientAndServer()
client, _, cleanup, err := newTracingOnlyTestClientAndServer()
if err != nil {
t.Fatalf("initializing client and server: %v", err)
}
Expand Down

0 comments on commit 0ac2803

Please sign in to comment.