From 371fd4d86e42801fc3c57fb7abed2a898c69036e Mon Sep 17 00:00:00 2001 From: lixinguo Date: Tue, 29 Oct 2024 11:25:13 +0800 Subject: [PATCH] enhance: refactor createIndex in RESTful API Signed-off-by: lixinguo --- .../proxy/httpserver/handler_v2.go | 12 +++++- .../proxy/httpserver/handler_v2_test.go | 40 +++++++++++++++++++ .../distributed/proxy/httpserver/utils.go | 8 ++++ .../testcases/test_index_operation.py | 12 +++--- 4 files changed, 65 insertions(+), 7 deletions(-) diff --git a/internal/distributed/proxy/httpserver/handler_v2.go b/internal/distributed/proxy/httpserver/handler_v2.go index fced1d4877ed1..4548e334d9bfc 100644 --- a/internal/distributed/proxy/httpserver/handler_v2.go +++ b/internal/distributed/proxy/httpserver/handler_v2.go @@ -1863,8 +1863,16 @@ func (h *HandlersV2) createIndex(ctx context.Context, c *gin.Context, anyReq any } c.Set(ContextRequest, req) - for key, value := range indexParam.Params { - req.ExtraParams = append(req.ExtraParams, &commonpb.KeyValuePair{Key: key, Value: fmt.Sprintf("%v", value)}) + var err error + req.ExtraParams, err = convertToExtraParams(indexParam) + if err != nil { + // will not happen + log.Ctx(ctx).Warn("high level restful api, convertToExtraParams fail", zap.Error(err), zap.Any("request", anyReq)) + HTTPAbortReturn(c, http.StatusOK, gin.H{ + HTTPReturnCode: merr.Code(err), + HTTPReturnMessage: err.Error(), + }) + return nil, err } resp, err := wrapperProxyWithLimit(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/CreateIndex", true, h.proxy, func(reqCtx context.Context, req any) (interface{}, error) { return h.proxy.CreateIndex(reqCtx, req.(*milvuspb.CreateIndexRequest)) diff --git a/internal/distributed/proxy/httpserver/handler_v2_test.go b/internal/distributed/proxy/httpserver/handler_v2_test.go index 8f317c9739f3d..d8047bde8b323 100644 --- a/internal/distributed/proxy/httpserver/handler_v2_test.go +++ b/internal/distributed/proxy/httpserver/handler_v2_test.go @@ -709,6 +709,46 @@ func TestDocInDocOutSearch(t *testing.T) { sendReqAndVerify(t, testEngine, testcase.path, http.MethodPost, testcase) } +func TestCreateIndex(t *testing.T) { + paramtable.Init() + // disable rate limit + paramtable.Get().Save(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key, "false") + defer paramtable.Get().Reset(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key) + + postTestCases := []requestBodyTestCase{} + mp := mocks.NewMockProxy(t) + mp.EXPECT().CreateIndex(mock.Anything, mock.Anything).Return(commonSuccessStatus, nil).Twice() + testEngine := initHTTPServerV2(mp, false) + path := versionalV2(IndexCategory, CreateAction) + // the previous format + postTestCases = append(postTestCases, requestBodyTestCase{ + path: path, + requestBody: []byte(`{"collectionName": "` + DefaultCollectionName + `", "indexParams": [{"fieldName": "book_intro", "indexName": "book_intro_vector", "metricType": "L2", "params": {"index_type": "L2", "nlist": 10}}]}`), + }) + // the current format + postTestCases = append(postTestCases, requestBodyTestCase{ + path: path, + requestBody: []byte(`{"collectionName": "` + DefaultCollectionName + `", "indexParams": [{"fieldName": "book_intro", "indexName": "book_intro_vector", "metricType": "L2", "indexType": "L2", "params":{"nlist": 10}}]}`), + }) + + for _, testcase := range postTestCases { + t.Run("post"+testcase.path, func(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, testcase.path, bytes.NewReader(testcase.requestBody)) + w := httptest.NewRecorder() + testEngine.ServeHTTP(w, req) + assert.Equal(t, http.StatusOK, w.Code) + fmt.Println(w.Body.String()) + returnBody := &ReturnErrMsg{} + err := json.Unmarshal(w.Body.Bytes(), returnBody) + assert.Nil(t, err) + assert.Equal(t, testcase.errCode, returnBody.Code) + if testcase.errCode != 0 { + assert.Equal(t, testcase.errMsg, returnBody.Message) + } + }) + } +} + func TestCreateCollection(t *testing.T) { paramtable.Init() // disable rate limit diff --git a/internal/distributed/proxy/httpserver/utils.go b/internal/distributed/proxy/httpserver/utils.go index 084a579925e5f..eda05468a84e2 100644 --- a/internal/distributed/proxy/httpserver/utils.go +++ b/internal/distributed/proxy/httpserver/utils.go @@ -1489,6 +1489,14 @@ func convertToExtraParams(indexParam IndexParam) ([]*commonpb.KeyValuePair, erro if indexParam.IndexType != "" { params = append(params, &commonpb.KeyValuePair{Key: common.IndexTypeKey, Value: indexParam.IndexType}) } + if indexParam.IndexType == "" { + for key, value := range indexParam.Params { + if key == common.IndexTypeKey { + params = append(params, &commonpb.KeyValuePair{Key: common.IndexTypeKey, Value: fmt.Sprintf("%v", value)}) + break + } + } + } if indexParam.MetricType != "" { params = append(params, &commonpb.KeyValuePair{Key: common.MetricTypeKey, Value: indexParam.MetricType}) } diff --git a/tests/restful_client_v2/testcases/test_index_operation.py b/tests/restful_client_v2/testcases/test_index_operation.py index 534684c9bfbdf..a528aec87a18a 100644 --- a/tests/restful_client_v2/testcases/test_index_operation.py +++ b/tests/restful_client_v2/testcases/test_index_operation.py @@ -72,8 +72,10 @@ def test_index_e2e(self, dim, metric_type, index_type): "metricType": f"{metric_type}"}] } if index_type == "HNSW": + payload["indexParams"][0]["indexType"]="HNSW" payload["indexParams"][0]["params"] = {"index_type": "HNSW", "M": "16", "efConstruction": "200"} if index_type == "AUTOINDEX": + payload["indexParams"][0]["indexType"]="AUTOINDEX" payload["indexParams"][0]["params"] = {"index_type": "AUTOINDEX"} rsp = self.index_client.index_create(payload) assert rsp['code'] == 0 @@ -89,8 +91,8 @@ def test_index_e2e(self, dim, metric_type, index_type): for i in range(len(expected_index)): assert expected_index[i]['fieldName'] == actual_index[i]['fieldName'] assert expected_index[i]['indexName'] == actual_index[i]['indexName'] + assert expected_index[i]['indexType'] == actual_index[i]['indexType'] assert expected_index[i]['metricType'] == actual_index[i]['metricType'] - assert expected_index[i]["params"]['index_type'] == actual_index[i]['indexType'] # drop index for i in range(len(actual_index)): @@ -152,7 +154,7 @@ def test_index_for_scalar_field(self, dim, index_type): # create index payload = { "collectionName": name, - "indexParams": [{"fieldName": "word_count", "indexName": "word_count_vector", + "indexParams": [{"fieldName": "word_count", "indexName": "word_count_vector","indexType": "INVERTED", "params": {"index_type": "INVERTED"}}] } rsp = self.index_client.index_create(payload) @@ -169,7 +171,7 @@ def test_index_for_scalar_field(self, dim, index_type): for i in range(len(expected_index)): assert expected_index[i]['fieldName'] == actual_index[i]['fieldName'] assert expected_index[i]['indexName'] == actual_index[i]['indexName'] - assert expected_index[i]['params']['index_type'] == actual_index[i]['indexType'] + assert expected_index[i]['indexType'] == actual_index[i]['indexType'] @pytest.mark.parametrize("index_type", ["BIN_FLAT", "BIN_IVF_FLAT"]) @pytest.mark.parametrize("metric_type", ["JACCARD", "HAMMING"]) @@ -220,7 +222,7 @@ def test_index_for_binary_vector_field(self, dim, metric_type, index_type): index_name = "binary_vector_index" payload = { "collectionName": name, - "indexParams": [{"fieldName": "binary_vector", "indexName": index_name, "metricType": metric_type, + "indexParams": [{"fieldName": "binary_vector", "indexName": index_name, "metricType": metric_type,"indexType": index_type, "params": {"index_type": index_type}}] } if index_type == "BIN_IVF_FLAT": @@ -239,7 +241,7 @@ def test_index_for_binary_vector_field(self, dim, metric_type, index_type): for i in range(len(expected_index)): assert expected_index[i]['fieldName'] == actual_index[i]['fieldName'] assert expected_index[i]['indexName'] == actual_index[i]['indexName'] - assert expected_index[i]['params']['index_type'] == actual_index[i]['indexType'] + assert expected_index[i]['indexType'] == actual_index[i]['indexType'] @pytest.mark.L1