Skip to content

Commit

Permalink
enhance: refine group_strict_size parameter(#37482)
Browse files Browse the repository at this point in the history
Signed-off-by: MrPresent-Han <[email protected]>
  • Loading branch information
MrPresent-Han committed Nov 7, 2024
1 parent 8275e40 commit e32153d
Show file tree
Hide file tree
Showing 10 changed files with 38 additions and 38 deletions.
2 changes: 1 addition & 1 deletion internal/core/src/common/QueryInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ namespace milvus {
struct SearchInfo {
int64_t topk_{0};
int64_t group_size_{1};
bool group_strict_size_{false};
bool strict_group_size_{false};
int64_t round_decimal_{0};
FieldId field_id_;
MetricType metric_type_;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ SearchGroupBy(const std::vector<std::shared_ptr<VectorIterator>>& iterators,
GroupIteratorsByType<int8_t>(iterators,
search_info.topk_,
search_info.group_size_,
search_info.group_strict_size_,
search_info.strict_group_size_,
*dataGetter,
group_by_values,
seg_offsets,
Expand All @@ -59,7 +59,7 @@ SearchGroupBy(const std::vector<std::shared_ptr<VectorIterator>>& iterators,
GroupIteratorsByType<int16_t>(iterators,
search_info.topk_,
search_info.group_size_,
search_info.group_strict_size_,
search_info.strict_group_size_,
*dataGetter,
group_by_values,
seg_offsets,
Expand All @@ -74,7 +74,7 @@ SearchGroupBy(const std::vector<std::shared_ptr<VectorIterator>>& iterators,
GroupIteratorsByType<int32_t>(iterators,
search_info.topk_,
search_info.group_size_,
search_info.group_strict_size_,
search_info.strict_group_size_,
*dataGetter,
group_by_values,
seg_offsets,
Expand All @@ -89,7 +89,7 @@ SearchGroupBy(const std::vector<std::shared_ptr<VectorIterator>>& iterators,
GroupIteratorsByType<int64_t>(iterators,
search_info.topk_,
search_info.group_size_,
search_info.group_strict_size_,
search_info.strict_group_size_,
*dataGetter,
group_by_values,
seg_offsets,
Expand All @@ -103,7 +103,7 @@ SearchGroupBy(const std::vector<std::shared_ptr<VectorIterator>>& iterators,
GroupIteratorsByType<bool>(iterators,
search_info.topk_,
search_info.group_size_,
search_info.group_strict_size_,
search_info.strict_group_size_,
*dataGetter,
group_by_values,
seg_offsets,
Expand All @@ -118,7 +118,7 @@ SearchGroupBy(const std::vector<std::shared_ptr<VectorIterator>>& iterators,
GroupIteratorsByType<std::string>(iterators,
search_info.topk_,
search_info.group_size_,
search_info.group_strict_size_,
search_info.strict_group_size_,
*dataGetter,
group_by_values,
seg_offsets,
Expand Down
4 changes: 2 additions & 2 deletions internal/core/src/query/PlanProto.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ ProtoParser::PlanNodeFromProto(const planpb::PlanNode& plan_node_proto) {
search_info.group_size_ = query_info_proto.group_size() > 0
? query_info_proto.group_size()
: 1;
search_info.group_strict_size_ =
query_info_proto.group_strict_size();
search_info.strict_group_size_ =
query_info_proto.strict_group_size();
}
return search_info;
};
Expand Down
2 changes: 1 addition & 1 deletion internal/distributed/proxy/httpserver/constant.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,6 @@ const (
ParamRangeFilter = "range_filter"
ParamGroupByField = "group_by_field"
ParamGroupSize = "group_size"
ParamGroupStrictSize = "group_strict_size"
ParamStrictGroupSize = "strict_group_size"
BoundedTimestamp = 2
)
4 changes: 2 additions & 2 deletions internal/distributed/proxy/httpserver/handler_v2.go
Original file line number Diff line number Diff line change
Expand Up @@ -995,7 +995,7 @@ func (h *HandlersV2) search(ctx context.Context, c *gin.Context, anyReq any, dbN
}
if httpReq.GroupByField != "" && httpReq.GroupSize > 0 {
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamGroupSize, Value: strconv.FormatInt(int64(httpReq.GroupSize), 10)})
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamGroupStrictSize, Value: strconv.FormatBool(httpReq.GroupStrictSize)})
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamStrictGroupSize, Value: strconv.FormatBool(httpReq.StrictGroupSize)})

Check warning on line 998 in internal/distributed/proxy/httpserver/handler_v2.go

View check run for this annotation

Codecov / codecov/patch

internal/distributed/proxy/httpserver/handler_v2.go#L998

Added line #L998 was not covered by tests
}
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: proxy.AnnsFieldKey, Value: httpReq.AnnsField})
body, _ := c.Get(gin.BodyBytesKey)
Expand Down Expand Up @@ -1100,7 +1100,7 @@ func (h *HandlersV2) advancedSearch(ctx context.Context, c *gin.Context, anyReq
}
if httpReq.GroupByField != "" && httpReq.GroupSize > 0 {
req.RankParams = append(req.RankParams, &commonpb.KeyValuePair{Key: ParamGroupSize, Value: strconv.FormatInt(int64(httpReq.GroupSize), 10)})
req.RankParams = append(req.RankParams, &commonpb.KeyValuePair{Key: ParamGroupStrictSize, Value: strconv.FormatBool(httpReq.GroupStrictSize)})
req.RankParams = append(req.RankParams, &commonpb.KeyValuePair{Key: ParamStrictGroupSize, Value: strconv.FormatBool(httpReq.StrictGroupSize)})

Check warning on line 1103 in internal/distributed/proxy/httpserver/handler_v2.go

View check run for this annotation

Codecov / codecov/patch

internal/distributed/proxy/httpserver/handler_v2.go#L1103

Added line #L1103 was not covered by tests
}
resp, err := wrapperProxyWithLimit(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/HybridSearch", true, h.proxy, func(reqCtx context.Context, req any) (interface{}, error) {
return h.proxy.HybridSearch(reqCtx, req.(*milvuspb.HybridSearchRequest))
Expand Down
4 changes: 2 additions & 2 deletions internal/distributed/proxy/httpserver/request_v2.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ type SearchReqV2 struct {
Filter string `json:"filter"`
GroupByField string `json:"groupingField"`
GroupSize int32 `json:"groupSize"`
GroupStrictSize bool `json:"groupStrictSize"`
StrictGroupSize bool `json:"strictGroupSize"`
Limit int32 `json:"limit"`
Offset int32 `json:"offset"`
OutputFields []string `json:"outputFields"`
Expand Down Expand Up @@ -197,7 +197,7 @@ type HybridSearchReq struct {
Limit int32 `json:"limit"`
GroupByField string `json:"groupingField"`
GroupSize int32 `json:"groupSize"`
GroupStrictSize bool `json:"groupStrictSize"`
StrictGroupSize bool `json:"strictGroupSize"`
OutputFields []string `json:"outputFields"`
ConsistencyLevel string `json:"consistencyLevel"`
}
Expand Down
2 changes: 1 addition & 1 deletion internal/proto/plan.proto
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ message QueryInfo {
int64 group_by_field_id = 6;
bool materialized_view_involved = 7;
int64 group_size = 8;
bool group_strict_size = 9;
bool strict_group_size = 9;
double bm25_avgdl = 10;
int64 query_field_id =11;
}
Expand Down
34 changes: 17 additions & 17 deletions internal/proxy/search_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ type rankParams struct {
roundDecimal int64
groupByFieldId int64
groupSize int64
groupStrictSize bool
strictGroupSize bool
}

func (r *rankParams) GetLimit() int64 {
Expand Down Expand Up @@ -64,9 +64,9 @@ func (r *rankParams) GetGroupSize() int64 {
return 1
}

func (r *rankParams) GetGroupStrictSize() bool {
func (r *rankParams) GetStrictGroupSize() bool {
if r != nil {
return r.groupStrictSize
return r.strictGroupSize
}
return false
}
Expand Down Expand Up @@ -170,15 +170,15 @@ func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb

// 5. parse group by field and group by size
var groupByFieldId, groupSize int64
var groupStrictSize bool
var strictGroupSize bool
if isAdvanced {
groupByFieldId, groupSize, groupStrictSize = rankParams.GetGroupByFieldId(), rankParams.GetGroupSize(), rankParams.GetGroupStrictSize()
groupByFieldId, groupSize, strictGroupSize = rankParams.GetGroupByFieldId(), rankParams.GetGroupSize(), rankParams.GetStrictGroupSize()
} else {
groupByInfo := parseGroupByInfo(searchParamsPair, schema)
if groupByInfo.err != nil {
return &SearchInfo{planInfo: nil, offset: 0, isIterator: false, parseError: groupByInfo.err}
}
groupByFieldId, groupSize, groupStrictSize = groupByInfo.GetGroupByFieldId(), groupByInfo.GetGroupSize(), groupByInfo.GetGroupStrictSize()
groupByFieldId, groupSize, strictGroupSize = groupByInfo.GetGroupByFieldId(), groupByInfo.GetGroupSize(), groupByInfo.GetStrictGroupSize()
}

// 6. parse iterator tag, prevent trying to groupBy when doing iteration or doing range-search
Expand All @@ -199,7 +199,7 @@ func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb
RoundDecimal: roundDecimal,
GroupByFieldId: groupByFieldId,
GroupSize: groupSize,
GroupStrictSize: groupStrictSize,
StrictGroupSize: strictGroupSize,
},
offset: offset,
isIterator: isIterator,
Expand Down Expand Up @@ -303,7 +303,7 @@ func getPartitionIDs(ctx context.Context, dbName string, collectionName string,
type groupByInfo struct {
groupByFieldId int64
groupSize int64
groupStrictSize bool
strictGroupSize bool
err error
}

Expand All @@ -321,9 +321,9 @@ func (g *groupByInfo) GetGroupSize() int64 {
return 0
}

func (g *groupByInfo) GetGroupStrictSize() bool {
func (g *groupByInfo) GetStrictGroupSize() bool {
if g != nil {
return g.groupStrictSize
return g.strictGroupSize
}
return false
}
Expand Down Expand Up @@ -389,17 +389,17 @@ func parseGroupByInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemap
ret.groupSize = groupSize

// 3. parse group strict size
var groupStrictSize bool
groupStrictSizeStr, err := funcutil.GetAttrByKeyFromRepeatedKV(GroupStrictSize, searchParamsPair)
var strictGroupSize bool
strictGroupSizeStr, err := funcutil.GetAttrByKeyFromRepeatedKV(StrictGroupSize, searchParamsPair)
if err != nil {
groupStrictSize = false
strictGroupSize = false
} else {
groupStrictSize, err = strconv.ParseBool(groupStrictSizeStr)
strictGroupSize, err = strconv.ParseBool(strictGroupSizeStr)
if err != nil {
groupStrictSize = false
strictGroupSize = false

Check warning on line 399 in internal/proxy/search_util.go

View check run for this annotation

Codecov / codecov/patch

internal/proxy/search_util.go#L399

Added line #L399 was not covered by tests
}
}
ret.groupStrictSize = groupStrictSize
ret.strictGroupSize = strictGroupSize
return ret
}

Expand Down Expand Up @@ -460,7 +460,7 @@ func parseRankParams(rankParamsPair []*commonpb.KeyValuePair, schema *schemapb.C
roundDecimal: roundDecimal,
groupByFieldId: groupByInfo.GetGroupByFieldId(),
groupSize: groupByInfo.GetGroupSize(),
groupStrictSize: groupByInfo.GetGroupStrictSize(),
strictGroupSize: groupByInfo.GetStrictGroupSize(),
}, nil
}

Expand Down
2 changes: 1 addition & 1 deletion internal/proxy/task.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ const (
IteratorField = "iterator"
GroupByFieldKey = "group_by_field"
GroupSizeKey = "group_size"
GroupStrictSize = "group_strict_size"
StrictGroupSize = "strict_group_size"
RankGroupScorer = "rank_group_scorer"
AnnsFieldKey = "anns_field"
TopKKey = "topk"
Expand Down
10 changes: 5 additions & 5 deletions internal/proxy/task_search_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2325,7 +2325,7 @@ func TestTaskSearch_parseSearchInfo(t *testing.T) {
},
}
// 1. first parse rank params
// outer params require to group by field 101 and groupSize=3 and groupStrictSize=false
// outer params require to group by field 101 and groupSize=3 and strictGroupSize=false
testRankParamsPairs := getValidSearchParams()
testRankParamsPairs = append(testRankParamsPairs, &commonpb.KeyValuePair{
Key: GroupByFieldKey,
Expand All @@ -2336,7 +2336,7 @@ func TestTaskSearch_parseSearchInfo(t *testing.T) {
Value: strconv.FormatInt(3, 10),
})
testRankParamsPairs = append(testRankParamsPairs, &commonpb.KeyValuePair{
Key: GroupStrictSize,
Key: StrictGroupSize,
Value: "false",
})
testRankParamsPairs = append(testRankParamsPairs, &commonpb.KeyValuePair{
Expand All @@ -2348,7 +2348,7 @@ func TestTaskSearch_parseSearchInfo(t *testing.T) {

// 2. parse search params for sub request in hybridsearch
params := getValidSearchParams()
// inner params require to group by field 103 and groupSize=10 and groupStrictSize=true
// inner params require to group by field 103 and groupSize=10 and strictGroupSize=true
params = append(params, &commonpb.KeyValuePair{
Key: GroupByFieldKey,
Value: "c3",
Expand All @@ -2358,7 +2358,7 @@ func TestTaskSearch_parseSearchInfo(t *testing.T) {
Value: strconv.FormatInt(10, 10),
})
params = append(params, &commonpb.KeyValuePair{
Key: GroupStrictSize,
Key: StrictGroupSize,
Value: "true",
})

Expand All @@ -2370,7 +2370,7 @@ func TestTaskSearch_parseSearchInfo(t *testing.T) {
// set by main request rather than inner sub request
assert.Equal(t, int64(101), searchInfo.planInfo.GetGroupByFieldId())
assert.Equal(t, int64(3), searchInfo.planInfo.GetGroupSize())
assert.False(t, searchInfo.planInfo.GetGroupStrictSize())
assert.False(t, searchInfo.planInfo.GetStrictGroupSize())
})

t.Run("parseSearchInfo error", func(t *testing.T) {
Expand Down

0 comments on commit e32153d

Please sign in to comment.