From 852b1f211ede282733ebdf62598f5b5cec1740fd Mon Sep 17 00:00:00 2001 From: Buqian Zheng Date: Wed, 16 Oct 2024 12:24:40 +0800 Subject: [PATCH] add checks for bm25 k1 and b in sparse index checker Signed-off-by: Buqian Zheng --- pkg/util/indexparamcheck/constraints.go | 3 + .../sparse_float_vector_base_checker.go | 16 ++++ .../sparse_float_vector_base_checker_test.go | 89 +++++++++++++++++++ 3 files changed, 108 insertions(+) create mode 100644 pkg/util/indexparamcheck/sparse_float_vector_base_checker_test.go diff --git a/pkg/util/indexparamcheck/constraints.go b/pkg/util/indexparamcheck/constraints.go index 8be175fe22dd5..14d374e53c01a 100644 --- a/pkg/util/indexparamcheck/constraints.go +++ b/pkg/util/indexparamcheck/constraints.go @@ -45,6 +45,9 @@ const ( // Sparse Index Param SparseDropRatioBuild = "drop_ratio_build" + BM25K1 = "bm25_k1" + BM25B = "bm25_b" + MaxBitmapCardinalityLimit = 1000 ) diff --git a/pkg/util/indexparamcheck/sparse_float_vector_base_checker.go b/pkg/util/indexparamcheck/sparse_float_vector_base_checker.go index 218d2d3e03a3e..9cd8921e31f26 100644 --- a/pkg/util/indexparamcheck/sparse_float_vector_base_checker.go +++ b/pkg/util/indexparamcheck/sparse_float_vector_base_checker.go @@ -29,6 +29,22 @@ func (c sparseFloatVectorBaseChecker) CheckTrain(params map[string]string) error } } + bm25K1Str, exist := params[BM25K1] + if exist { + bm25K1, err := strconv.ParseFloat(bm25K1Str, 64) + if err != nil || bm25K1 < 0 || bm25K1 > 3 { + return fmt.Errorf("invalid bm25_k1: %s, must be in range [0, 3]", bm25K1Str) + } + } + + bm25BStr, exist := params[BM25B] + if exist { + bm25B, err := strconv.ParseFloat(bm25BStr, 64) + if err != nil || bm25B < 0 || bm25B > 1 { + return fmt.Errorf("invalid bm25_b: %s, must be in range [0, 1]", bm25BStr) + } + } + return nil } diff --git a/pkg/util/indexparamcheck/sparse_float_vector_base_checker_test.go b/pkg/util/indexparamcheck/sparse_float_vector_base_checker_test.go new file mode 100644 index 0000000000000..2fb558f4fbcdf --- /dev/null +++ b/pkg/util/indexparamcheck/sparse_float_vector_base_checker_test.go @@ -0,0 +1,89 @@ +package indexparamcheck + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" +) + +func Test_sparseFloatVectorBaseChecker_StaticCheck(t *testing.T) { + validParams := map[string]string{ + Metric: "IP", + } + + invalidParams := map[string]string{ + Metric: "L2", + } + + c := newSparseFloatVectorBaseChecker() + + t.Run("valid metric", func(t *testing.T) { + err := c.StaticCheck(validParams) + assert.NoError(t, err) + }) + + t.Run("invalid metric", func(t *testing.T) { + err := c.StaticCheck(invalidParams) + assert.Error(t, err) + }) +} + +func Test_sparseFloatVectorBaseChecker_CheckTrain(t *testing.T) { + validParams := map[string]string{ + SparseDropRatioBuild: "0.5", + BM25K1: "1.5", + BM25B: "0.5", + } + + invalidDropRatio := map[string]string{ + SparseDropRatioBuild: "1.5", + } + + invalidBM25K1 := map[string]string{ + BM25K1: "3.5", + } + + invalidBM25B := map[string]string{ + BM25B: "1.5", + } + + c := newSparseFloatVectorBaseChecker() + + t.Run("valid params", func(t *testing.T) { + err := c.CheckTrain(validParams) + assert.NoError(t, err) + }) + + t.Run("invalid drop ratio", func(t *testing.T) { + err := c.CheckTrain(invalidDropRatio) + assert.Error(t, err) + }) + + t.Run("invalid BM25K1", func(t *testing.T) { + err := c.CheckTrain(invalidBM25K1) + assert.Error(t, err) + }) + + t.Run("invalid BM25B", func(t *testing.T) { + err := c.CheckTrain(invalidBM25B) + assert.Error(t, err) + }) +} + +func Test_sparseFloatVectorBaseChecker_CheckValidDataType(t *testing.T) { + c := newSparseFloatVectorBaseChecker() + + t.Run("valid data type", func(t *testing.T) { + field := &schemapb.FieldSchema{DataType: schemapb.DataType_SparseFloatVector} + err := c.CheckValidDataType(field) + assert.NoError(t, err) + }) + + t.Run("invalid data type", func(t *testing.T) { + field := &schemapb.FieldSchema{DataType: schemapb.DataType_FloatVector} + err := c.CheckValidDataType(field) + assert.Error(t, err) + }) +}