Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Unify loaded partition check to delegator #36879

Merged
merged 2 commits into from
Oct 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 44 additions & 21 deletions internal/querynodev2/delegator/delegator.go
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,25 @@ func (sd *shardDelegator) modifyQueryRequest(req *querypb.QueryRequest, scope qu
return nodeReq
}

func (sd *shardDelegator) getTargetPartitions(reqPartitions []int64) (searchPartitions []int64, err error) {
existPartitions := sd.collection.GetPartitions()

// search all loaded partitions if req partition ids not provided
if len(reqPartitions) == 0 {
searchPartitions = existPartitions
return searchPartitions, nil
}

// use brute search to avoid map struct cost
for _, partition := range reqPartitions {
if !funcutil.SliceContain(existPartitions, partition) {
return nil, merr.WrapErrPartitionNotLoaded(reqPartitions)
}
}
searchPartitions = reqPartitions
return searchPartitions, nil
}

// Search preforms search operation on shard.
func (sd *shardDelegator) search(ctx context.Context, req *querypb.SearchRequest, sealed []SnapshotItem, growing []SegmentEntry) ([]*internalpb.SearchResults, error) {
log := sd.getLogger(ctx)
Expand Down Expand Up @@ -302,11 +321,6 @@ func (sd *shardDelegator) Search(ctx context.Context, req *querypb.SearchRequest
return nil, fmt.Errorf("dml channel not match, delegator channel %s, search channels %v", sd.vchannelName, req.GetDmlChannels())
}

partitions := req.GetReq().GetPartitionIDs()
if !sd.collection.ExistPartition(partitions...) {
return nil, merr.WrapErrPartitionNotLoaded(partitions)
}

// wait tsafe
waitTr := timerecord.NewTimeRecorder("wait tSafe")
tSafe, err := sd.waitTSafe(ctx, req.Req.GuaranteeTimestamp)
Expand All @@ -327,9 +341,14 @@ func (sd *shardDelegator) Search(ctx context.Context, req *querypb.SearchRequest
return nil, merr.WrapErrChannelNotAvailable(sd.vchannelName, "distribution is not servcieable")
}
defer sd.distribution.Unpin(version)
existPartitions := sd.collection.GetPartitions()
targetPartitions, err := sd.getTargetPartitions(req.GetReq().GetPartitionIDs())
if err != nil {
return nil, err
}
// set target partition ids to sub task request
req.Req.PartitionIDs = targetPartitions
growing = lo.Filter(growing, func(segment SegmentEntry, _ int) bool {
return funcutil.SliceContain(existPartitions, segment.PartitionID)
return funcutil.SliceContain(targetPartitions, segment.PartitionID)
})

if req.GetReq().GetIsAdvanced() {
Expand Down Expand Up @@ -418,11 +437,6 @@ func (sd *shardDelegator) QueryStream(ctx context.Context, req *querypb.QueryReq
return fmt.Errorf("dml channel not match, delegator channel %s, search channels %v", sd.vchannelName, req.GetDmlChannels())
}

partitions := req.GetReq().GetPartitionIDs()
if !sd.collection.ExistPartition(partitions...) {
return merr.WrapErrPartitionNotLoaded(partitions)
}

// wait tsafe
waitTr := timerecord.NewTimeRecorder("wait tSafe")
tSafe, err := sd.waitTSafe(ctx, req.Req.GuaranteeTimestamp)
Expand All @@ -443,9 +457,16 @@ func (sd *shardDelegator) QueryStream(ctx context.Context, req *querypb.QueryReq
return merr.WrapErrChannelNotAvailable(sd.vchannelName, "distribution is not servcieable")
}
defer sd.distribution.Unpin(version)
existPartitions := sd.collection.GetPartitions()

targetPartitions, err := sd.getTargetPartitions(req.GetReq().GetPartitionIDs())
if err != nil {
return err
}
// set target partition ids to sub task request
req.Req.PartitionIDs = targetPartitions

growing = lo.Filter(growing, func(segment SegmentEntry, _ int) bool {
return funcutil.SliceContain(existPartitions, segment.PartitionID)
return funcutil.SliceContain(targetPartitions, segment.PartitionID)
})
if req.Req.IgnoreGrowing {
growing = []SegmentEntry{}
Expand Down Expand Up @@ -489,11 +510,6 @@ func (sd *shardDelegator) Query(ctx context.Context, req *querypb.QueryRequest)
return nil, fmt.Errorf("dml channel not match, delegator channel %s, search channels %v", sd.vchannelName, req.GetDmlChannels())
}

partitions := req.GetReq().GetPartitionIDs()
if !sd.collection.ExistPartition(partitions...) {
return nil, merr.WrapErrPartitionNotLoaded(partitions)
}

// wait tsafe
waitTr := timerecord.NewTimeRecorder("wait tSafe")
tSafe, err := sd.waitTSafe(ctx, req.Req.GuaranteeTimestamp)
Expand All @@ -514,12 +530,19 @@ func (sd *shardDelegator) Query(ctx context.Context, req *querypb.QueryRequest)
return nil, merr.WrapErrChannelNotAvailable(sd.vchannelName, "distribution is not servcieable")
}
defer sd.distribution.Unpin(version)

targetPartitions, err := sd.getTargetPartitions(req.GetReq().GetPartitionIDs())
if err != nil {
return nil, err
}
// set target partition ids to sub task request
req.Req.PartitionIDs = targetPartitions

if req.Req.IgnoreGrowing {
growing = []SegmentEntry{}
} else {
existPartitions := sd.collection.GetPartitions()
growing = lo.Filter(growing, func(segment SegmentEntry, _ int) bool {
return funcutil.SliceContain(existPartitions, segment.PartitionID)
return funcutil.SliceContain(targetPartitions, segment.PartitionID)
})
}

Expand Down
8 changes: 4 additions & 4 deletions internal/querynodev2/segments/retrieve.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,10 +156,10 @@ func Retrieve(ctx context.Context, manager *Manager, plan *RetrievePlan, req *qu

if req.GetScope() == querypb.DataScope_Historical {
SegType = SegmentTypeSealed
retrieveSegments, err = validateOnHistorical(ctx, manager, collID, nil, segIDs)
retrieveSegments, err = validateOnHistorical(ctx, manager, collID, req.GetReq().GetPartitionIDs(), segIDs)
} else {
SegType = SegmentTypeGrowing
retrieveSegments, err = validateOnStream(ctx, manager, collID, nil, segIDs)
retrieveSegments, err = validateOnStream(ctx, manager, collID, req.GetReq().GetPartitionIDs(), segIDs)
}

if err != nil {
Expand All @@ -181,10 +181,10 @@ func RetrieveStream(ctx context.Context, manager *Manager, plan *RetrievePlan, r

if req.GetScope() == querypb.DataScope_Historical {
SegType = SegmentTypeSealed
retrieveSegments, err = validateOnHistorical(ctx, manager, collID, nil, segIDs)
retrieveSegments, err = validateOnHistorical(ctx, manager, collID, req.GetReq().GetPartitionIDs(), segIDs)
} else {
SegType = SegmentTypeGrowing
retrieveSegments, err = validateOnStream(ctx, manager, collID, nil, segIDs)
retrieveSegments, err = validateOnStream(ctx, manager, collID, req.GetReq().GetPartitionIDs(), segIDs)
}

if err != nil {
Expand Down
9 changes: 4 additions & 5 deletions internal/querynodev2/segments/validate.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,17 +42,16 @@
if len(partitionIDs) == 0 {
searchPartIDs = collection.GetPartitions()
} else {
if collection.ExistPartition(partitionIDs...) {
searchPartIDs = partitionIDs
}
// use request partition ids directly, ignoring meta partition ids
// partitions shall be controlled by delegator distribution
searchPartIDs = partitionIDs
}

log.Ctx(ctx).Debug("read target partitions", zap.Int64("collectionID", collectionID), zap.Int64s("partitionIDs", searchPartIDs))

// all partitions have been released
if len(searchPartIDs) == 0 && collection.GetLoadType() == querypb.LoadType_LoadPartition {
return nil, errors.New("partitions have been released , collectionID = " +
fmt.Sprintln(collectionID) + "target partitionIDs = " + fmt.Sprintln(searchPartIDs))
return nil, errors.Newf("partitions have been released , collectionID = %d target partitionIDs = %v", collectionID, searchPartIDs)

Check warning on line 54 in internal/querynodev2/segments/validate.go

View check run for this annotation

Codecov / codecov/patch

internal/querynodev2/segments/validate.go#L54

Added line #L54 was not covered by tests
}

if len(searchPartIDs) == 0 && collection.GetLoadType() == querypb.LoadType_LoadCollection {
Expand Down
6 changes: 3 additions & 3 deletions internal/querynodev2/tasks/search_task.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ func (t *SearchTask) Execute() error {
t.segmentManager,
searchReq,
req.GetReq().GetCollectionID(),
nil,
req.GetReq().GetPartitionIDs(),
req.GetSegmentIDs(),
)
} else if req.GetScope() == querypb.DataScope_Streaming {
Expand All @@ -169,7 +169,7 @@ func (t *SearchTask) Execute() error {
t.segmentManager,
searchReq,
req.GetReq().GetCollectionID(),
nil,
req.GetReq().GetPartitionIDs(),
req.GetSegmentIDs(),
)
}
Expand Down Expand Up @@ -475,7 +475,7 @@ func (t *StreamingSearchTask) Execute() error {
t.segmentManager,
searchReq,
req.GetReq().GetCollectionID(),
nil,
req.GetReq().GetPartitionIDs(),
req.GetSegmentIDs(),
)
defer segments.DeleteSearchResults(results)
Expand Down
Loading