diff --git a/README.md b/README.md index 2dbd1564..620d9ae7 100644 --- a/README.md +++ b/README.md @@ -100,6 +100,8 @@ attack command: Duration of the test [0 = forever] -format string Targets format [http, json] (default "http") + -graphql + Enable GraphQL response awareness -h2c Send HTTP/2 requests without TLS encryption -header value diff --git a/attack.go b/attack.go index bad0896e..5a93d79e 100644 --- a/attack.go +++ b/attack.go @@ -58,6 +58,7 @@ func attackCmd() command { fs.StringVar(&opts.unixSocket, "unix-socket", "", "Connect over a unix socket. This overrides the host address in target URLs") fs.Var(&dnsTTLFlag{&opts.dnsTTL}, "dns-ttl", "Cache DNS lookups for the given duration [-1 = disabled, 0 = forever]") fs.BoolVar(&opts.sessionTickets, "session-tickets", false, "Enable TLS session resumption using session tickets") + fs.BoolVar(&opts.graphql, "graphql", false, "Enable GraphQL response awareness") systemSpecificFlags(fs, opts) return command{fs, func(args []string) error { @@ -103,6 +104,7 @@ type attackOpts struct { unixSocket string dnsTTL time.Duration sessionTickets bool + graphql bool } // attack validates the attack arguments, sets up the @@ -196,6 +198,7 @@ func attack(opts *attackOpts) (err error) { vegeta.ChunkedBody(opts.chunked), vegeta.DNSCaching(opts.dnsTTL), vegeta.SessionTickets(opts.sessionTickets), + vegeta.GraphQL(opts.graphql), ) res := atk.Attack(tr, opts.rate, opts.duration, opts.name) diff --git a/lib/attack.go b/lib/attack.go index ce1e76a7..ea785848 100644 --- a/lib/attack.go +++ b/lib/attack.go @@ -3,6 +3,7 @@ package vegeta import ( "context" "crypto/tls" + "encoding/json" "fmt" "io" "io/ioutil" @@ -30,6 +31,7 @@ type Attacker struct { maxBody int64 redirects int chunked bool + graphql bool } const ( @@ -373,6 +375,11 @@ func DNSCaching(ttl time.Duration) func(*Attacker) { } } +// GraphQL returns a functional option which sets the attacker's graphql awareness +func GraphQL(gql bool) func(*Attacker) { + return func(a *Attacker) { a.graphql = gql } +} + type attack struct { name string began time.Time @@ -475,9 +482,19 @@ func (a *Attacker) attack(tr Targeter, atk *attack, workers *sync.WaitGroup, tic } } +type GQLResponse struct { + Data interface{} + Errors []GQLError +} + +type GQLError struct { + Extensions interface{} + Message string +} + func (a *Attacker) hit(tr Targeter, atk *attack) *Result { var ( - res = Result{Attack: atk.name} + res = Result{Attack: atk.name, GraphQL: a.graphql} tgt Target err error ) @@ -521,6 +538,9 @@ func (a *Attacker) hit(tr Targeter, atk *attack) *Result { if atk.name != "" { req.Header.Set("X-Vegeta-Attack", atk.name) } + if a.graphql { + req.Header.Set("X-Vegeta-Gql", "true") + } req.Header.Set("X-Vegeta-Seq", strconv.FormatUint(res.Seq, 10)) @@ -557,5 +577,26 @@ func (a *Attacker) hit(tr Targeter, atk *attack) *Result { res.Headers = r.Header + if a.graphql { + var response GQLResponse + err = json.Unmarshal(res.Body, &response) + if err != nil { + res.Error = err.Error() + } + + if response.Errors != nil && len(response.Errors) > 0 { + var error string + for i, e := range response.Errors { + if i == 0 { + error = e.Message + } else { + error += fmt.Sprintf("%v, %v", error, e.Message) + } + } + + res.Error = error + } + } + return &res } diff --git a/lib/attack_test.go b/lib/attack_test.go index 88d09aa6..db2f31d5 100644 --- a/lib/attack_test.go +++ b/lib/attack_test.go @@ -407,3 +407,37 @@ func TestVegetaHeaders(t *testing.T) { } } } + +func TestGraphqlAwareness(t *testing.T) { + t.Parallel() + + errRes := []byte("{\"data\":{}, \"errors\":[{\"message\":\"no punch\"}]}") + ok := []byte("{\"data\":{\"vegeta\":\"punch\"}}") + server := httptest.NewServer( + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if seq, err := strconv.Atoi(r.Header.Get("X-Vegeta-Seq")); err == nil && seq%2 == 0 { + w.Write(ok) + } else { + w.Write(errRes) + } + }), + ) + defer server.Close() + + tr := NewStaticTargeter(Target{Method: "POST", URL: server.URL}) + a := NewAttacker(GraphQL(true)) + atk := &attack{name: "gql", began: time.Now()} + for seq := 0; seq < 5; seq++ { + res := a.hit(tr, atk) + + if !res.GraphQL { + t.Errorf("GraphQL flag not set") + } + if have, want := res.Code, uint16(200); have != want { + t.Errorf("Code: have %q, want %q", have, want) + } + if have, want := res.Error, "no punch"; seq%2 != 0 && have != want { + t.Errorf("Expected error: have %q, want %q", have, want) + } + } +} diff --git a/lib/metrics.go b/lib/metrics.go index 5d7dbf9b..9c37531a 100644 --- a/lib/metrics.go +++ b/lib/metrics.go @@ -69,7 +69,7 @@ func (m *Metrics) Add(r *Result) { m.End = end } - if r.Code >= 200 && r.Code < 400 { + if r.Code >= 200 && r.Code < 400 && (!r.GraphQL || r.Error == "") { m.success++ } diff --git a/lib/metrics_test.go b/lib/metrics_test.go index 8d22f1a4..5c3610cf 100644 --- a/lib/metrics_test.go +++ b/lib/metrics_test.go @@ -22,13 +22,15 @@ func TestMetrics_Add(t *testing.T) { var got Metrics for i := 1; i <= 10000; i++ { + code := codes[i%len(codes)] got.Add(&Result{ - Code: codes[i%len(codes)], + Code: code, Timestamp: time.Unix(int64(i-1), 0), Latency: time.Duration(i) * time.Microsecond, BytesIn: 1024, BytesOut: 512, Error: errors[i%len(errors)], + GraphQL: code == 200 && i%10 == 0, }) } got.Close() @@ -62,8 +64,8 @@ func TestMetrics_Add(t *testing.T) { Wait: duration("10ms"), Requests: 10000, Rate: 1.000100010001, - Throughput: 0.6667660098349737, - Success: 0.6667, + Throughput: 0.6333627029075878, + Success: 0.6333, StatusCodes: map[string]int{"500": 3333, "200": 3334, "302": 3333}, Errors: []string{"Internal server error"}, diff --git a/lib/results.go b/lib/results.go index eacd42d0..29057aaf 100644 --- a/lib/results.go +++ b/lib/results.go @@ -36,6 +36,7 @@ type Result struct { Method string `json:"method"` URL string `json:"url"` Headers http.Header `json:"headers"` + GraphQL bool `json:"graphql"` } // End returns the time at which a Result ended.