diff --git a/pubsub/iterator.go b/pubsub/iterator.go index e551802596d2..fa7ae592d16b 100644 --- a/pubsub/iterator.go +++ b/pubsub/iterator.go @@ -68,12 +68,15 @@ type messageIterator struct { // message arrives, we'll record now+MaxExtension in this table; whenever we have a chance // to update ack deadlines (via modack), we'll consult this table and only include IDs // that are not beyond their deadline. - keepAliveDeadlines map[string]time.Time - pendingAcks map[string]bool - pendingNacks map[string]bool - pendingModAcks map[string]bool // ack IDs whose ack deadline is to be modified - err error // error from stream failure + keepAliveDeadlines map[string]time.Time + pendingAcks map[string]bool + pendingNacks map[string]bool + pendingModAcks map[string]bool // ack IDs whose ack deadline is to be modified + err error // error from stream failure + + eoMu sync.RWMutex enableExactlyOnceDelivery bool + sendNewAckDeadline bool } // newMessageIterator starts and returns a new messageIterator. @@ -280,7 +283,12 @@ func (it *messageIterator) recvMessages() ([]*pb.ReceivedMessage, error) { if err != nil { return nil, err } - it.enableExactlyOnceDelivery = res.GetSubscriptionProperties().GetExactlyOnceDeliveryEnabled() + it.eoMu.Lock() + if got := res.GetSubscriptionProperties().GetExactlyOnceDeliveryEnabled(); got != it.enableExactlyOnceDelivery { + it.sendNewAckDeadline = true + it.enableExactlyOnceDelivery = got + } + it.eoMu.Unlock() return res.ReceivedMessages, nil } @@ -534,8 +542,14 @@ func (it *messageIterator) sendAckIDRPC(ackIDSet map[string]bool, maxSize int, c // default ack deadline, and if the messages are small enough so that many can fit // into the buffer. func (it *messageIterator) pingStream() { - // Ignore error; if the stream is broken, this doesn't matter anyway. - _ = it.ps.Send(&pb.StreamingPullRequest{}) + spr := &pb.StreamingPullRequest{} + it.eoMu.RLock() + if it.sendNewAckDeadline { + spr.StreamAckDeadlineSeconds = int32(it.ackDeadline()) + it.sendNewAckDeadline = false + } + it.eoMu.RUnlock() + it.ps.Send(spr) } // calcFieldSizeString returns the number of bytes string fields @@ -583,8 +597,10 @@ func splitRequestIDs(ids []string, maxSize int) (prefix, remainder []string) { // expiration. func (it *messageIterator) ackDeadline() time.Duration { pt := time.Duration(it.ackTimeDist.Percentile(.99)) * time.Second - - return boundedDuration(pt, it.po.minExtensionPeriod, it.po.maxExtensionPeriod, it.enableExactlyOnceDelivery) + it.eoMu.RLock() + enableExactlyOnce := it.enableExactlyOnceDelivery + it.eoMu.RUnlock() + return boundedDuration(pt, it.po.minExtensionPeriod, it.po.maxExtensionPeriod, enableExactlyOnce) } func boundedDuration(ackDeadline, minExtension, maxExtension time.Duration, exactlyOnce bool) time.Duration { diff --git a/pubsub/iterator_test.go b/pubsub/iterator_test.go index eff9bb0507d6..9ba21fc7f8a5 100644 --- a/pubsub/iterator_test.go +++ b/pubsub/iterator_test.go @@ -36,7 +36,7 @@ import ( ) var ( - projName = "some-project" + projName = "P" topicName = "some-topic" subName = "some-sub" fullyQualifiedTopicName = fmt.Sprintf("projects/%s/topics/%s", projName, topicName) @@ -550,17 +550,9 @@ func TestIterator_StreamingPullExactlyOnce(t *testing.T) { } func TestAddToDistribution(t *testing.T) { - srv := pstest.NewServer() - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() + c, _ := newFake(t) - srv.Publish(fullyQualifiedTopicName, []byte("creating a topic"), nil) - - _, client, err := initConn(ctx, srv.Addr) - if err != nil { - t.Fatal(err) - } - iter := newMessageIterator(client.subc, fullyQualifiedTopicName, &pullOptions{}) + iter := newMessageIterator(c.subc, "some-sub", &pullOptions{}) // Start with a datapoint that's too small that should be bounded to 10s. receiveTime := time.Now().Add(time.Duration(-1) * time.Second) @@ -589,3 +581,40 @@ func TestAddToDistribution(t *testing.T) { t.Errorf("99th percentile ack distribution got: %v, want %v", deadline, want) } } + +func TestPingStreamAckDeadline(t *testing.T) { + c, srv := newFake(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + srv.Publish(fullyQualifiedTopicName, []byte("creating a topic"), nil) + topic := c.Topic(topicName) + s, err := c.CreateSubscription(ctx, subName, SubscriptionConfig{Topic: topic}) + if err != nil { + t.Errorf("failed to create subscription: %v", err) + } + + iter := newMessageIterator(c.subc, fullyQualifiedSubName, &pullOptions{}) + defer iter.stop() + + iter.eoMu.RLock() + if iter.enableExactlyOnceDelivery { + t.Error("iter.enableExactlyOnceDelivery should be false") + } + iter.eoMu.RUnlock() + + _, err = s.Update(ctx, SubscriptionConfigToUpdate{ + EnableExactlyOnceDelivery: true, + }) + if err != nil { + t.Error(err) + } + srv.Publish(fullyQualifiedTopicName, []byte("creating a topic"), nil) + // Receive one message via the stream to trigger the update to enableExactlyOnceDelivery + iter.receive(1) + iter.eoMu.RLock() + if !iter.enableExactlyOnceDelivery { + t.Error("iter.enableExactlyOnceDelivery should be true") + } + iter.eoMu.RUnlock() +} diff --git a/pubsub/pstest/fake.go b/pubsub/pstest/fake.go index 9e66e5c5b129..6ed12e62a76e 100644 --- a/pubsub/pstest/fake.go +++ b/pubsub/pstest/fake.go @@ -619,6 +619,9 @@ func (s *GServer) UpdateSubscription(_ context.Context, req *pb.UpdateSubscripti case "enable_exactly_once_delivery": sub.proto.EnableExactlyOnceDelivery = req.Subscription.EnableExactlyOnceDelivery + for _, st := range sub.streams { + st.enableExactlyOnceDelivery = req.Subscription.EnableExactlyOnceDelivery + } default: return nil, status.Errorf(codes.InvalidArgument, "unknown field name %q", path) diff --git a/pubsub/subscription_test.go b/pubsub/subscription_test.go index 6a8b0771ed2f..bdc749890c1b 100644 --- a/pubsub/subscription_test.go +++ b/pubsub/subscription_test.go @@ -314,7 +314,7 @@ func (t1 *Topic) Equal(t2 *Topic) bool { func newFake(t *testing.T) (*Client, *pstest.Server) { ctx := context.Background() srv := pstest.NewServer() - client, err := NewClient(ctx, "P", + client, err := NewClient(ctx, projName, option.WithEndpoint(srv.Addr), option.WithoutAuthentication(), option.WithGRPCDialOption(grpc.WithInsecure()))