Skip to content
This repository has been archived by the owner on Jul 31, 2023. It is now read-only.

Commit

Permalink
Allow configuring ochttp.Handler for public servers
Browse files Browse the repository at this point in the history
Added a flag on ochttp.Handler that causes the incoming SpanContext
to be added as a link rather than a parent. This is useful if we don't
trust the metadata.
  • Loading branch information
Ramon Nogueira committed Mar 13, 2018
1 parent eea28b8 commit 8ed1dd7
Show file tree
Hide file tree
Showing 2 changed files with 172 additions and 102 deletions.
29 changes: 24 additions & 5 deletions plugin/ochttp/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,12 @@ type Handler struct {
// StartOptions are applied to the span started by this Handler around each
// request.
StartOptions trace.StartOptions

// IsPublicEndpoint should be set to true for publicly accessible HTTP(S)
// servers. If true, any trace metadata set on the incoming request will
// be added as a linked trace instead of being added as a parent of the
// current trace.
IsPublicEndpoint bool
}

func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
Expand All @@ -77,22 +83,35 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {

func (h *Handler) startTrace(w http.ResponseWriter, r *http.Request) (*http.Request, func()) {
name := spanNameFromURL("Recv", r.URL)
p := h.Propagation
if p == nil {
p = defaultFormat
}
ctx := r.Context()
var span *trace.Span
if sc, ok := p.SpanContextFromRequest(r); ok {
sc, haveSC := h.extractSpanContext(r)
if haveSC && !h.IsPublicEndpoint {
span = trace.NewSpanWithRemoteParent(name, sc, h.StartOptions)
ctx = trace.WithSpan(ctx, span)
} else {
span = trace.NewSpan(name, nil, h.StartOptions)
if haveSC {
span.AddLink(trace.Link{
TraceID: sc.TraceID,
SpanID: sc.SpanID,
Type: trace.LinkTypeChild,
Attributes: nil,
})
}
}
ctx = trace.WithSpan(ctx, span)
span.SetAttributes(requestAttrs(r)...)
return r.WithContext(trace.WithSpan(r.Context(), span)), span.End
}

func (h *Handler) extractSpanContext(r *http.Request) (trace.SpanContext, bool) {
if h.Propagation == nil {
return defaultFormat.SpanContextFromRequest(r)
}
return h.Propagation.SpanContextFromRequest(r)
}

func (h *Handler) startStats(w http.ResponseWriter, r *http.Request) (http.ResponseWriter, func()) {
ctx, _ := tag.New(r.Context(),
tag.Upsert(Host, r.URL.Host),
Expand Down
245 changes: 148 additions & 97 deletions plugin/ochttp/trace_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ import (
"testing"
"time"

"go.opencensus.io/plugin/ochttp/propagation/b3"
"go.opencensus.io/plugin/ochttp/propagation/tracecontext"
"go.opencensus.io/trace"
)

Expand Down Expand Up @@ -141,8 +143,6 @@ func TestHandler(t *testing.T) {

for _, tt := range tests {
t.Run(tt.header, func(t *testing.T) {
propagator := &testPropagator{}

handler := &Handler{
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
span := trace.FromContext(r.Context())
Expand All @@ -155,7 +155,7 @@ func TestHandler(t *testing.T) {
}
}),
StartOptions: trace.StartOptions{Sampler: trace.ProbabilitySampler(0.0)},
Propagation: propagator,
Propagation: &testPropagator{},
}
req, _ := http.NewRequest("GET", "http://foo.com", nil)
req.Header.Add("trace", tt.header)
Expand All @@ -173,116 +173,167 @@ func (c *collector) ExportSpan(s *trace.SpanData) {
}

func TestEndToEnd(t *testing.T) {
var spans collector
trace.RegisterExporter(&spans)
defer trace.UnregisterExporter(&spans)

span := trace.NewSpan(
"top-level",
nil,
trace.StartOptions{
Sampler: trace.AlwaysSample(),
})
ctx := trace.WithSpan(context.Background(), span)

serverDone := make(chan struct{})
serverReturn := make(chan time.Time)
url := serveHTTP(serverDone, serverReturn)

req, err := http.NewRequest(
"POST",
fmt.Sprintf("%s/example/url/path?qparam=val", url),
strings.NewReader("expected-request-body"))
if err != nil {
t.Fatalf("unexpected error %#v", err)
trace.SetDefaultSampler(trace.AlwaysSample())

cases := []struct {
Name string
ServerMiddleware *Handler
ClientMiddleware *Transport
ExpectSameTraceID bool
ExpectLinks bool // expect a link between client and server span
}{
{
Name: "internal default propagation",
ServerMiddleware: &Handler{},
ClientMiddleware: &Transport{NoStats: true},
ExpectSameTraceID: true,
},
{
Name: "external default propagation",
ServerMiddleware: &Handler{IsPublicEndpoint: true},
ClientMiddleware: &Transport{NoStats: true},
ExpectSameTraceID: false,
ExpectLinks: true,
},
{
Name: "internal TraceContext propagation",
ServerMiddleware: &Handler{Propagation: &tracecontext.HTTPFormat{}},
ClientMiddleware: &Transport{NoStats: true, Propagation: &tracecontext.HTTPFormat{}},
ExpectSameTraceID: true,
},
{
Name: "misconfigured propagation",
ServerMiddleware: &Handler{IsPublicEndpoint: true, Propagation: &tracecontext.HTTPFormat{}},
ClientMiddleware: &Transport{NoStats: true, Propagation: &b3.HTTPFormat{}},
ExpectSameTraceID: false,
ExpectLinks: false,
},
}
req = req.WithContext(ctx)

rt := &Transport{
NoStats: true,
Propagation: defaultFormat,
Base: http.DefaultTransport,
}
resp, err := rt.RoundTrip(req)
if err != nil {
t.Fatalf("unexpected error %s", err)
}
if resp.StatusCode != http.StatusOK {
t.Fatalf("unexpected stats: %d", resp.StatusCode)
}
for _, c := range cases {
t.Run(c.Name, func(t *testing.T) {
var spans collector
trace.RegisterExporter(&spans)
defer trace.UnregisterExporter(&spans)

root := trace.NewSpan(
"top-level",
nil,
trace.StartOptions{})
ctx := trace.WithSpan(context.Background(), root)

serverDone := make(chan struct{})
serverReturn := make(chan time.Time)
url := serveHTTP(c.ServerMiddleware, serverDone, serverReturn)

req, err := http.NewRequest(
http.MethodPost,
fmt.Sprintf("%s/example/url/path?qparam=val", url),
strings.NewReader("expected-request-body"))
if err != nil {
t.Fatalf("unexpected error %#v", err)
}
req = req.WithContext(ctx)

serverReturn <- time.Now().Add(time.Millisecond)
resp, err := c.ClientMiddleware.RoundTrip(req)
if err != nil {
t.Fatalf("unexpected error %s", err)
}
if resp.StatusCode != http.StatusOK {
t.Fatalf("unexpected stats: %d", resp.StatusCode)
}

respBody, err := ioutil.ReadAll(resp.Body)
if err != nil {
t.Fatalf("unexpected read error: %#v", err)
}
if string(respBody) != "expected-response" {
t.Fatalf("unexpected response: %s", string(respBody))
}
serverReturn <- time.Now().Add(time.Millisecond)

resp.Body.Close()
respBody, err := ioutil.ReadAll(resp.Body)
if err != nil {
t.Fatalf("unexpected read error: %#v", err)
}
if string(respBody) != "expected-response" {
t.Fatalf("unexpected response: %s", string(respBody))
}

<-serverDone
trace.UnregisterExporter(&spans)
resp.Body.Close()

if got, want := len(spans), 2; got != want {
t.Fatalf("len(%#v) = %d; want %d", spans, got, want)
}
<-serverDone
trace.UnregisterExporter(&spans)

var client, server *trace.SpanData
for _, sp := range spans {
if strings.HasPrefix(sp.Name, "Sent.") {
client = sp
serverHostport := req.URL.Hostname() + ":" + req.URL.Port()
if got, want := client.Name, "Sent."+serverHostport+"/example/url/path"; got != want {
t.Errorf("Span name: %q; want %q", got, want)
if got, want := len(spans), 2; got != want {
t.Fatalf("len(spans) = %d; want %d", got, want)
}
} else if strings.HasPrefix(sp.Name, "Recv.") {
server = sp
if got, want := server.Name, "Recv./example/url/path"; got != want {
t.Errorf("Span name: %q; want %q", got, want)

var client, server *trace.SpanData
for _, sp := range spans {
if strings.HasPrefix(sp.Name, "Sent.") {
client = sp
serverHostport := req.URL.Hostname() + ":" + req.URL.Port()
if got, want := client.Name, "Sent."+serverHostport+"/example/url/path"; got != want {
t.Errorf("Span name: %q; want %q", got, want)
}
} else if strings.HasPrefix(sp.Name, "Recv.") {
server = sp
if got, want := server.Name, "Recv./example/url/path"; got != want {
t.Errorf("Span name: %q; want %q", got, want)
}
}
}
}
}

if server == nil || client == nil {
t.Fatalf("server or client span missing")
}
if server.TraceID != client.TraceID {
t.Errorf("TraceID does not match: server.TraceID=%q client.TraceID=%q", server.TraceID, client.TraceID)
}
if server.StartTime.Before(client.StartTime) {
t.Errorf("server span starts before client span")
}
if server.EndTime.After(client.EndTime) {
t.Errorf("client span ends before server span")
}
if !server.HasRemoteParent {
t.Errorf("server span should have remote parent")
}
if server.ParentSpanID != client.SpanID {
t.Errorf("server span should have client span as parent")
if server == nil || client == nil {
t.Fatalf("server or client span missing")
}
if c.ExpectSameTraceID {
if server.TraceID != client.TraceID {
t.Errorf("TraceID does not match: server.TraceID=%q client.TraceID=%q", server.TraceID, client.TraceID)
}
if !server.HasRemoteParent {
t.Errorf("server span should have remote parent")
}
if server.ParentSpanID != client.SpanID {
t.Errorf("server span should have client span as parent")
}
}
if !c.ExpectSameTraceID {
if server.TraceID == client.TraceID {
t.Errorf("TraceID should not be trusted")
}
}
if c.ExpectLinks {
if got, want := len(server.Links), 1; got != want {
t.Errorf("len(server.Links) = %d; want %d", got, want)
} else {
link := server.Links[0]
if got, want := link.TraceID, root.SpanContext().TraceID; got != want {
t.Errorf("link.TraceID = %q; want %q", got, want)
}
if got, want := link.Type, trace.LinkTypeChild; got != want {
t.Errorf("link.Type = %v; want %v", got, want)
}
}
}
if server.StartTime.Before(client.StartTime) {
t.Errorf("server span starts before client span")
}
if server.EndTime.After(client.EndTime) {
t.Errorf("client span ends before server span")
}
})
}
}

func serveHTTP(done chan struct{}, wait chan time.Time) string {
handler := &Handler{
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
w.(http.Flusher).Flush()
func serveHTTP(handler *Handler, done chan struct{}, wait chan time.Time) string {
handler.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
w.(http.Flusher).Flush()

// simulate a slow-responding server
sleepUntil := <-wait
for time.Now().Before(sleepUntil) {
time.Sleep(sleepUntil.Sub(time.Now()))
}

io.WriteString(w, "expected-response")
close(done)
}),
}
// simulate a slow-responding server
sleepUntil := <-wait
for time.Now().Before(sleepUntil) {
time.Sleep(sleepUntil.Sub(time.Now()))
}

io.WriteString(w, "expected-response")
close(done)
})
server := httptest.NewServer(handler)
go func() {
<-done
Expand Down

0 comments on commit 8ed1dd7

Please sign in to comment.