diff --git a/source/extensions/filters/http/rate_limit_quota/client_impl.cc b/source/extensions/filters/http/rate_limit_quota/client_impl.cc index 876060592eb3..e8f9c49a11e5 100644 --- a/source/extensions/filters/http/rate_limit_quota/client_impl.cc +++ b/source/extensions/filters/http/rate_limit_quota/client_impl.cc @@ -100,8 +100,18 @@ void RateLimitClientImpl::onReceiveMessage(RateLimitQuotaResponsePtr&& response) switch (action.bucket_action_case()) { case envoy::service::rate_limit_quota::v3::RateLimitQuotaResponse_BucketAction:: kQuotaAssignmentAction: { - quota_buckets_[bucket_id]->cached_action = action; + absl::optional cached_action = quota_buckets_[bucket_id]->cached_action; quota_buckets_[bucket_id]->current_assignment_time = time_source_.monotonicTime(); + + if (cached_action.has_value() && + Protobuf::util::MessageDifferencer::Equals(*cached_action, action)) { + ENVOY_LOG(debug, + "Cached action matches the incoming response so only TTL is updated for bucket " + "id: {}", + bucket_id); + break; + } + quota_buckets_[bucket_id]->cached_action = action; if (quota_buckets_[bucket_id]->cached_action->has_quota_assignment_action()) { auto rate_limit_strategy = quota_buckets_[bucket_id] ->cached_action->quota_assignment_action() diff --git a/test/extensions/filters/http/rate_limit_quota/client_test.cc b/test/extensions/filters/http/rate_limit_quota/client_test.cc index 336b47b7fff5..2765f7e53c54 100644 --- a/test/extensions/filters/http/rate_limit_quota/client_test.cc +++ b/test/extensions/filters/http/rate_limit_quota/client_test.cc @@ -17,6 +17,8 @@ class RateLimitClientTest : public testing::Test { RateLimitTestClient test_client{}; }; +using envoy::service::rate_limit_quota::v3::RateLimitQuotaResponse; + TEST_F(RateLimitClientTest, OpenAndCloseStream) { EXPECT_OK(test_client.client_->startStream(&test_client.stream_info_)); EXPECT_CALL(test_client.stream_, closeStream()); @@ -53,7 +55,7 @@ TEST_F(RateLimitClientTest, SendRequestAndReceiveResponse) { // `onQuotaResponse` callback is expected to be called. EXPECT_CALL(test_client.callbacks_, onQuotaResponse); - envoy::service::rate_limit_quota::v3::RateLimitQuotaResponse resp; + RateLimitQuotaResponse resp; auto response_buf = Grpc::Common::serializeMessage(resp); EXPECT_TRUE(test_client.stream_callbacks_->onReceiveMessageRaw(std::move(response_buf))); @@ -93,6 +95,92 @@ TEST_F(RateLimitClientTest, RestartStreamWhileInUse) { { test_client.client_->sendUsageReport(bucket_id_hash); }); } +TEST_F(RateLimitClientTest, HandlingDuplicateTokenBucketAssignments) { + EXPECT_OK(test_client.client_->startStream(&test_client.stream_info_)); + ASSERT_NE(test_client.stream_callbacks_, nullptr); + + auto empty_request_headers = Http::RequestHeaderMapImpl::create(); + test_client.stream_callbacks_->onCreateInitialMetadata(*empty_request_headers); + auto empty_response_headers = Http::ResponseHeaderMapImpl::create(); + test_client.stream_callbacks_->onReceiveInitialMetadata(std::move(empty_response_headers)); + + // `onQuotaResponse` callback is expected to be called twice. + EXPECT_CALL(test_client.callbacks_, onQuotaResponse).Times(3); + + ::envoy::type::v3::TokenBucket token_bucket; + token_bucket.set_max_tokens(100); + token_bucket.mutable_tokens_per_fill()->set_value(10); + token_bucket.mutable_fill_interval()->set_seconds(1000); + + ::envoy::service::rate_limit_quota::v3::BucketId bucket_id; + bucket_id.mutable_bucket()->insert({"fairshare_group_id", "mock_group"}); + const size_t bucket_id_hash = MessageUtil::hash(bucket_id); + + Bucket initial_bucket_state; + initial_bucket_state.bucket_id = bucket_id; + test_client.bucket_cache_.insert( + {bucket_id_hash, std::make_unique(std::move(initial_bucket_state))}); + + RateLimitQuotaResponse::BucketAction action; + action.mutable_quota_assignment_action() + ->mutable_rate_limit_strategy() + ->mutable_token_bucket() + ->MergeFrom(token_bucket); + action.mutable_bucket_id()->MergeFrom(bucket_id); + + RateLimitQuotaResponse resp; + resp.add_bucket_action()->MergeFrom(action); + RateLimitQuotaResponse duplicate_resp; + duplicate_resp.add_bucket_action()->MergeFrom(action); + + auto response_buf = Grpc::Common::serializeMessage(resp); + auto duplicate_response_buf = Grpc::Common::serializeMessage(duplicate_resp); + EXPECT_TRUE(test_client.stream_callbacks_->onReceiveMessageRaw(std::move(response_buf))); + + ASSERT_EQ(test_client.bucket_cache_.size(), 1); + ASSERT_TRUE(test_client.bucket_cache_.contains(bucket_id_hash)); + Bucket* first_bucket = test_client.bucket_cache_.at(bucket_id_hash).get(); + TokenBucket* first_token_bucket_limiter = first_bucket->token_bucket_limiter.get(); + EXPECT_TRUE(first_token_bucket_limiter); + + // Send a duplicate response & expect the token bucket to be carried forward + // in the cache to avoid resetting token consumption. + EXPECT_TRUE( + test_client.stream_callbacks_->onReceiveMessageRaw(std::move(duplicate_response_buf))); + + ASSERT_EQ(test_client.bucket_cache_.size(), 1); + ASSERT_TRUE(test_client.bucket_cache_.contains(bucket_id_hash)); + Bucket* second_bucket = test_client.bucket_cache_.at(bucket_id_hash).get(); + TokenBucket* second_token_bucket_limiter = second_bucket->token_bucket_limiter.get(); + EXPECT_TRUE(second_token_bucket_limiter); + EXPECT_EQ(first_token_bucket_limiter, second_token_bucket_limiter); + + // Expect the limiter to be replaced if the config changes. + resp.mutable_bucket_action(0) + ->mutable_quota_assignment_action() + ->mutable_rate_limit_strategy() + ->mutable_token_bucket() + ->set_max_tokens(200); + auto different_response_buf = Grpc::Common::serializeMessage(resp); + EXPECT_TRUE( + test_client.stream_callbacks_->onReceiveMessageRaw(std::move(different_response_buf))); + + ASSERT_EQ(test_client.bucket_cache_.size(), 1); + ASSERT_TRUE(test_client.bucket_cache_.contains(bucket_id_hash)); + Bucket* third_bucket = test_client.bucket_cache_.at(bucket_id_hash).get(); + TokenBucket* third_token_bucket_limiter = third_bucket->token_bucket_limiter.get(); + EXPECT_TRUE(third_token_bucket_limiter); + EXPECT_NE(first_token_bucket_limiter, third_token_bucket_limiter); + + auto empty_response_trailers = Http::ResponseTrailerMapImpl::create(); + test_client.stream_callbacks_->onReceiveTrailingMetadata(std::move(empty_response_trailers)); + + EXPECT_CALL(test_client.stream_, closeStream()); + EXPECT_CALL(test_client.stream_, resetStream()); + test_client.client_->closeStream(); + test_client.client_->onRemoteClose(0, ""); +} + } // namespace } // namespace RateLimitQuota } // namespace HttpFilters