Skip to content
This repository has been archived by the owner on Jul 31, 2023. It is now read-only.

Allow configuring ochttp.Handler for public servers #563

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 24 additions & 5 deletions plugin/ochttp/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,12 @@ type Handler struct {
// StartOptions are applied to the span started by this Handler around each
// request.
StartOptions trace.StartOptions

// IsPublicEndpoint should be set to true for publicly accessible HTTP(S)
// servers. If true, any trace metadata set on the incoming request will
// be added as a linked trace instead of being added as a parent of the
// current trace.
IsPublicEndpoint bool
}

func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
Expand All @@ -77,22 +83,35 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {

func (h *Handler) startTrace(w http.ResponseWriter, r *http.Request) (*http.Request, func()) {
name := spanNameFromURL("Recv", r.URL)
p := h.Propagation
if p == nil {
p = defaultFormat
}
ctx := r.Context()
var span *trace.Span
if sc, ok := p.SpanContextFromRequest(r); ok {
sc, ok := h.extractSpanContext(r)
if ok && !h.IsPublicEndpoint {
span = trace.NewSpanWithRemoteParent(name, sc, h.StartOptions)
ctx = trace.WithSpan(ctx, span)
} else {
span = trace.NewSpan(name, nil, h.StartOptions)
if ok {
span.AddLink(trace.Link{
TraceID: sc.TraceID,
SpanID: sc.SpanID,
Type: trace.LinkTypeChild,
Attributes: nil,
})
}
}
ctx = trace.WithSpan(ctx, span)
span.SetAttributes(requestAttrs(r)...)
return r.WithContext(trace.WithSpan(r.Context(), span)), span.End
}

func (h *Handler) extractSpanContext(r *http.Request) (trace.SpanContext, bool) {
if h.Propagation == nil {
return defaultFormat.SpanContextFromRequest(r)
}
return h.Propagation.SpanContextFromRequest(r)
}

func (h *Handler) startStats(w http.ResponseWriter, r *http.Request) (http.ResponseWriter, func()) {
ctx, _ := tag.New(r.Context(),
tag.Upsert(Host, r.URL.Host),
Expand Down
250 changes: 152 additions & 98 deletions plugin/ochttp/trace_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ import (
"testing"
"time"

"go.opencensus.io/plugin/ochttp/propagation/b3"
"go.opencensus.io/plugin/ochttp/propagation/tracecontext"
"go.opencensus.io/trace"
)

Expand Down Expand Up @@ -141,8 +143,6 @@ func TestHandler(t *testing.T) {

for _, tt := range tests {
t.Run(tt.header, func(t *testing.T) {
propagator := &testPropagator{}

handler := &Handler{
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
span := trace.FromContext(r.Context())
Expand All @@ -155,7 +155,7 @@ func TestHandler(t *testing.T) {
}
}),
StartOptions: trace.StartOptions{Sampler: trace.ProbabilitySampler(0.0)},
Propagation: propagator,
Propagation: &testPropagator{},
}
req, _ := http.NewRequest("GET", "http://foo.com", nil)
req.Header.Add("trace", tt.header)
Expand All @@ -173,116 +173,170 @@ func (c *collector) ExportSpan(s *trace.SpanData) {
}

func TestEndToEnd(t *testing.T) {
var spans collector
trace.RegisterExporter(&spans)
defer trace.UnregisterExporter(&spans)

span := trace.NewSpan(
"top-level",
nil,
trace.StartOptions{
Sampler: trace.AlwaysSample(),
})
ctx := trace.WithSpan(context.Background(), span)

serverDone := make(chan struct{})
serverReturn := make(chan time.Time)
url := serveHTTP(serverDone, serverReturn)

req, err := http.NewRequest(
"POST",
fmt.Sprintf("%s/example/url/path?qparam=val", url),
strings.NewReader("expected-request-body"))
if err != nil {
t.Fatalf("unexpected error %#v", err)
trace.SetDefaultSampler(trace.AlwaysSample())

tc := []struct {
name string
handler *Handler
transport *Transport
wantSameTraceID bool
wantLinks bool // expect a link between client and server span
}{
{
name: "internal default propagation",
handler: &Handler{},
transport: &Transport{NoStats: true},
wantSameTraceID: true,
},
{
name: "external default propagation",
handler: &Handler{IsPublicEndpoint: true},
transport: &Transport{NoStats: true},
wantSameTraceID: false,
wantLinks: true,
},
{
name: "internal TraceContext propagation",
handler: &Handler{Propagation: &tracecontext.HTTPFormat{}},
transport: &Transport{NoStats: true, Propagation: &tracecontext.HTTPFormat{}},
wantSameTraceID: true,
},
{
name: "misconfigured propagation",
handler: &Handler{IsPublicEndpoint: true, Propagation: &tracecontext.HTTPFormat{}},
transport: &Transport{NoStats: true, Propagation: &b3.HTTPFormat{}},
wantSameTraceID: false,
wantLinks: false,
},
}
req = req.WithContext(ctx)

rt := &Transport{
NoStats: true,
Propagation: defaultFormat,
Base: http.DefaultTransport,
}
resp, err := rt.RoundTrip(req)
if err != nil {
t.Fatalf("unexpected error %s", err)
}
if resp.StatusCode != http.StatusOK {
t.Fatalf("unexpected stats: %d", resp.StatusCode)
}
for _, tt := range tc {
t.Run(tt.name, func(t *testing.T) {
var spans collector
trace.RegisterExporter(&spans)
defer trace.UnregisterExporter(&spans)

// Start the server.
serverDone := make(chan struct{})
serverReturn := make(chan time.Time)
url := serveHTTP(tt.handler, serverDone, serverReturn)

// Start a root Span in the client.
root := trace.NewSpan(
"top-level",
nil,
trace.StartOptions{})
ctx := trace.WithSpan(context.Background(), root)

// Make the request.
req, err := http.NewRequest(
http.MethodPost,
fmt.Sprintf("%s/example/url/path?qparam=val", url),
strings.NewReader("expected-request-body"))
if err != nil {
t.Fatal(err)
}
req = req.WithContext(ctx)
resp, err := tt.transport.RoundTrip(req)
if err != nil {
t.Fatal(err)
}
if resp.StatusCode != http.StatusOK {
t.Fatalf("resp.StatusCode = %d", resp.StatusCode)
}

serverReturn <- time.Now().Add(time.Millisecond)
// Tell the server to return from request handling.
serverReturn <- time.Now().Add(time.Millisecond)

respBody, err := ioutil.ReadAll(resp.Body)
if err != nil {
t.Fatalf("unexpected read error: %#v", err)
}
if string(respBody) != "expected-response" {
t.Fatalf("unexpected response: %s", string(respBody))
}
respBody, err := ioutil.ReadAll(resp.Body)
if err != nil {
t.Fatal(err)
}
if got, want := string(respBody), "expected-response"; got != want {
t.Fatalf("respBody = %q; want %q", got, want)
}

resp.Body.Close()
resp.Body.Close()

<-serverDone
trace.UnregisterExporter(&spans)
<-serverDone
trace.UnregisterExporter(&spans)

if got, want := len(spans), 2; got != want {
t.Fatalf("len(%#v) = %d; want %d", spans, got, want)
}

var client, server *trace.SpanData
for _, sp := range spans {
if strings.HasPrefix(sp.Name, "Sent.") {
client = sp
serverHostport := req.URL.Hostname() + ":" + req.URL.Port()
if got, want := client.Name, "Sent."+serverHostport+"/example/url/path"; got != want {
t.Errorf("Span name: %q; want %q", got, want)
if got, want := len(spans), 2; got != want {
t.Fatalf("len(spans) = %d; want %d", got, want)
}
} else if strings.HasPrefix(sp.Name, "Recv.") {
server = sp
if got, want := server.Name, "Recv./example/url/path"; got != want {
t.Errorf("Span name: %q; want %q", got, want)

var client, server *trace.SpanData
for _, sp := range spans {
if strings.HasPrefix(sp.Name, "Sent.") {
client = sp
serverHostport := req.URL.Hostname() + ":" + req.URL.Port()
if got, want := client.Name, "Sent."+serverHostport+"/example/url/path"; got != want {
t.Errorf("Span name: %q; want %q", got, want)
}
} else if strings.HasPrefix(sp.Name, "Recv.") {
server = sp
if got, want := server.Name, "Recv./example/url/path"; got != want {
t.Errorf("Span name: %q; want %q", got, want)
}
}
}
}
}

if server == nil || client == nil {
t.Fatalf("server or client span missing")
}
if server.TraceID != client.TraceID {
t.Errorf("TraceID does not match: server.TraceID=%q client.TraceID=%q", server.TraceID, client.TraceID)
}
if server.StartTime.Before(client.StartTime) {
t.Errorf("server span starts before client span")
}
if server.EndTime.After(client.EndTime) {
t.Errorf("client span ends before server span")
}
if !server.HasRemoteParent {
t.Errorf("server span should have remote parent")
}
if server.ParentSpanID != client.SpanID {
t.Errorf("server span should have client span as parent")
if server == nil || client == nil {
t.Fatalf("server or client span missing")
}
if tt.wantSameTraceID {
if server.TraceID != client.TraceID {
t.Errorf("TraceID does not match: server.TraceID=%q client.TraceID=%q", server.TraceID, client.TraceID)
}
if !server.HasRemoteParent {
t.Errorf("server span should have remote parent")
}
if server.ParentSpanID != client.SpanID {
t.Errorf("server span should have client span as parent")
}
}
if !tt.wantSameTraceID {
if server.TraceID == client.TraceID {
t.Errorf("TraceID should not be trusted")
}
}
if tt.wantLinks {
if got, want := len(server.Links), 1; got != want {
t.Errorf("len(server.Links) = %d; want %d", got, want)
} else {
link := server.Links[0]
if got, want := link.TraceID, root.SpanContext().TraceID; got != want {
t.Errorf("link.TraceID = %q; want %q", got, want)
}
if got, want := link.Type, trace.LinkTypeChild; got != want {
t.Errorf("link.Type = %v; want %v", got, want)
}
}
}
if server.StartTime.Before(client.StartTime) {
t.Errorf("server span starts before client span")
}
if server.EndTime.After(client.EndTime) {
t.Errorf("client span ends before server span")
}
})
}
}

func serveHTTP(done chan struct{}, wait chan time.Time) string {
handler := &Handler{
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
w.(http.Flusher).Flush()
func serveHTTP(handler *Handler, done chan struct{}, wait chan time.Time) string {
handler.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
w.(http.Flusher).Flush()

// simulate a slow-responding server
sleepUntil := <-wait
for time.Now().Before(sleepUntil) {
time.Sleep(sleepUntil.Sub(time.Now()))
}

io.WriteString(w, "expected-response")
close(done)
}),
}
// Simulate a slow-responding server.
sleepUntil := <-wait
for time.Now().Before(sleepUntil) {
time.Sleep(sleepUntil.Sub(time.Now()))
}

io.WriteString(w, "expected-response")
close(done)
})
server := httptest.NewServer(handler)
go func() {
<-done
Expand Down