diff --git a/raft/node.go b/raft/node.go index f3ba250b9af9..ec3487b6e4a8 100644 --- a/raft/node.go +++ b/raft/node.go @@ -224,10 +224,15 @@ func RestartNode(c *Config) Node { return &n } +type proposingMsg struct { + m pb.Message + result chan error +} + // node is the canonical implementation of the Node interface type node struct { - propc chan pb.Message - recvc chan pb.Message + propc chan proposingMsg + recvc chan proposingMsg confc chan pb.ConfChange confstatec chan pb.ConfState readyc chan Ready @@ -242,8 +247,8 @@ type node struct { func newNode() node { return node{ - propc: make(chan pb.Message), - recvc: make(chan pb.Message), + propc: make(chan proposingMsg), + recvc: make(chan proposingMsg), confc: make(chan pb.ConfChange), confstatec: make(chan pb.ConfState), readyc: make(chan Ready), @@ -271,7 +276,7 @@ func (n *node) Stop() { } func (n *node) run(r *raft) { - var propc chan pb.Message + var propc chan proposingMsg var readyc chan Ready var advancec chan struct{} var prevLastUnstablei, prevLastUnstablet uint64 @@ -314,13 +319,27 @@ func (n *node) run(r *raft) { // TODO: maybe buffer the config propose if there exists one (the way // described in raft dissertation) // Currently it is dropped in Step silently. - case m := <-propc: + case pm := <-propc: + m := pm.m m.From = r.id - r.Step(m) - case m := <-n.recvc: + err := r.Step(m) + if pm.result != nil { + if err == ErrProposalDropped { + pm.result <- err + } + close(pm.result) + } + case pm := <-n.recvc: + m := pm.m // filter out response message from unknown From. if pr := r.getProgress(m.From); pr != nil || !IsResponseMsg(m.Type) { - r.Step(m) // raft never returns an error + err := r.Step(m) // raft never returns an error + if pm.result != nil { + if err == ErrProposalDropped { + pm.result <- err + } + close(pm.result) + } } case cc := <-n.confc: if cc.NodeID == None { @@ -408,7 +427,7 @@ func (n *node) Tick() { func (n *node) Campaign(ctx context.Context) error { return n.step(ctx, pb.Message{Type: pb.MsgHup}) } func (n *node) Propose(ctx context.Context, data []byte) error { - return n.step(ctx, pb.Message{Type: pb.MsgProp, Entries: []pb.Entry{{Data: data}}}) + return n.stepWait(ctx, pb.Message{Type: pb.MsgProp, Entries: []pb.Entry{{Data: data}}}, true) } func (n *node) Step(ctx context.Context, m pb.Message) error { @@ -428,22 +447,43 @@ func (n *node) ProposeConfChange(ctx context.Context, cc pb.ConfChange) error { return n.Step(ctx, pb.Message{Type: pb.MsgProp, Entries: []pb.Entry{{Type: pb.EntryConfChange, Data: data}}}) } +func (n *node) step(ctx context.Context, m pb.Message) error { + return n.stepWait(ctx, m, false) +} + // Step advances the state machine using msgs. The ctx.Err() will be returned, // if any. -func (n *node) step(ctx context.Context, m pb.Message) error { +func (n *node) stepWait(ctx context.Context, m pb.Message, wait bool) error { ch := n.recvc if m.Type == pb.MsgProp { ch = n.propc } + pm := proposingMsg{m: m, result: nil} + if wait { + pm.result = make(chan error) + } select { - case ch <- m: - return nil + case ch <- pm: + if !wait { + return nil + } + case <-ctx.Done(): + return ctx.Err() + case <-n.done: + return ErrStopped + } + select { + case rsp := <-pm.result: + if rsp != nil { + return rsp + } case <-ctx.Done(): return ctx.Err() case <-n.done: return ErrStopped } + return nil } func (n *node) Ready() <-chan Ready { return n.readyc } @@ -480,7 +520,7 @@ func (n *node) Status() Status { func (n *node) ReportUnreachable(id uint64) { select { - case n.recvc <- pb.Message{Type: pb.MsgUnreachable, From: id}: + case n.recvc <- proposingMsg{m: pb.Message{Type: pb.MsgUnreachable, From: id}, result: nil}: case <-n.done: } } @@ -489,7 +529,7 @@ func (n *node) ReportSnapshot(id uint64, status SnapshotStatus) { rej := status == SnapshotFailure select { - case n.recvc <- pb.Message{Type: pb.MsgSnapStatus, From: id, Reject: rej}: + case n.recvc <- proposingMsg{m: pb.Message{Type: pb.MsgSnapStatus, From: id, Reject: rej}, result: nil}: case <-n.done: } } @@ -497,7 +537,7 @@ func (n *node) ReportSnapshot(id uint64, status SnapshotStatus) { func (n *node) TransferLeadership(ctx context.Context, lead, transferee uint64) { select { // manually set 'from' and 'to', so that leader can voluntarily transfers its leadership - case n.recvc <- pb.Message{Type: pb.MsgTransferLeader, From: transferee, To: lead}: + case n.recvc <- proposingMsg{m: pb.Message{Type: pb.MsgTransferLeader, From: transferee, To: lead}, result: nil}: case <-n.done: case <-ctx.Done(): } diff --git a/raft/node_test.go b/raft/node_test.go index f884f3319a5f..4554dd6a321c 100644 --- a/raft/node_test.go +++ b/raft/node_test.go @@ -18,6 +18,7 @@ import ( "bytes" "context" "reflect" + "strings" "testing" "time" @@ -30,8 +31,8 @@ import ( func TestNodeStep(t *testing.T) { for i, msgn := range raftpb.MessageType_name { n := &node{ - propc: make(chan raftpb.Message, 1), - recvc: make(chan raftpb.Message, 1), + propc: make(chan proposingMsg, 1), + recvc: make(chan proposingMsg, 1), } msgt := raftpb.MessageType(i) n.Step(context.TODO(), raftpb.Message{Type: msgt}) @@ -64,7 +65,7 @@ func TestNodeStep(t *testing.T) { func TestNodeStepUnblock(t *testing.T) { // a node without buffer to block step n := &node{ - propc: make(chan raftpb.Message), + propc: make(chan proposingMsg), done: make(chan struct{}), } @@ -433,6 +434,49 @@ func TestBlockProposal(t *testing.T) { } } +func TestNodeProposeWaitDropped(t *testing.T) { + msgs := []raftpb.Message{} + droppingMsg := []byte("test_dropping") + dropStep := func(r *raft, m raftpb.Message) error { + if m.Type == raftpb.MsgProp && strings.Contains(m.String(), string(droppingMsg)) { + t.Logf("dropping message: %v", m.String()) + return ErrProposalDropped + } + msgs = append(msgs, m) + return nil + } + + n := newNode() + s := NewMemoryStorage() + r := newTestRaft(1, []uint64{1}, 10, 1, s) + go n.run(r) + n.Campaign(context.TODO()) + for { + rd := <-n.Ready() + s.Append(rd.Entries) + // change the step function to dropStep until this raft becomes leader + if rd.SoftState.Lead == r.id { + r.step = dropStep + n.Advance() + break + } + n.Advance() + } + proposalTimeout := time.Millisecond * 100 + ctx, cancel := context.WithTimeout(context.Background(), proposalTimeout) + // propose with cancel should be cancelled earyly if dropped + err := n.Propose(ctx, droppingMsg) + if err != ErrProposalDropped { + t.Errorf("should drop proposal : %v", err) + } + cancel() + + n.Stop() + if len(msgs) != 0 { + t.Fatalf("len(msgs) = %d, want %d", len(msgs), 1) + } +} + // TestNodeTick ensures that node.Tick() will increase the // elapsed of the underlying raft state machine. func TestNodeTick(t *testing.T) {