From 01977ffb27ec9915c1c14e7c408fcc3705f915b8 Mon Sep 17 00:00:00 2001 From: bigsheeper Date: Fri, 18 Oct 2024 18:35:52 +0800 Subject: [PATCH] enhance: Support db for bulkinsert Signed-off-by: bigsheeper --- internal/datacoord/import_job.go | 7 ++++ internal/datacoord/services.go | 22 +++++++--- internal/datacoord/services_test.go | 40 +++++++++++++++++-- .../proxy/httpserver/handler_v2.go | 4 +- .../proxy/httpserver/request_v2.go | 9 +++-- internal/proto/internal.proto | 1 + internal/proxy/impl.go | 30 +++++++++++++- internal/proxy/impl_test.go | 40 ++++++++++++++++++- 8 files changed, 137 insertions(+), 16 deletions(-) diff --git a/internal/datacoord/import_job.go b/internal/datacoord/import_job.go index 39645ec5154be..329a26c7f5a53 100644 --- a/internal/datacoord/import_job.go +++ b/internal/datacoord/import_job.go @@ -39,6 +39,12 @@ func WithCollectionID(collectionID int64) ImportJobFilter { } } +func WithDbID(DbID int64) ImportJobFilter { + return func(job ImportJob) bool { + return job.GetDbID() == DbID + } +} + type UpdateJobAction func(job ImportJob) func UpdateJobState(state internalpb.ImportJobState) UpdateJobAction { @@ -78,6 +84,7 @@ func UpdateJobCompleteTime(completeTime string) UpdateJobAction { type ImportJob interface { GetJobID() int64 + GetDbID() int64 GetCollectionID() int64 GetCollectionName() string GetPartitionIDs() []int64 diff --git a/internal/datacoord/services.go b/internal/datacoord/services.go index 6d8a092c0c761..7caeb619dbbbf 100644 --- a/internal/datacoord/services.go +++ b/internal/datacoord/services.go @@ -1674,7 +1674,9 @@ func (s *Server) ImportV2(ctx context.Context, in *internalpb.ImportRequestInter Status: merr.Success(), } - log := log.With(zap.Int64("collection", in.GetCollectionID()), + log := log.With( + zap.Int64("dbID", in.GetDbID()), + zap.Int64("collection", in.GetCollectionID()), zap.Int64s("partitions", in.GetPartitionIDs()), zap.Strings("channels", in.GetChannelNames())) log.Info("receive import request", zap.Any("files", in.GetFiles())) @@ -1728,6 +1730,7 @@ func (s *Server) ImportV2(ctx context.Context, in *internalpb.ImportRequestInter job := &importJob{ ImportJob: &datapb.ImportJob{ JobID: idStart, + DbID: in.GetDbID(), CollectionID: in.GetCollectionID(), CollectionName: in.GetCollectionName(), PartitionIDs: in.GetPartitionIDs(), @@ -1754,7 +1757,7 @@ func (s *Server) ImportV2(ctx context.Context, in *internalpb.ImportRequestInter } func (s *Server) GetImportProgress(ctx context.Context, in *internalpb.GetImportProgressRequest) (*internalpb.GetImportProgressResponse, error) { - log := log.With(zap.String("jobID", in.GetJobID())) + log := log.With(zap.String("jobID", in.GetJobID()), zap.Int64("dbID", in.GetDbID())) if err := merr.CheckHealthy(s.GetStateCode()); err != nil { return &internalpb.GetImportProgressResponse{ Status: merr.Status(err), @@ -1774,6 +1777,10 @@ func (s *Server) GetImportProgress(ctx context.Context, in *internalpb.GetImport resp.Status = merr.Status(merr.WrapErrImportFailed(fmt.Sprintf("import job does not exist, jobID=%d", jobID))) return resp, nil } + if job.GetDbID() != 0 && job.GetDbID() != in.GetDbID() { + resp.Status = merr.Status(merr.WrapErrImportFailed(fmt.Sprintf("import job does not exist, jobID=%d, dbID=%d", jobID, in.GetDbID()))) + return resp, nil + } progress, state, importedRows, totalRows, reason := GetJobProgress(jobID, s.importMeta, s.meta, s.jobManager) resp.State = state resp.Reason = reason @@ -1804,11 +1811,14 @@ func (s *Server) ListImports(ctx context.Context, req *internalpb.ListImportsReq } var jobs []ImportJob + filters := make([]ImportJobFilter, 0) + if req.GetDbID() != 0 { + filters = append(filters, WithDbID(req.GetDbID())) + } if req.GetCollectionID() != 0 { - jobs = s.importMeta.GetJobBy(WithCollectionID(req.GetCollectionID())) - } else { - jobs = s.importMeta.GetJobBy() + filters = append(filters, WithCollectionID(req.GetCollectionID())) } + jobs = s.importMeta.GetJobBy(filters...) for _, job := range jobs { progress, state, _, _, reason := GetJobProgress(job.GetJobID(), s.importMeta, s.meta, s.jobManager) @@ -1818,5 +1828,7 @@ func (s *Server) ListImports(ctx context.Context, req *internalpb.ListImportsReq resp.Progresses = append(resp.Progresses, progress) resp.CollectionNames = append(resp.CollectionNames, job.GetCollectionName()) } + log.Info("ListImports done", zap.Int64("collectionID", req.GetCollectionID()), + zap.Int64("dbID", req.GetDbID()), zap.Any("resp", resp)) return resp, nil } diff --git a/internal/datacoord/services_test.go b/internal/datacoord/services_test.go index 9735c8f6ca184..e8c3d3d126d8f 100644 --- a/internal/datacoord/services_test.go +++ b/internal/datacoord/services_test.go @@ -1399,9 +1399,10 @@ func TestImportV2(t *testing.T) { assert.NoError(t, err) assert.True(t, errors.Is(merr.Error(resp.GetStatus()), merr.ErrImportFailed)) - // normal case + // db does not exist var job ImportJob = &importJob{ ImportJob: &datapb.ImportJob{ + DbID: 1, JobID: 0, Schema: &schemapb.CollectionSchema{}, State: internalpb.ImportJobState_Failed, @@ -1410,12 +1411,31 @@ func TestImportV2(t *testing.T) { err = s.importMeta.AddJob(job) assert.NoError(t, err) resp, err = s.GetImportProgress(ctx, &internalpb.GetImportProgressRequest{ + DbID: 2, + JobID: "0", + }) + assert.NoError(t, err) + assert.True(t, errors.Is(merr.Error(resp.GetStatus()), merr.ErrImportFailed)) + + // normal case + job = &importJob{ + ImportJob: &datapb.ImportJob{ + DbID: 1, + JobID: 0, + Schema: &schemapb.CollectionSchema{}, + State: internalpb.ImportJobState_Pending, + }, + } + err = s.importMeta.AddJob(job) + assert.NoError(t, err) + resp, err = s.GetImportProgress(ctx, &internalpb.GetImportProgressRequest{ + DbID: 1, JobID: "0", }) assert.NoError(t, err) assert.Equal(t, int32(0), resp.GetStatus().GetCode()) - assert.Equal(t, int64(0), resp.GetProgress()) - assert.Equal(t, internalpb.ImportJobState_Failed, resp.GetState()) + assert.Equal(t, int64(10), resp.GetProgress()) + assert.Equal(t, internalpb.ImportJobState_Pending, resp.GetState()) }) t.Run("ListImports", func(t *testing.T) { @@ -1438,6 +1458,7 @@ func TestImportV2(t *testing.T) { assert.NoError(t, err) var job ImportJob = &importJob{ ImportJob: &datapb.ImportJob{ + DbID: 2, JobID: 0, CollectionID: 1, Schema: &schemapb.CollectionSchema{}, @@ -1454,7 +1475,20 @@ func TestImportV2(t *testing.T) { } err = s.importMeta.AddTask(task) assert.NoError(t, err) + // db id not match + resp, err = s.ListImports(ctx, &internalpb.ListImportsRequestInternal{ + DbID: 3, + CollectionID: 1, + }) + assert.NoError(t, err) + assert.Equal(t, int32(0), resp.GetStatus().GetCode()) + assert.Equal(t, 0, len(resp.GetJobIDs())) + assert.Equal(t, 0, len(resp.GetStates())) + assert.Equal(t, 0, len(resp.GetReasons())) + assert.Equal(t, 0, len(resp.GetProgresses())) + // db id match resp, err = s.ListImports(ctx, &internalpb.ListImportsRequestInternal{ + DbID: 2, CollectionID: 1, }) assert.NoError(t, err) diff --git a/internal/distributed/proxy/httpserver/handler_v2.go b/internal/distributed/proxy/httpserver/handler_v2.go index a2a1d21a7258d..800df8198af54 100644 --- a/internal/distributed/proxy/httpserver/handler_v2.go +++ b/internal/distributed/proxy/httpserver/handler_v2.go @@ -139,8 +139,8 @@ func (h *HandlersV2) RegisterRoutesToV2(router gin.IRouter) { router.POST(ImportJobCategory+ListAction, timeoutMiddleware(wrapperPost(func() any { return &OptionalCollectionNameReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.listImportJob))))) router.POST(ImportJobCategory+CreateAction, timeoutMiddleware(wrapperPost(func() any { return &ImportReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.createImportJob))))) - router.POST(ImportJobCategory+GetProgressAction, timeoutMiddleware(wrapperPost(func() any { return &JobIDReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.getImportJobProcess))))) - router.POST(ImportJobCategory+DescribeAction, timeoutMiddleware(wrapperPost(func() any { return &JobIDReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.getImportJobProcess))))) + router.POST(ImportJobCategory+GetProgressAction, timeoutMiddleware(wrapperPost(func() any { return &GetImportReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.getImportJobProcess))))) + router.POST(ImportJobCategory+DescribeAction, timeoutMiddleware(wrapperPost(func() any { return &GetImportReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.getImportJobProcess))))) } type ( diff --git a/internal/distributed/proxy/httpserver/request_v2.go b/internal/distributed/proxy/httpserver/request_v2.go index 31bf6f0d585b6..0ea2cc42a36c0 100644 --- a/internal/distributed/proxy/httpserver/request_v2.go +++ b/internal/distributed/proxy/httpserver/request_v2.go @@ -94,11 +94,14 @@ func (req *ImportReq) GetOptions() map[string]string { return req.Options } -type JobIDReq struct { - JobID string `json:"jobId" binding:"required"` +type GetImportReq struct { + DbName string `json:"dbName"` + JobID string `json:"jobId" binding:"required"` } -func (req *JobIDReq) GetJobID() string { return req.JobID } +func (req *GetImportReq) GetJobID() string { return req.JobID } + +func (req *GetImportReq) GetDbName() string { return req.DbName } type QueryReqV2 struct { DbName string `json:"dbName"` diff --git a/internal/proto/internal.proto b/internal/proto/internal.proto index 154191d2db51c..cdf2d9e028697 100644 --- a/internal/proto/internal.proto +++ b/internal/proto/internal.proto @@ -343,6 +343,7 @@ message ImportResponse { message GetImportProgressRequest { string db_name = 1; string jobID = 2; + int64 dbID = 3; } message ImportTaskProgress { diff --git a/internal/proxy/impl.go b/internal/proxy/impl.go index dfff13db75d36..21a1b926f6d89 100644 --- a/internal/proxy/impl.go +++ b/internal/proxy/impl.go @@ -6189,6 +6189,7 @@ func (node *Proxy) ImportV2(ctx context.Context, req *internalpb.ImportRequest) return &internalpb.ImportResponse{Status: merr.Status(err)}, nil } log := log.Ctx(ctx).With( + zap.String("dbName", req.GetDbName()), zap.String("collectionName", req.GetCollectionName()), zap.String("partition name", req.GetPartitionName()), zap.Any("files", req.GetFiles()), @@ -6214,6 +6215,11 @@ func (node *Proxy) ImportV2(ctx context.Context, req *internalpb.ImportRequest) } }() + dbInfo, err := globalMetaCache.GetDatabaseInfo(ctx, req.GetDbName()) + if err != nil { + resp.Status = merr.Status(err) + return resp, nil + } collectionID, err := globalMetaCache.GetCollectionID(ctx, req.GetDbName(), req.GetCollectionName()) if err != nil { resp.Status = merr.Status(err) @@ -6324,6 +6330,7 @@ func (node *Proxy) ImportV2(ctx context.Context, req *internalpb.ImportRequest) } } importRequest := &internalpb.ImportRequestInternal{ + DbID: dbInfo.dbID, CollectionID: collectionID, CollectionName: req.GetCollectionName(), PartitionIDs: partitionIDs, @@ -6348,14 +6355,28 @@ func (node *Proxy) GetImportProgress(ctx context.Context, req *internalpb.GetImp }, nil } log := log.Ctx(ctx).With( + zap.String("dbName", req.GetDbName()), zap.String("jobID", req.GetJobID()), ) + + resp := &internalpb.GetImportProgressResponse{ + Status: merr.Success(), + } + method := "GetImportProgress" tr := timerecord.NewTimeRecorder(method) log.Info(rpcReceived(method)) + // Fill db id for datacoord. + dbInfo, err := globalMetaCache.GetDatabaseInfo(ctx, req.GetDbName()) + if err != nil { + resp.Status = merr.Status(err) + return resp, nil + } + req.DbID = dbInfo.dbID + nodeID := fmt.Sprint(paramtable.GetNodeID()) - resp, err := node.dataCoord.GetImportProgress(ctx, req) + resp, err = node.dataCoord.GetImportProgress(ctx, req) if resp.GetStatus().GetCode() != 0 || err != nil { log.Warn("get import progress failed", zap.String("reason", resp.GetStatus().GetReason()), zap.Error(err)) metrics.ProxyFunctionCall.WithLabelValues(nodeID, method, metrics.FailLabel, req.GetDbName(), "").Inc() @@ -6392,6 +6413,11 @@ func (node *Proxy) ListImports(ctx context.Context, req *internalpb.ListImportsR err error collectionID UniqueID ) + dbInfo, err := globalMetaCache.GetDatabaseInfo(ctx, req.GetDbName()) + if err != nil { + resp.Status = merr.Status(err) + return resp, nil + } if req.GetCollectionName() != "" { collectionID, err = globalMetaCache.GetCollectionID(ctx, req.GetDbName(), req.GetCollectionName()) if err != nil { @@ -6400,7 +6426,9 @@ func (node *Proxy) ListImports(ctx context.Context, req *internalpb.ListImportsR return resp, nil } } + resp, err = node.dataCoord.ListImports(ctx, &internalpb.ListImportsRequestInternal{ + DbID: dbInfo.dbID, CollectionID: collectionID, }) if resp.GetStatus().GetCode() != 0 || err != nil { diff --git a/internal/proxy/impl_test.go b/internal/proxy/impl_test.go index 70171fe3b84cd..b2ab3cde1a2d4 100644 --- a/internal/proxy/impl_test.go +++ b/internal/proxy/impl_test.go @@ -1616,8 +1616,17 @@ func TestProxy_ImportV2(t *testing.T) { assert.NotEqual(t, int32(0), rsp.GetStatus().GetCode()) node.UpdateStateCode(commonpb.StateCode_Healthy) - // no such collection + // no such database mc := NewMockCache(t) + mc.EXPECT().GetDatabaseInfo(mock.Anything, mock.Anything).Return(nil, mockErr) + globalMetaCache = mc + rsp, err = node.ImportV2(ctx, &internalpb.ImportRequest{CollectionName: "aaa"}) + assert.NoError(t, err) + assert.NotEqual(t, int32(0), rsp.GetStatus().GetCode()) + + // no such collection + mc = NewMockCache(t) + mc.EXPECT().GetDatabaseInfo(mock.Anything, mock.Anything).Return(&databaseInfo{dbID: 1}, nil) mc.EXPECT().GetCollectionID(mock.Anything, mock.Anything, mock.Anything).Return(0, mockErr) globalMetaCache = mc rsp, err = node.ImportV2(ctx, &internalpb.ImportRequest{CollectionName: "aaa"}) @@ -1626,6 +1635,7 @@ func TestProxy_ImportV2(t *testing.T) { // get schema failed mc = NewMockCache(t) + mc.EXPECT().GetDatabaseInfo(mock.Anything, mock.Anything).Return(&databaseInfo{dbID: 1}, nil) mc.EXPECT().GetCollectionID(mock.Anything, mock.Anything, mock.Anything).Return(0, nil) mc.EXPECT().GetCollectionSchema(mock.Anything, mock.Anything, mock.Anything).Return(nil, mockErr) globalMetaCache = mc @@ -1635,6 +1645,7 @@ func TestProxy_ImportV2(t *testing.T) { // get channel failed mc = NewMockCache(t) + mc.EXPECT().GetDatabaseInfo(mock.Anything, mock.Anything).Return(&databaseInfo{dbID: 1}, nil) mc.EXPECT().GetCollectionID(mock.Anything, mock.Anything, mock.Anything).Return(0, nil) mc.EXPECT().GetCollectionSchema(mock.Anything, mock.Anything, mock.Anything).Return(&schemaInfo{ CollectionSchema: &schemapb.CollectionSchema{Fields: []*schemapb.FieldSchema{ @@ -1659,6 +1670,7 @@ func TestProxy_ImportV2(t *testing.T) { // get partitions failed mc = NewMockCache(t) + mc.EXPECT().GetDatabaseInfo(mock.Anything, mock.Anything).Return(&databaseInfo{dbID: 1}, nil) mc.EXPECT().GetCollectionID(mock.Anything, mock.Anything, mock.Anything).Return(0, nil) mc.EXPECT().GetCollectionSchema(mock.Anything, mock.Anything, mock.Anything).Return(&schemaInfo{ CollectionSchema: &schemapb.CollectionSchema{Fields: []*schemapb.FieldSchema{ @@ -1673,6 +1685,7 @@ func TestProxy_ImportV2(t *testing.T) { // get partitionID failed mc = NewMockCache(t) + mc.EXPECT().GetDatabaseInfo(mock.Anything, mock.Anything).Return(&databaseInfo{dbID: 1}, nil) mc.EXPECT().GetCollectionID(mock.Anything, mock.Anything, mock.Anything).Return(0, nil) mc.EXPECT().GetCollectionSchema(mock.Anything, mock.Anything, mock.Anything).Return(&schemaInfo{ CollectionSchema: &schemapb.CollectionSchema{}, @@ -1685,6 +1698,7 @@ func TestProxy_ImportV2(t *testing.T) { // no file mc = NewMockCache(t) + mc.EXPECT().GetDatabaseInfo(mock.Anything, mock.Anything).Return(&databaseInfo{dbID: 1}, nil) mc.EXPECT().GetCollectionID(mock.Anything, mock.Anything, mock.Anything).Return(0, nil) mc.EXPECT().GetCollectionSchema(mock.Anything, mock.Anything, mock.Anything).Return(&schemaInfo{ CollectionSchema: &schemapb.CollectionSchema{}, @@ -1731,7 +1745,18 @@ func TestProxy_ImportV2(t *testing.T) { assert.NotEqual(t, int32(0), rsp.GetStatus().GetCode()) node.UpdateStateCode(commonpb.StateCode_Healthy) + // no such database + mc := NewMockCache(t) + mc.EXPECT().GetDatabaseInfo(mock.Anything, mock.Anything).Return(nil, mockErr) + globalMetaCache = mc + rsp, err = node.GetImportProgress(ctx, &internalpb.GetImportProgressRequest{}) + assert.NoError(t, err) + assert.NotEqual(t, int32(0), rsp.GetStatus().GetCode()) + // normal case + mc = NewMockCache(t) + mc.EXPECT().GetDatabaseInfo(mock.Anything, mock.Anything).Return(&databaseInfo{dbID: 1}, nil) + globalMetaCache = mc dataCoord := mocks.NewMockDataCoordClient(t) dataCoord.EXPECT().GetImportProgress(mock.Anything, mock.Anything).Return(nil, nil) node.dataCoord = dataCoord @@ -1749,8 +1774,19 @@ func TestProxy_ImportV2(t *testing.T) { assert.NotEqual(t, int32(0), rsp.GetStatus().GetCode()) node.UpdateStateCode(commonpb.StateCode_Healthy) - // normal case + // no such database mc := NewMockCache(t) + mc.EXPECT().GetDatabaseInfo(mock.Anything, mock.Anything).Return(nil, mockErr) + globalMetaCache = mc + rsp, err = node.ListImports(ctx, &internalpb.ListImportsRequest{ + CollectionName: "col", + }) + assert.NoError(t, err) + assert.NotEqual(t, int32(0), rsp.GetStatus().GetCode()) + + // normal case + mc = NewMockCache(t) + mc.EXPECT().GetDatabaseInfo(mock.Anything, mock.Anything).Return(&databaseInfo{dbID: 1}, nil) mc.EXPECT().GetCollectionID(mock.Anything, mock.Anything, mock.Anything).Return(0, nil) globalMetaCache = mc dataCoord := mocks.NewMockDataCoordClient(t)