diff --git a/pubsub/consumer.go b/pubsub/consumer.go index 9c3785ee3f..2c4787101d 100644 --- a/pubsub/consumer.go +++ b/pubsub/consumer.go @@ -18,7 +18,7 @@ type ConsumerConfig struct { // Timeout of result entry in Redis. ResponseEntryTimeout time.Duration `koanf:"response-entry-timeout"` // Minimum idle time after which messages will be autoclaimed - IdletimeToAutoclaim time.Duration `koanf:"Idletime-to-autoclaim"` + IdletimeToAutoclaim time.Duration `koanf:"idletime-to-autoclaim"` } var DefaultConsumerConfig = ConsumerConfig{ @@ -33,7 +33,7 @@ var TestConsumerConfig = ConsumerConfig{ func ConsumerConfigAddOptions(prefix string, f *pflag.FlagSet) { f.Duration(prefix+".response-entry-timeout", DefaultConsumerConfig.ResponseEntryTimeout, "timeout for response entry") - f.Duration(prefix+".Idletime-to-autoclaim", DefaultConsumerConfig.IdletimeToAutoclaim, "After a message spends this amount of time in PEL (Pending Entries List i.e claimed by another consumer but not Acknowledged) it will be allowed to be autoclaimed by other consumers") + f.Duration(prefix+".idletime-to-autoclaim", DefaultConsumerConfig.IdletimeToAutoclaim, "After a message spends this amount of time in PEL (Pending Entries List i.e claimed by another consumer but not Acknowledged) it will be allowed to be autoclaimed by other consumers") } // Consumer implements a consumer for redis stream provides heartbeat to @@ -93,9 +93,12 @@ func (c *Consumer[Request, Response]) Consume(ctx context.Context) (*Message[Req MinIdle: c.cfg.IdletimeToAutoclaim, // Minimum idle time for messages to claim (in milliseconds) Stream: c.redisStream, Start: "0", - Count: 1, // Limit the number of messages to claim + Count: 5, // Try looking for 50 entries in PEL, this assumes there are a maximum of 50 consumers in this redisGroup }).Result() - if len(messages) != 1 || err != nil { + if len(messages) == 0 || err != nil { + if err != nil { + log.Error("error from xautoclaim", "err", err) + } // Fallback to reading new messages res, err := c.client.XReadGroup(ctx, &redis.XReadGroupArgs{ Group: c.redisGroup, @@ -132,7 +135,9 @@ func (c *Consumer[Request, Response]) Consume(ctx context.Context) (*Message[Req ackNotifier := make(chan struct{}) c.StopWaiter.LaunchThread(func(ctx context.Context) { for { - if err := c.client.XClaim(ctx, &redis.XClaimArgs{ + // Use XClaimJustID so that we would have clear difference between invalid requests that are claimed multiple times due to xautoclaim and + // valid requests that are just being claimed in regular intervals to indicate heartbeat + if err := c.client.XClaimJustID(ctx, &redis.XClaimArgs{ Stream: c.redisStream, Group: c.redisGroup, Consumer: c.id, @@ -174,5 +179,8 @@ func (c *Consumer[Request, Response]) SetResult(ctx context.Context, id string, if _, err := c.client.XAck(ctx, c.redisStream, c.redisGroup, messageID).Result(); err != nil { return fmt.Errorf("acking message: %v, error: %w", messageID, err) } + if _, err := c.client.XDel(ctx, c.redisStream, messageID).Result(); err != nil { + return fmt.Errorf("deleting message: %v, error: %w", messageID, err) + } return nil } diff --git a/pubsub/producer.go b/pubsub/producer.go index df6e7d5a28..cf5dfdbd36 100644 --- a/pubsub/producer.go +++ b/pubsub/producer.go @@ -57,21 +57,26 @@ type ProducerConfig struct { CheckResultInterval time.Duration `koanf:"check-result-interval"` // Timeout of entry's written to redis by producer ResponseEntryTimeout time.Duration `koanf:"response-entry-timeout"` + // RequestTimeout is a TTL for any message sent to the redis stream + RequestTimeout time.Duration `koanf:"request-timeout"` } var DefaultProducerConfig = ProducerConfig{ CheckResultInterval: 5 * time.Second, ResponseEntryTimeout: time.Hour, + RequestTimeout: time.Hour, // should we increase this? } var TestProducerConfig = ProducerConfig{ CheckResultInterval: 5 * time.Millisecond, ResponseEntryTimeout: time.Minute, + RequestTimeout: 2 * time.Second, } func ProducerAddConfigAddOptions(prefix string, f *pflag.FlagSet) { f.Duration(prefix+".check-result-interval", DefaultProducerConfig.CheckResultInterval, "interval in which producer checks pending messages whether consumer processing them is inactive") f.Duration(prefix+".response-entry-timeout", DefaultProducerConfig.ResponseEntryTimeout, "timeout after which responses written from producer to the redis are cleared. Currently used for the key mapping unique request id to redis stream message id") + f.Duration(prefix+".request-timeout", DefaultProducerConfig.RequestTimeout, "timeout after which the message in redis stream is considered as errored, this prevents workers from working on wrong requests indefinitely") } func NewProducer[Request any, Response any](client redis.UniversalClient, streamName string, cfg *ProducerConfig) (*Producer[Request, Response], error) { @@ -91,37 +96,58 @@ func NewProducer[Request any, Response any](client redis.UniversalClient, stream }, nil } -func setMaxMsgIdInt(maxMsgIdInt *[2]uint64, msgId string) error { - idParts := strings.Split(msgId, "-") - if len(idParts) != 2 { - return fmt.Errorf("invalid i.d: %v", msgId) +// cmpMsgId compares two msgid's and returns (0) if equal, (-1) if msgId1 < msgId2, (1) if msgId1 > msgId2, (-2) if not comparable (or error) +func cmpMsgId(msgId1, msgId2 string) int { + getUintParts := func(msgId string) ([2]uint64, error) { + idParts := strings.Split(msgId, "-") + if len(idParts) != 2 { + return [2]uint64{}, fmt.Errorf("invalid i.d: %v", msgId) + } + idTimeStamp, err := strconv.ParseUint(idParts[0], 10, 64) + if err != nil { + return [2]uint64{}, fmt.Errorf("invalid i.d: %v err: %w", msgId, err) + } + idSerial, err := strconv.ParseUint(idParts[1], 10, 64) + if err != nil { + return [2]uint64{}, fmt.Errorf("invalid i.d serial: %v err: %w", msgId, err) + } + return [2]uint64{idTimeStamp, idSerial}, nil } - idTimeStamp, err := strconv.ParseUint(idParts[0], 10, 64) + id1, err := getUintParts(msgId1) if err != nil { - return fmt.Errorf("invalid i.d: %v err: %w", msgId, err) - } - if idTimeStamp < maxMsgIdInt[0] { - return nil + log.Trace("error comparing msgIds", "msgId1", msgId1, "msgId2", msgId2) + return -2 } - idSerial, err := strconv.ParseUint(idParts[1], 10, 64) + id2, err := getUintParts(msgId2) if err != nil { - return fmt.Errorf("invalid i.d serial: %v err: %w", msgId, err) + log.Trace("error comparing msgIds", "msgId1", msgId1, "msgId2", msgId2) + return -2 } - if idTimeStamp > maxMsgIdInt[0] { - maxMsgIdInt[0] = idTimeStamp - maxMsgIdInt[1] = idSerial - return nil + if id1[0] < id2[0] { + return -1 + } else if id1[0] > id2[0] { + return 1 + } else if id1[1] < id2[1] { + return -1 + } else if id1[1] > id2[1] { + return 1 } - // idTimeStamp == maxMsgIdInt[0] - if idSerial > maxMsgIdInt[1] { - maxMsgIdInt[1] = idSerial - } - return nil + return 0 } // checkResponses checks iteratively whether response for the promise is ready. func (p *Producer[Request, Response]) checkResponses(ctx context.Context) time.Duration { - maxMsgIdInt := [2]uint64{0, 0} + pelData, err := p.client.XPending(ctx, p.redisStream, p.redisGroup).Result() + if err != nil { + log.Error("error getting PEL data from xpending, xtrimming is disabled", "err", err) + } + deletePromise := func(id string) { + // Try deleting UNIQUEID_MSGID_MAP_KEY corresponding to this id from redis + if err := p.client.Del(ctx, MessageKeyFor(p.redisStream, id)+UNIQUEID_MSGID_MAP_KEY).Err(); err != nil { + log.Error("Error deleting key from redis that flags that a request is being processed", "err", err) + } + delete(p.promises, id) + } p.promisesLock.Lock() defer p.promisesLock.Unlock() responded := 0 @@ -135,16 +161,22 @@ func (p *Producer[Request, Response]) checkResponses(ctx context.Context) time.D if err != nil { if !errors.Is(err, redis.Nil) { log.Error("Error reading value in redis", "key", id, "error", err) + } else { + // The request this producer is waiting for has been past its TTL or is older than current PEL's lower, + // so safe to error and stop tracking this promise + allowedOldestID := fmt.Sprintf("%d-0", time.Now().Add(-p.cfg.RequestTimeout).UnixMilli()) + if pelData != nil && pelData.Lower != "" { + allowedOldestID = pelData.Lower + } + if cmpMsgId(msgIDAndPromise.msgID, allowedOldestID) == -1 { + msgIDAndPromise.promise.ProduceError(errors.New("error getting response, request has been waiting for too long")) + log.Error("error getting response, request has been waiting past its TTL") + errored++ + deletePromise(id) + } } continue } - // We keep track of a maxMsgId of a successfully solved request, because messages - // with id lower than this are either ack-ed or in PEL, so its safe to call XTRIMMINID on maxMsgId - errSetId := setMaxMsgIdInt(&maxMsgIdInt, msgIDAndPromise.msgID) - if errSetId != nil { - log.Error("error setting maxMsgId", "err", err) - return p.cfg.CheckResultInterval - } var resp Response if err := json.Unmarshal([]byte(res), &resp); err != nil { msgIDAndPromise.promise.ProduceError(fmt.Errorf("error unmarshalling: %w", err)) @@ -154,21 +186,36 @@ func (p *Producer[Request, Response]) checkResponses(ctx context.Context) time.D msgIDAndPromise.promise.Produce(resp) responded++ } - // Try deleting UNIQUEID_MSGID_MAP_KEY corresponding to this id from redis - if err := p.client.Del(ctx, msgKey+UNIQUEID_MSGID_MAP_KEY).Err(); err != nil { - log.Error("Error deleting key from redis that flags that a request is being processed", "err", err) - } - delete(p.promises, id) + deletePromise(id) } - var trimmed int64 - var trimErr error - maxMsgId := "+" - // If at least response for one promise was found, find the maximum of the found ones and XTRIMMINID from that msg id + 1 - if maxMsgIdInt[0] > 0 { - maxMsgId = fmt.Sprintf("%d-%d", maxMsgIdInt[0], maxMsgIdInt[1]+1) - trimmed, trimErr = p.client.XTrimMinID(ctx, p.redisStream, maxMsgId).Result() + // XDEL on consumer side already deletes acked messages (mark as deleted) but doesnt claim the memory back, XTRIM helps in claiming this memory in normal conditions + // pelData might be outdated when we do the xtrim, but thats ok as the messages are also being trimmed by other producers + if pelData != nil && pelData.Lower != "" { + trimmed, trimErr := p.client.XTrimMinID(ctx, p.redisStream, pelData.Lower).Result() + log.Trace("trimming", "xTrimMinID", pelData.Lower, "trimmed", trimmed, "responded", responded, "errored", errored, "trim-err", trimErr) + // Check if pelData.Lower has been past its TTL and if it is then ack it to remove from PEL and delete it, once + // its taken out from PEL the producer that sent this request will handle the corresponding promise accordingly (if PEL is non-empty) + allowedOldestID := fmt.Sprintf("%d-0", time.Now().Add(-p.cfg.RequestTimeout).UnixMilli()) + if cmpMsgId(pelData.Lower, allowedOldestID) == -1 { + if err := p.client.XClaim(ctx, &redis.XClaimArgs{ + Stream: p.redisStream, + Group: p.redisGroup, + Consumer: p.id, + MinIdle: 0, + Messages: []string{pelData.Lower}, + }).Err(); err != nil { + log.Error("error claiming PEL's lower message thats past its TTL", "msgID", pelData.Lower, "err", err) + return p.cfg.CheckResultInterval + } + if _, err := p.client.XAck(ctx, p.redisStream, p.redisGroup, pelData.Lower).Result(); err != nil { + log.Error("error acking PEL's lower message thats past its TTL", "msgID", pelData.Lower, "err", err) + return p.cfg.CheckResultInterval + } + if _, err := p.client.XDel(ctx, p.redisStream, pelData.Lower).Result(); err != nil { + log.Error("error deleting PEL's lower message thats past its TTL", "msgID", pelData.Lower, "err", err) + } + } } - log.Trace("trimming", "xTrimMinID", maxMsgId, "trimmed", trimmed, "responded", responded, "errored", errored, "trim-err", trimErr) return p.cfg.CheckResultInterval } diff --git a/pubsub/pubsub_test.go b/pubsub/pubsub_test.go index b1ffdca0fd..3883420f4e 100644 --- a/pubsub/pubsub_test.go +++ b/pubsub/pubsub_test.go @@ -26,8 +26,9 @@ var ( ) type testRequest struct { - Request string - SelfHash string // Is a unique identifier which can be used to compare any two validationInputs + Request string + IsInvalid bool + SelfHash string // Is a unique identifier which can be used to compare any two validationInputs } // SetSelfHash should be only called once. In the context of redis streams- by the producer @@ -63,6 +64,7 @@ func producerCfg() *ProducerConfig { return &ProducerConfig{ CheckResultInterval: TestProducerConfig.CheckResultInterval, ResponseEntryTimeout: TestProducerConfig.ResponseEntryTimeout, + RequestTimeout: TestProducerConfig.RequestTimeout, } } @@ -136,10 +138,13 @@ func flatten(responses [][]string) []string { return ret } -func produceMessages(ctx context.Context, msgs []string, producer *Producer[testRequest, testResponse], useUniqueIdentifier bool) ([]*containers.Promise[testResponse], error) { +func produceMessages(ctx context.Context, msgs []string, producer *Producer[testRequest, testResponse], useUniqueIdentifier, withInvalidEntries bool) ([]*containers.Promise[testResponse], error) { var promises []*containers.Promise[testResponse] for i := 0; i < len(msgs); i++ { req := testRequest{Request: msgs[i]} + if withInvalidEntries && i%50 == 0 { + req.IsInvalid = true + } if useUniqueIdentifier { req.SetSelfHash() } @@ -194,12 +199,14 @@ func consume(ctx context.Context, t *testing.T, consumers []*Consumer[testReques continue } gotMessages[idx][res.ID] = res.Value.Request - resp := fmt.Sprintf("result for: %v", res.ID) - if err := c.SetResult(ctx, res.Value.SelfHash, res.ID, testResponse{Response: resp}); err != nil { - t.Errorf("Error setting a result: %v", err) + if !res.Value.IsInvalid { + resp := fmt.Sprintf("result for: %v", res.ID) + if err := c.SetResult(ctx, res.Value.SelfHash, res.ID, testResponse{Response: resp}); err != nil { + t.Errorf("Error setting a result: %v", err) + } + wantResponses[idx] = append(wantResponses[idx], resp) } close(ackNotifier) - wantResponses[idx] = append(wantResponses[idx], resp) } }) } @@ -210,45 +217,50 @@ func TestRedisProduceComplex(t *testing.T) { log.SetDefault(log.NewLogger(log.NewTerminalHandlerWithLevel(os.Stderr, log.LevelTrace, true))) t.Parallel() for _, tc := range []struct { - name string - entries1Count int - entries2Count int - numProducers int - withDuplicates bool // If this is set, then every fourth entry (while generation) of each entries list is equal - killConsumers bool + name string + entriesCount []int + numProducers int + withDuplicates bool // If this is set, then every fourth entry (while generation) of each entries list is equal + killConsumers bool + withInvalidEntries bool // If this is set, then every 50th entry is invalid (requests that can't be solved by any consumer) }{ { - name: "one producer, all consumers are active", - entries1Count: messagesCount, - numProducers: 1, + name: "one producer, all consumers are active", + entriesCount: []int{messagesCount}, + numProducers: 1, }, { name: "one producer, some consumers killed, others should take over their work", - entries1Count: messagesCount, + entriesCount: []int{messagesCount}, numProducers: 1, killConsumers: true, }, { - name: "two producers, all consumers are active, all unique entries", - entries1Count: 20, - entries2Count: 20, - numProducers: 2, + name: "two producers, all consumers are active, all unique entries", + entriesCount: []int{20, 20}, + numProducers: 2, }, { name: "two producers, all consumers are active, some duplicate entries", - entries1Count: 20, - entries2Count: 20, + entriesCount: []int{20, 20}, numProducers: 2, withDuplicates: true, }, { name: "two producers, some consumers killed, others should take over their work, some duplicate entries, unequal number of requests from producers", - entries1Count: messagesCount, - entries2Count: 2 * messagesCount, + entriesCount: []int{messagesCount, 2 * messagesCount}, numProducers: 2, withDuplicates: true, killConsumers: true, }, + { + name: "two producers, some consumers killed, others should take over their work, some duplicate entries, some invalid entries, unequal number of requests from producers", + entriesCount: []int{messagesCount, 2 * messagesCount}, + numProducers: 2, + withDuplicates: true, + killConsumers: true, + withInvalidEntries: true, + }, } { t.Run(tc.name, func(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) @@ -271,15 +283,15 @@ func TestRedisProduceComplex(t *testing.T) { var entries [][]string if tc.numProducers == 2 { - entries = append(entries, wantMessages(tc.entries1Count, "1.", tc.withDuplicates)) - entries = append(entries, wantMessages(tc.entries2Count, "2.", tc.withDuplicates)) + entries = append(entries, wantMessages(tc.entriesCount[0], "1.", tc.withDuplicates)) + entries = append(entries, wantMessages(tc.entriesCount[1], "2.", tc.withDuplicates)) } else { - entries = append(entries, wantMessages(tc.entries1Count, "", tc.withDuplicates)) + entries = append(entries, wantMessages(tc.entriesCount[0], "", tc.withDuplicates)) } var promises [][]*containers.Promise[testResponse] for i := 0; i < tc.numProducers; i++ { - prs, err := produceMessages(ctx, entries[i], producers[i], tc.numProducers == 2) + prs, err := produceMessages(ctx, entries[i], producers[i], tc.numProducers == 2, tc.withInvalidEntries) if err != nil { t.Fatalf("Error producing messages from producer%d: %v", i, err) } @@ -311,8 +323,17 @@ func TestRedisProduceComplex(t *testing.T) { var gotResponses []string for i := 0; i < tc.numProducers; i++ { grs, errIndexes := awaitResponses(ctx, promises[i]) - if len(errIndexes) != 0 { - t.Fatalf("Error awaiting responses from promises%d: %v", i, errIndexes) + if tc.withInvalidEntries { + if errIndexes[len(errIndexes)-1]+50 <= len(entries[i]) { + t.Fatalf("Unexpected number of invalid requests while awaiting responses") + } + for j, idx := range errIndexes { + if idx != j*50 { + t.Fatalf("Invalid request' index mismatch want: %d got %d", j*50, idx) + } + } + } else if len(errIndexes) != 0 { + t.Fatalf("Error awaiting responses from promises %d: %v", i, errIndexes) } gotResponses = append(gotResponses, grs...) } @@ -325,6 +346,7 @@ func TestRedisProduceComplex(t *testing.T) { if err != nil { t.Fatalf("mergeMaps() unexpected error: %v", err) } + got = removeDuplicates(got) var combinedEntries []string for i := 0; i < tc.numProducers; i++ { @@ -384,14 +406,9 @@ func removeDuplicates(list []string) []string { // mergeValues merges maps from the slice and returns their values. // Returns and error if there exists duplicate key. func mergeValues(messages []map[string]string) ([]string, error) { - res := make(map[string]any) var ret []string for _, m := range messages { - for k, v := range m { - if _, found := res[k]; found { - return nil, fmt.Errorf("duplicate key: %v", k) - } - res[k] = v + for _, v := range m { ret = append(ret, v) } }