From d493bc55f81374f8431f9fdd5fe315d7bff98e16 Mon Sep 17 00:00:00 2001 From: Petr Pchelko Date: Thu, 2 Jul 2020 06:36:55 -0700 Subject: [PATCH] Followups to v3 upgrade - Regenerate mocks based on new default protocol - Manually transform v2 messages to v3 messages - some of the fields were renamed thus json Marshal/Unmarshal does not work anymore - Added tests that verify conversion v2<->v3 works for headers fields - Update tests to use proto.Equal - simple assert.Equals might not work correctly for protobuf messages. Signed-off-by: Petr Pchelko --- src/service/ratelimit_legacy.go | 94 ++++++++++++++++++++------- test/common/common.go | 8 +++ test/integration/integration_test.go | 15 +++-- test/mocks/config/config.go | 4 +- test/mocks/limiter/limiter.go | 6 +- test/mocks/mocks.go | 2 +- test/mocks/rls/rls.go | 12 ++-- test/server/server_impl_test.go | 31 ++++----- test/service/ratelimit_legacy_test.go | 69 +++++++++++++------- test/service/ratelimit_test.go | 9 ++- 10 files changed, 166 insertions(+), 84 deletions(-) diff --git a/src/service/ratelimit_legacy.go b/src/service/ratelimit_legacy.go index 63885068..e8ecb98a 100644 --- a/src/service/ratelimit_legacy.go +++ b/src/service/ratelimit_legacy.go @@ -1,9 +1,10 @@ package ratelimit import ( + core_legacy "github.com/envoyproxy/go-control-plane/envoy/api/v2/core" + pb_struct "github.com/envoyproxy/go-control-plane/envoy/extensions/common/ratelimit/v3" pb_legacy "github.com/envoyproxy/go-control-plane/envoy/service/ratelimit/v2" pb "github.com/envoyproxy/go-control-plane/envoy/service/ratelimit/v3" - "github.com/golang/protobuf/jsonpb" "github.com/lyft/gostats" "golang.org/x/net/context" ) @@ -62,20 +63,32 @@ func ConvertLegacyRequest(legacyRequest *pb_legacy.RateLimitRequest) (*pb.RateLi if legacyRequest == nil { return nil, nil } - - m := &jsonpb.Marshaler{} - s, err := m.MarshalToString(legacyRequest) - if err != nil { - return nil, err + request := &pb.RateLimitRequest{ + Domain: legacyRequest.GetDomain(), + HitsAddend: legacyRequest.GetHitsAddend(), } - - req := &pb.RateLimitRequest{} - err = jsonpb.UnmarshalString(s, req) - if err != nil { - return nil, err + if legacyRequest.GetDescriptors() != nil { + descriptors := make([]*pb_struct.RateLimitDescriptor, len(legacyRequest.GetDescriptors())) + for i, descriptor := range legacyRequest.GetDescriptors() { + if descriptor != nil { + descriptors[i] = &pb_struct.RateLimitDescriptor{} + if descriptor.GetEntries() != nil { + entries := make([]*pb_struct.RateLimitDescriptor_Entry, len(descriptor.GetEntries())) + for j, entry := range descriptor.GetEntries() { + if entry != nil { + entries[j] = &pb_struct.RateLimitDescriptor_Entry{ + Key: entry.GetKey(), + Value: entry.GetValue(), + } + } + } + descriptors[i].Entries = entries + } + } + } + request.Descriptors = descriptors } - - return req, nil + return request, nil } func ConvertResponse(response *pb.RateLimitResponse) (*pb_legacy.RateLimitResponse, error) { @@ -83,17 +96,54 @@ func ConvertResponse(response *pb.RateLimitResponse) (*pb_legacy.RateLimitRespon return nil, nil } - m := &jsonpb.Marshaler{} - s, err := m.MarshalToString(response) - if err != nil { - return nil, err + legacyResponse := &pb_legacy.RateLimitResponse{ + OverallCode: pb_legacy.RateLimitResponse_Code(response.GetOverallCode()), } - resp := &pb_legacy.RateLimitResponse{} - err = jsonpb.UnmarshalString(s, resp) - if err != nil { - return nil, err + if response.GetStatuses() != nil { + statuses := make([]*pb_legacy.RateLimitResponse_DescriptorStatus, len(response.GetStatuses())) + for i, status := range response.GetStatuses() { + if status != nil { + statuses[i] = &pb_legacy.RateLimitResponse_DescriptorStatus{ + Code: pb_legacy.RateLimitResponse_Code(status.GetCode()), + LimitRemaining: status.GetLimitRemaining(), + } + if status.GetCurrentLimit() != nil { + statuses[i].CurrentLimit = &pb_legacy.RateLimitResponse_RateLimit{ + RequestsPerUnit: status.GetCurrentLimit().GetRequestsPerUnit(), + Unit: pb_legacy.RateLimitResponse_RateLimit_Unit(status.GetCurrentLimit().GetUnit()), + } + } + } + } + legacyResponse.Statuses = statuses + } + + if response.GetRequestHeadersToAdd() != nil { + requestHeadersToAdd := make([]*core_legacy.HeaderValue, len(response.GetRequestHeadersToAdd())) + for i, header := range response.GetRequestHeadersToAdd() { + if header != nil { + requestHeadersToAdd[i] = &core_legacy.HeaderValue{ + Key: header.GetKey(), + Value: header.GetValue(), + } + } + } + legacyResponse.RequestHeadersToAdd = requestHeadersToAdd + } + + if response.GetResponseHeadersToAdd() != nil { + responseHeadersToAdd := make([]*core_legacy.HeaderValue, len(response.GetResponseHeadersToAdd())) + for i, header := range response.GetResponseHeadersToAdd() { + if header != nil { + responseHeadersToAdd[i] = &core_legacy.HeaderValue{ + Key: header.GetKey(), + Value: header.GetValue(), + } + } + } + legacyResponse.Headers = responseHeadersToAdd } - return resp, nil + return legacyResponse, nil } diff --git a/test/common/common.go b/test/common/common.go index 630161b4..b15c41d5 100644 --- a/test/common/common.go +++ b/test/common/common.go @@ -1,6 +1,9 @@ package common import ( + "fmt" + "github.com/golang/protobuf/proto" + "github.com/stretchr/testify/assert" "sync" pb_struct_legacy "github.com/envoyproxy/go-control-plane/envoy/api/v2/ratelimit" @@ -69,3 +72,8 @@ func NewRateLimitRequestLegacy(domain string, descriptors [][][2]string, hitsAdd request.HitsAddend = hitsAddend return request } + +func AssertProtoEqual(assert *assert.Assertions, expected proto.Message, actual proto.Message) { + assert.True(proto.Equal(expected, actual), + fmt.Sprintf("These two protobuf messages are not equal:\nexpected: %v\nactual: %v", expected, actual)) +} diff --git a/test/integration/integration_test.go b/test/integration/integration_test.go index 22582e2c..0f01ac98 100644 --- a/test/integration/integration_test.go +++ b/test/integration/integration_test.go @@ -166,7 +166,8 @@ func testBasicBaseConfig(grpcPort, perSecond string, local_cache_size string) fu response, err := c.ShouldRateLimit( context.Background(), common.NewRateLimitRequest("foo", [][][2]string{{{getCacheKey("hello", enable_local_cache), "world"}}}, 1)) - assert.Equal( + common.AssertProtoEqual( + assert, &pb.RateLimitResponse{ OverallCode: pb.RateLimitResponse_OK, Statuses: []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OK, CurrentLimit: nil, LimitRemaining: 0}}}, @@ -184,7 +185,8 @@ func testBasicBaseConfig(grpcPort, perSecond string, local_cache_size string) fu response, err = c.ShouldRateLimit( context.Background(), common.NewRateLimitRequest("basic", [][][2]string{{{getCacheKey("key1", enable_local_cache), "foo"}}}, 1)) - assert.Equal( + common.AssertProtoEqual( + assert, &pb.RateLimitResponse{ OverallCode: pb.RateLimitResponse_OK, Statuses: []*pb.RateLimitResponse_DescriptorStatus{ @@ -224,7 +226,8 @@ func testBasicBaseConfig(grpcPort, perSecond string, local_cache_size string) fu limitRemaining = 0 } - assert.Equal( + common.AssertProtoEqual( + assert, &pb.RateLimitResponse{ OverallCode: status, Statuses: []*pb.RateLimitResponse_DescriptorStatus{ @@ -287,7 +290,8 @@ func testBasicBaseConfig(grpcPort, perSecond string, local_cache_size string) fu limitRemaining2 = 0 } - assert.Equal( + common.AssertProtoEqual( + assert, &pb.RateLimitResponse{ OverallCode: status, Statuses: []*pb.RateLimitResponse_DescriptorStatus{ @@ -384,7 +388,8 @@ func TestBasicConfigLegacy(t *testing.T) { response, err := c.ShouldRateLimit( context.Background(), common.NewRateLimitRequestLegacy("foo", [][][2]string{{{"hello", "world"}}}, 1)) - assert.Equal( + common.AssertProtoEqual( + assert, &pb_legacy.RateLimitResponse{ OverallCode: pb_legacy.RateLimitResponse_OK, Statuses: []*pb_legacy.RateLimitResponse_DescriptorStatus{{Code: pb_legacy.RateLimitResponse_OK, CurrentLimit: nil, LimitRemaining: 0}}}, diff --git a/test/mocks/config/config.go b/test/mocks/config/config.go index 6205f7c9..38d5b347 100644 --- a/test/mocks/config/config.go +++ b/test/mocks/config/config.go @@ -6,7 +6,7 @@ package mock_config import ( context "context" - ratelimit "github.com/envoyproxy/go-control-plane/envoy/extensions/common/ratelimit/v3" + envoy_extensions_common_ratelimit_v3 "github.com/envoyproxy/go-control-plane/envoy/extensions/common/ratelimit/v3" config "github.com/envoyproxy/ratelimit/src/config" gomock "github.com/golang/mock/gomock" stats "github.com/lyft/gostats" @@ -51,7 +51,7 @@ func (mr *MockRateLimitConfigMockRecorder) Dump() *gomock.Call { } // GetLimit mocks base method -func (m *MockRateLimitConfig) GetLimit(arg0 context.Context, arg1 string, arg2 *ratelimit.RateLimitDescriptor) *config.RateLimit { +func (m *MockRateLimitConfig) GetLimit(arg0 context.Context, arg1 string, arg2 *envoy_extensions_common_ratelimit_v3.RateLimitDescriptor) *config.RateLimit { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetLimit", arg0, arg1, arg2) ret0, _ := ret[0].(*config.RateLimit) diff --git a/test/mocks/limiter/limiter.go b/test/mocks/limiter/limiter.go index 53c06923..7e9f3e5b 100644 --- a/test/mocks/limiter/limiter.go +++ b/test/mocks/limiter/limiter.go @@ -6,7 +6,7 @@ package mock_limiter import ( context "context" - v2 "github.com/envoyproxy/go-control-plane/envoy/service/ratelimit/v3" + envoy_service_ratelimit_v3 "github.com/envoyproxy/go-control-plane/envoy/service/ratelimit/v3" config "github.com/envoyproxy/ratelimit/src/config" gomock "github.com/golang/mock/gomock" reflect "reflect" @@ -36,10 +36,10 @@ func (m *MockRateLimitCache) EXPECT() *MockRateLimitCacheMockRecorder { } // DoLimit mocks base method -func (m *MockRateLimitCache) DoLimit(arg0 context.Context, arg1 *v2.RateLimitRequest, arg2 []*config.RateLimit) []*v2.RateLimitResponse_DescriptorStatus { +func (m *MockRateLimitCache) DoLimit(arg0 context.Context, arg1 *envoy_service_ratelimit_v3.RateLimitRequest, arg2 []*config.RateLimit) []*envoy_service_ratelimit_v3.RateLimitResponse_DescriptorStatus { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "DoLimit", arg0, arg1, arg2) - ret0, _ := ret[0].([]*v2.RateLimitResponse_DescriptorStatus) + ret0, _ := ret[0].([]*envoy_service_ratelimit_v3.RateLimitResponse_DescriptorStatus) return ret0 } diff --git a/test/mocks/mocks.go b/test/mocks/mocks.go index 703865af..9f8b18ce 100644 --- a/test/mocks/mocks.go +++ b/test/mocks/mocks.go @@ -5,4 +5,4 @@ package mocks //go:generate go run github.com/golang/mock/mockgen -destination ./config/config.go github.com/envoyproxy/ratelimit/src/config RateLimitConfig,RateLimitConfigLoader //go:generate go run github.com/golang/mock/mockgen -destination ./redis/redis.go github.com/envoyproxy/ratelimit/src/redis Client //go:generate go run github.com/golang/mock/mockgen -destination ./limiter/limiter.go github.com/envoyproxy/ratelimit/src/limiter RateLimitCache,TimeSource,JitterRandSource -//go:generate go run github.com/golang/mock/mockgen -destination ./rls/rls.go github.com/envoyproxy/go-control-plane/envoy/service/ratelimit/v2 RateLimitServiceServer +//go:generate go run github.com/golang/mock/mockgen -destination ./rls/rls.go github.com/envoyproxy/go-control-plane/envoy/service/ratelimit/v3 RateLimitServiceServer diff --git a/test/mocks/rls/rls.go b/test/mocks/rls/rls.go index 6be8fda9..92d79b9a 100644 --- a/test/mocks/rls/rls.go +++ b/test/mocks/rls/rls.go @@ -1,12 +1,12 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: github.com/envoyproxy/go-control-plane/envoy/service/ratelimit/v2 (interfaces: RateLimitServiceServer) +// Source: github.com/envoyproxy/go-control-plane/envoy/service/ratelimit/v3 (interfaces: RateLimitServiceServer) -// Package mock_v2 is a generated GoMock package. -package mock_v2 +// Package mock_v3 is a generated GoMock package. +package mock_v3 import ( context "context" - v2 "github.com/envoyproxy/go-control-plane/envoy/service/ratelimit/v3" + envoy_service_ratelimit_v3 "github.com/envoyproxy/go-control-plane/envoy/service/ratelimit/v3" gomock "github.com/golang/mock/gomock" reflect "reflect" ) @@ -35,10 +35,10 @@ func (m *MockRateLimitServiceServer) EXPECT() *MockRateLimitServiceServerMockRec } // ShouldRateLimit mocks base method -func (m *MockRateLimitServiceServer) ShouldRateLimit(arg0 context.Context, arg1 *v2.RateLimitRequest) (*v2.RateLimitResponse, error) { +func (m *MockRateLimitServiceServer) ShouldRateLimit(arg0 context.Context, arg1 *envoy_service_ratelimit_v3.RateLimitRequest) (*envoy_service_ratelimit_v3.RateLimitResponse, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "ShouldRateLimit", arg0, arg1) - ret0, _ := ret[0].(*v2.RateLimitResponse) + ret0, _ := ret[0].(*envoy_service_ratelimit_v3.RateLimitResponse) ret1, _ := ret[1].(error) return ret0, ret1 } diff --git a/test/server/server_impl_test.go b/test/server/server_impl_test.go index b5646521..8ee22161 100644 --- a/test/server/server_impl_test.go +++ b/test/server/server_impl_test.go @@ -2,6 +2,8 @@ package server_test import ( "fmt" + "github.com/golang/protobuf/proto" + "github.com/stretchr/testify/mock" "io/ioutil" "net/http" "net/http/httptest" @@ -11,7 +13,7 @@ import ( pb "github.com/envoyproxy/go-control-plane/envoy/service/ratelimit/v3" "github.com/envoyproxy/ratelimit/src/server" - mock_v2 "github.com/envoyproxy/ratelimit/test/mocks/rls" + mock_v3 "github.com/envoyproxy/ratelimit/test/mocks/rls" "github.com/golang/mock/gomock" "github.com/stretchr/testify/assert" ) @@ -41,8 +43,13 @@ func TestJsonHandler(t *testing.T) { controller := gomock.NewController(t) defer controller.Finish() - rls := mock_v2.NewMockRateLimitServiceServer(controller) + rls := mock_v3.NewMockRateLimitServiceServer(controller) handler := server.NewJsonHandler(rls) + requestMatcher := mock.MatchedBy(func(req *pb.RateLimitRequest) bool { + return proto.Equal(req, &pb.RateLimitRequest{ + Domain: "foo", + }) + }) // Missing request body assertHttpResponse(t, handler, "", 400, "text/plain; charset=utf-8", "EOF\n") @@ -51,35 +58,25 @@ func TestJsonHandler(t *testing.T) { assertHttpResponse(t, handler, "}", 400, "text/plain; charset=utf-8", "invalid character '}' looking for beginning of value\n") // Unknown response code - rls.EXPECT().ShouldRateLimit(nil, &pb.RateLimitRequest{ - Domain: "foo", - }).Return(&pb.RateLimitResponse{}, nil) + rls.EXPECT().ShouldRateLimit(nil, requestMatcher).Return(&pb.RateLimitResponse{}, nil) assertHttpResponse(t, handler, `{"domain": "foo"}`, 500, "application/json", "{}") // ratelimit service error - rls.EXPECT().ShouldRateLimit(nil, &pb.RateLimitRequest{ - Domain: "foo", - }).Return(nil, fmt.Errorf("some error")) + rls.EXPECT().ShouldRateLimit(nil, requestMatcher).Return(nil, fmt.Errorf("some error")) assertHttpResponse(t, handler, `{"domain": "foo"}`, 400, "text/plain; charset=utf-8", "some error\n") // json unmarshaling error - rls.EXPECT().ShouldRateLimit(nil, &pb.RateLimitRequest{ - Domain: "foo", - }).Return(nil, nil) + rls.EXPECT().ShouldRateLimit(nil, requestMatcher).Return(nil, nil) assertHttpResponse(t, handler, `{"domain": "foo"}`, 500, "text/plain; charset=utf-8", "error marshaling proto3 to json: Marshal called with nil\n") // successful request, not rate limited - rls.EXPECT().ShouldRateLimit(nil, &pb.RateLimitRequest{ - Domain: "foo", - }).Return(&pb.RateLimitResponse{ + rls.EXPECT().ShouldRateLimit(nil, requestMatcher).Return(&pb.RateLimitResponse{ OverallCode: pb.RateLimitResponse_OK, }, nil) assertHttpResponse(t, handler, `{"domain": "foo"}`, 200, "application/json", `{"overallCode":"OK"}`) // successful request, rate limited - rls.EXPECT().ShouldRateLimit(nil, &pb.RateLimitRequest{ - Domain: "foo", - }).Return(&pb.RateLimitResponse{ + rls.EXPECT().ShouldRateLimit(nil, requestMatcher).Return(&pb.RateLimitResponse{ OverallCode: pb.RateLimitResponse_OVER_LIMIT, }, nil) assertHttpResponse(t, handler, `{"domain": "foo"}`, 429, "application/json", `{"overallCode":"OVER_LIMIT"}`) diff --git a/test/service/ratelimit_legacy_test.go b/test/service/ratelimit_legacy_test.go index 5de22470..d8839b71 100644 --- a/test/service/ratelimit_legacy_test.go +++ b/test/service/ratelimit_legacy_test.go @@ -3,7 +3,9 @@ package ratelimit_test import ( "testing" + core_legacy "github.com/envoyproxy/go-control-plane/envoy/api/v2/core" pb_struct_legacy "github.com/envoyproxy/go-control-plane/envoy/api/v2/ratelimit" + core "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" pb_struct "github.com/envoyproxy/go-control-plane/envoy/extensions/common/ratelimit/v3" pb_legacy "github.com/envoyproxy/go-control-plane/envoy/service/ratelimit/v2" pb "github.com/envoyproxy/go-control-plane/envoy/service/ratelimit/v3" @@ -12,7 +14,6 @@ import ( "github.com/envoyproxy/ratelimit/src/service" "github.com/envoyproxy/ratelimit/test/common" "github.com/golang/mock/gomock" - "github.com/golang/protobuf/jsonpb" "github.com/lyft/gostats" "github.com/stretchr/testify/assert" "golang.org/x/net/context" @@ -23,19 +24,10 @@ func convertRatelimit(ratelimit *pb.RateLimitResponse_RateLimit) (*pb_legacy.Rat return nil, nil } - m := &jsonpb.Marshaler{} - s, err := m.MarshalToString(ratelimit) - if err != nil { - return nil, err - } - - rl := &pb_legacy.RateLimitResponse_RateLimit{} - err = jsonpb.UnmarshalString(s, rl) - if err != nil { - return nil, err - } - - return rl, nil + return &pb_legacy.RateLimitResponse_RateLimit{ + RequestsPerUnit: ratelimit.GetRequestsPerUnit(), + Unit: pb_legacy.RateLimitResponse_RateLimit_Unit(ratelimit.GetUnit()), + }, nil } func convertRatelimits(ratelimits []*config.RateLimit) ([]*pb_legacy.RateLimitResponse_RateLimit, error) { @@ -75,7 +67,8 @@ func TestServiceLegacy(test *testing.T) { []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OK, CurrentLimit: nil, LimitRemaining: 0}}) response, err := service.GetLegacyService().ShouldRateLimit(nil, legacyRequest) - t.assert.Equal( + common.AssertProtoEqual( + t.assert, &pb_legacy.RateLimitResponse{ OverallCode: pb_legacy.RateLimitResponse_OK, Statuses: []*pb_legacy.RateLimitResponse_DescriptorStatus{{Code: pb_legacy.RateLimitResponse_OK, CurrentLimit: nil, LimitRemaining: 0}}}, @@ -112,7 +105,8 @@ func TestServiceLegacy(test *testing.T) { []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[0].Limit, LimitRemaining: 0}, {Code: pb.RateLimitResponse_OK, CurrentLimit: nil, LimitRemaining: 0}}) response, err = service.GetLegacyService().ShouldRateLimit(nil, legacyRequest) - t.assert.Equal( + common.AssertProtoEqual( + t.assert, &pb_legacy.RateLimitResponse{ OverallCode: pb_legacy.RateLimitResponse_OVER_LIMIT, Statuses: []*pb_legacy.RateLimitResponse_DescriptorStatus{ @@ -147,7 +141,8 @@ func TestServiceLegacy(test *testing.T) { []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OK, CurrentLimit: nil, LimitRemaining: 0}, {Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[1].Limit, LimitRemaining: 0}}) response, err = service.GetLegacyService().ShouldRateLimit(nil, legacyRequest) - t.assert.Equal( + common.AssertProtoEqual( + t.assert, &pb_legacy.RateLimitResponse{ OverallCode: pb_legacy.RateLimitResponse_OVER_LIMIT, Statuses: []*pb_legacy.RateLimitResponse_DescriptorStatus{ @@ -261,7 +256,7 @@ func TestConvertLegacyRequest(test *testing.T) { assert.FailNow(test, err.Error()) } - assert.Equal(test, expectedRequest, req) + common.AssertProtoEqual(assert.New(test), expectedRequest, req) } { @@ -282,7 +277,7 @@ func TestConvertLegacyRequest(test *testing.T) { assert.FailNow(test, err.Error()) } - assert.Equal(test, expectedRequest, req) + common.AssertProtoEqual(assert.New(test), expectedRequest, req) } { @@ -341,7 +336,7 @@ func TestConvertLegacyRequest(test *testing.T) { assert.FailNow(test, err.Error()) } - assert.Equal(test, expectedRequest, req) + common.AssertProtoEqual(assert.New(test), expectedRequest, req) } } @@ -371,9 +366,21 @@ func TestConvertResponse(test *testing.T) { }, } + requestHeadersToAdd := []*core.HeaderValue{{ + Key: "test_request", + Value: "test_request_value", + }, nil} + + responseHeadersToAdd := []*core.HeaderValue{{ + Key: "test_response", + Value: "test_response", + }, nil} + response := &pb.RateLimitResponse{ - OverallCode: pb.RateLimitResponse_OVER_LIMIT, - Statuses: statuses, + OverallCode: pb.RateLimitResponse_OVER_LIMIT, + Statuses: statuses, + RequestHeadersToAdd: requestHeadersToAdd, + ResponseHeadersToAdd: responseHeadersToAdd, } expectedRl := &pb_legacy.RateLimitResponse_RateLimit{ @@ -395,9 +402,21 @@ func TestConvertResponse(test *testing.T) { }, } + expectedRequestHeadersToAdd := []*core_legacy.HeaderValue{{ + Key: "test_request", + Value: "test_request_value", + }, nil} + + expecpectedResponseHeadersToAdd := []*core_legacy.HeaderValue{{ + Key: "test_response", + Value: "test_response", + }, nil} + expectedResponse := &pb_legacy.RateLimitResponse{ - OverallCode: pb_legacy.RateLimitResponse_OVER_LIMIT, - Statuses: expectedStatuses, + OverallCode: pb_legacy.RateLimitResponse_OVER_LIMIT, + Statuses: expectedStatuses, + RequestHeadersToAdd: expectedRequestHeadersToAdd, + Headers: expecpectedResponseHeadersToAdd, } resp, err = ratelimit.ConvertResponse(response) @@ -405,5 +424,5 @@ func TestConvertResponse(test *testing.T) { assert.FailNow(test, err.Error()) } - assert.Equal(test, expectedResponse, resp) + common.AssertProtoEqual(assert.New(test), expectedResponse, resp) } diff --git a/test/service/ratelimit_test.go b/test/service/ratelimit_test.go index 5c862fbe..12c77926 100644 --- a/test/service/ratelimit_test.go +++ b/test/service/ratelimit_test.go @@ -97,7 +97,8 @@ func TestService(test *testing.T) { []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OK, CurrentLimit: nil, LimitRemaining: 0}}) response, err := service.ShouldRateLimit(nil, request) - t.assert.Equal( + common.AssertProtoEqual( + t.assert, &pb.RateLimitResponse{ OverallCode: pb.RateLimitResponse_OK, Statuses: []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OK, CurrentLimit: nil, LimitRemaining: 0}}}, @@ -124,7 +125,8 @@ func TestService(test *testing.T) { []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[0].Limit, LimitRemaining: 0}, {Code: pb.RateLimitResponse_OK, CurrentLimit: nil, LimitRemaining: 0}}) response, err = service.ShouldRateLimit(nil, request) - t.assert.Equal( + common.AssertProtoEqual( + t.assert, &pb.RateLimitResponse{ OverallCode: pb.RateLimitResponse_OVER_LIMIT, Statuses: []*pb.RateLimitResponse_DescriptorStatus{ @@ -154,7 +156,8 @@ func TestService(test *testing.T) { []*pb.RateLimitResponse_DescriptorStatus{{Code: pb.RateLimitResponse_OK, CurrentLimit: nil, LimitRemaining: 0}, {Code: pb.RateLimitResponse_OVER_LIMIT, CurrentLimit: limits[1].Limit, LimitRemaining: 0}}) response, err = service.ShouldRateLimit(nil, request) - t.assert.Equal( + common.AssertProtoEqual( + t.assert, &pb.RateLimitResponse{ OverallCode: pb.RateLimitResponse_OVER_LIMIT, Statuses: []*pb.RateLimitResponse_DescriptorStatus{