Skip to content

Commit

Permalink
enhance: refactor createIndex in RESTful API
Browse files Browse the repository at this point in the history
Signed-off-by: lixinguo <[email protected]>
  • Loading branch information
lixinguo committed Oct 29, 2024
1 parent 8894346 commit 001a927
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 5 deletions.
12 changes: 10 additions & 2 deletions internal/distributed/proxy/httpserver/handler_v2.go
Original file line number Diff line number Diff line change
Expand Up @@ -1863,8 +1863,16 @@ func (h *HandlersV2) createIndex(ctx context.Context, c *gin.Context, anyReq any
}
c.Set(ContextRequest, req)

for key, value := range indexParam.Params {
req.ExtraParams = append(req.ExtraParams, &commonpb.KeyValuePair{Key: key, Value: fmt.Sprintf("%v", value)})
var err error
req.ExtraParams, err = convertToExtraParams(indexParam)
if err != nil {
// will not happen
log.Ctx(ctx).Warn("high level restful api, convertToExtraParams fail", zap.Error(err), zap.Any("request", anyReq))
HTTPAbortReturn(c, http.StatusOK, gin.H{
HTTPReturnCode: merr.Code(err),
HTTPReturnMessage: err.Error(),
})
return nil, err
}
resp, err := wrapperProxyWithLimit(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/CreateIndex", true, h.proxy, func(reqCtx context.Context, req any) (interface{}, error) {
return h.proxy.CreateIndex(reqCtx, req.(*milvuspb.CreateIndexRequest))
Expand Down
40 changes: 40 additions & 0 deletions internal/distributed/proxy/httpserver/handler_v2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -709,6 +709,46 @@ func TestDocInDocOutSearch(t *testing.T) {
sendReqAndVerify(t, testEngine, testcase.path, http.MethodPost, testcase)
}

func TestCreateIndex(t *testing.T) {
paramtable.Init()
// disable rate limit
paramtable.Get().Save(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key, "false")
defer paramtable.Get().Reset(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key)

postTestCases := []requestBodyTestCase{}
mp := mocks.NewMockProxy(t)
mp.EXPECT().CreateIndex(mock.Anything, mock.Anything).Return(commonSuccessStatus, nil).Twice()
testEngine := initHTTPServerV2(mp, false)
path := versionalV2(IndexCategory, CreateAction)
// the previous format
postTestCases = append(postTestCases, requestBodyTestCase{
path: path,
requestBody: []byte(`{"collectionName": "` + DefaultCollectionName + `", "indexParams": [{"fieldName": "book_intro", "indexName": "book_intro_vector", "metricType": "L2", "params": {"index_type": "L2", "nlist": 10}}]}`),
})
// the current format
postTestCases = append(postTestCases, requestBodyTestCase{
path: path,
requestBody: []byte(`{"collectionName": "` + DefaultCollectionName + `", "indexParams": [{"fieldName": "book_intro", "indexName": "book_intro_vector", "metricType": "L2", "indexType": "L2", "params":{"nlist": 10}}]}`),
})

for _, testcase := range postTestCases {
t.Run("post"+testcase.path, func(t *testing.T) {
req := httptest.NewRequest(http.MethodPost, testcase.path, bytes.NewReader(testcase.requestBody))
w := httptest.NewRecorder()
testEngine.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
fmt.Println(w.Body.String())
returnBody := &ReturnErrMsg{}
err := json.Unmarshal(w.Body.Bytes(), returnBody)
assert.Nil(t, err)
assert.Equal(t, testcase.errCode, returnBody.Code)
if testcase.errCode != 0 {
assert.Equal(t, testcase.errMsg, returnBody.Message)
}
})
}
}

func TestCreateCollection(t *testing.T) {
paramtable.Init()
// disable rate limit
Expand Down
6 changes: 3 additions & 3 deletions tests/restful_client_v2/testcases/test_index_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def test_index_e2e(self, dim, metric_type, index_type):
assert expected_index[i]['fieldName'] == actual_index[i]['fieldName']
assert expected_index[i]['indexName'] == actual_index[i]['indexName']
assert expected_index[i]['metricType'] == actual_index[i]['metricType']
assert expected_index[i]["params"]['index_type'] == actual_index[i]['indexType']
assert expected_index[i]['indexType'] == actual_index[i]['indexType']

# drop index
for i in range(len(actual_index)):
Expand Down Expand Up @@ -169,7 +169,7 @@ def test_index_for_scalar_field(self, dim, index_type):
for i in range(len(expected_index)):
assert expected_index[i]['fieldName'] == actual_index[i]['fieldName']
assert expected_index[i]['indexName'] == actual_index[i]['indexName']
assert expected_index[i]['params']['index_type'] == actual_index[i]['indexType']
assert expected_index[i]['indexType'] == actual_index[i]['indexType']

@pytest.mark.parametrize("index_type", ["BIN_FLAT", "BIN_IVF_FLAT"])
@pytest.mark.parametrize("metric_type", ["JACCARD", "HAMMING"])
Expand Down Expand Up @@ -239,7 +239,7 @@ def test_index_for_binary_vector_field(self, dim, metric_type, index_type):
for i in range(len(expected_index)):
assert expected_index[i]['fieldName'] == actual_index[i]['fieldName']
assert expected_index[i]['indexName'] == actual_index[i]['indexName']
assert expected_index[i]['params']['index_type'] == actual_index[i]['indexType']
assert expected_index[i]['indexType'] == actual_index[i]['indexType']


@pytest.mark.L1
Expand Down

0 comments on commit 001a927

Please sign in to comment.