diff --git a/tools/benchmark/cmd/root.go b/tools/benchmark/cmd/root.go index e37e87c206c..6ac798f5c06 100644 --- a/tools/benchmark/cmd/root.go +++ b/tools/benchmark/cmd/root.go @@ -52,6 +52,8 @@ var ( user string dialTimeout time.Duration + + targetLeader bool ) func init() { @@ -67,4 +69,6 @@ func init() { RootCmd.PersistentFlags().StringVar(&user, "user", "", "specify username and password in username:password format") RootCmd.PersistentFlags().DurationVar(&dialTimeout, "dial-timeout", 0, "dial timeout for client connections") + + RootCmd.PersistentFlags().BoolVar(&targetLeader, "target-leader", false, "connect only to the leader node") } diff --git a/tools/benchmark/cmd/util.go b/tools/benchmark/cmd/util.go index 7369db701e7..7775acce426 100644 --- a/tools/benchmark/cmd/util.go +++ b/tools/benchmark/cmd/util.go @@ -23,19 +23,53 @@ import ( "github.com/coreos/etcd/clientv3" "github.com/coreos/etcd/pkg/report" + "golang.org/x/net/context" ) var ( // dialTotal counts the number of mustCreateConn calls so that endpoint // connections can be handed out in round-robin order dialTotal int + + // leaderEps is a cache for holding endpoints of a leader node + leaderEps []string ) +func mustFindLeaderEndpoints(c *clientv3.Client) { + resp, lerr := c.MemberList(context.TODO()) + if lerr != nil { + fmt.Fprintf(os.Stderr, "failed to get a member list: %s\n", lerr) + os.Exit(1) + } + + leaderId := uint64(0) + for _, ep := range c.Endpoints() { + resp, serr := c.Status(context.TODO(), ep) + if serr == nil { + leaderId = resp.Leader + break + } + } + + for _, m := range resp.Members { + if m.ID == leaderId { + leaderEps = m.ClientURLs + return + } + } + + fmt.Fprintf(os.Stderr, "failed to find a leader endpoint\n") + os.Exit(1) +} + func mustCreateConn() *clientv3.Client { - endpoint := endpoints[dialTotal%len(endpoints)] - dialTotal++ + connEndpoints := leaderEps + if len(connEndpoints) == 0 { + connEndpoints = []string{endpoints[dialTotal%len(endpoints)]} + dialTotal++ + } cfg := clientv3.Config{ - Endpoints: []string{endpoint}, + Endpoints: connEndpoints, DialTimeout: dialTimeout, } if !tls.Empty() { @@ -59,12 +93,19 @@ func mustCreateConn() *clientv3.Client { } client, err := clientv3.New(cfg) + if targetLeader && len(leaderEps) == 0 { + mustFindLeaderEndpoints(client) + client.Close() + return mustCreateConn() + } + clientv3.SetLogger(log.New(os.Stderr, "grpc", 0)) if err != nil { fmt.Fprintf(os.Stderr, "dial error: %v\n", err) os.Exit(1) } + return client }