Skip to content

Commit

Permalink
fix: add checks for bm25 k1 and b in sparse index checker (#36907)
Browse files Browse the repository at this point in the history
issue: #36883,
#35853

Signed-off-by: Buqian Zheng <[email protected]>
  • Loading branch information
zhengbuqian authored Oct 16, 2024
1 parent e5948bd commit 51f13ba
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 0 deletions.
3 changes: 3 additions & 0 deletions pkg/util/indexparamcheck/constraints.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ const (
// Sparse Index Param
SparseDropRatioBuild = "drop_ratio_build"

BM25K1 = "bm25_k1"
BM25B = "bm25_b"

MaxBitmapCardinalityLimit = 1000
)

Expand Down
16 changes: 16 additions & 0 deletions pkg/util/indexparamcheck/sparse_float_vector_base_checker.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,22 @@ func (c sparseFloatVectorBaseChecker) CheckTrain(params map[string]string) error
}
}

bm25K1Str, exist := params[BM25K1]
if exist {
bm25K1, err := strconv.ParseFloat(bm25K1Str, 64)
if err != nil || bm25K1 < 0 || bm25K1 > 3 {
return fmt.Errorf("invalid bm25_k1: %s, must be in range [0, 3]", bm25K1Str)
}
}

bm25BStr, exist := params[BM25B]
if exist {
bm25B, err := strconv.ParseFloat(bm25BStr, 64)
if err != nil || bm25B < 0 || bm25B > 1 {
return fmt.Errorf("invalid bm25_b: %s, must be in range [0, 1]", bm25BStr)
}
}

return nil
}

Expand Down
89 changes: 89 additions & 0 deletions pkg/util/indexparamcheck/sparse_float_vector_base_checker_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
package indexparamcheck

import (
"testing"

"github.com/stretchr/testify/assert"

"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
)

func Test_sparseFloatVectorBaseChecker_StaticCheck(t *testing.T) {
validParams := map[string]string{
Metric: "IP",
}

invalidParams := map[string]string{
Metric: "L2",
}

c := newSparseFloatVectorBaseChecker()

t.Run("valid metric", func(t *testing.T) {
err := c.StaticCheck(validParams)
assert.NoError(t, err)
})

t.Run("invalid metric", func(t *testing.T) {
err := c.StaticCheck(invalidParams)
assert.Error(t, err)
})
}

func Test_sparseFloatVectorBaseChecker_CheckTrain(t *testing.T) {
validParams := map[string]string{
SparseDropRatioBuild: "0.5",
BM25K1: "1.5",
BM25B: "0.5",
}

invalidDropRatio := map[string]string{
SparseDropRatioBuild: "1.5",
}

invalidBM25K1 := map[string]string{
BM25K1: "3.5",
}

invalidBM25B := map[string]string{
BM25B: "1.5",
}

c := newSparseFloatVectorBaseChecker()

t.Run("valid params", func(t *testing.T) {
err := c.CheckTrain(validParams)
assert.NoError(t, err)
})

t.Run("invalid drop ratio", func(t *testing.T) {
err := c.CheckTrain(invalidDropRatio)
assert.Error(t, err)
})

t.Run("invalid BM25K1", func(t *testing.T) {
err := c.CheckTrain(invalidBM25K1)
assert.Error(t, err)
})

t.Run("invalid BM25B", func(t *testing.T) {
err := c.CheckTrain(invalidBM25B)
assert.Error(t, err)
})
}

func Test_sparseFloatVectorBaseChecker_CheckValidDataType(t *testing.T) {
c := newSparseFloatVectorBaseChecker()

t.Run("valid data type", func(t *testing.T) {
field := &schemapb.FieldSchema{DataType: schemapb.DataType_SparseFloatVector}
err := c.CheckValidDataType(field)
assert.NoError(t, err)
})

t.Run("invalid data type", func(t *testing.T) {
field := &schemapb.FieldSchema{DataType: schemapb.DataType_FloatVector}
err := c.CheckValidDataType(field)
assert.Error(t, err)
})
}

0 comments on commit 51f13ba

Please sign in to comment.