diff --git a/DEPS.bzl b/DEPS.bzl index 3e3f661f5528c..71d4df7d3a774 100644 --- a/DEPS.bzl +++ b/DEPS.bzl @@ -6807,13 +6807,13 @@ def go_deps(): name = "com_github_tikv_client_go_v2", build_file_proto_mode = "disable_global", importpath = "github.com/tikv/client-go/v2", - sha256 = "3a9d97649d1c917faebb2f7756e750213ab5fd34c070f17a38c9227b201862c9", - strip_prefix = "github.com/tikv/client-go/v2@v2.0.8-0.20240409022718-714958ccd4d5", + sha256 = "9dc8899d26420c39a52014f920ba361e445f2600f613669f2b20a62dd5a56791", + strip_prefix = "github.com/tikv/client-go/v2@v2.0.8-0.20240424052342-0229f4077f0c", urls = [ - "http://bazel-cache.pingcap.net:8080/gomod/github.com/tikv/client-go/v2/com_github_tikv_client_go_v2-v2.0.8-0.20240409022718-714958ccd4d5.zip", - "http://ats.apps.svc/gomod/github.com/tikv/client-go/v2/com_github_tikv_client_go_v2-v2.0.8-0.20240409022718-714958ccd4d5.zip", - "https://cache.hawkingrei.com/gomod/github.com/tikv/client-go/v2/com_github_tikv_client_go_v2-v2.0.8-0.20240409022718-714958ccd4d5.zip", - "https://storage.googleapis.com/pingcapmirror/gomod/github.com/tikv/client-go/v2/com_github_tikv_client_go_v2-v2.0.8-0.20240409022718-714958ccd4d5.zip", + "http://bazel-cache.pingcap.net:8080/gomod/github.com/tikv/client-go/v2/com_github_tikv_client_go_v2-v2.0.8-0.20240424052342-0229f4077f0c.zip", + "http://ats.apps.svc/gomod/github.com/tikv/client-go/v2/com_github_tikv_client_go_v2-v2.0.8-0.20240424052342-0229f4077f0c.zip", + "https://cache.hawkingrei.com/gomod/github.com/tikv/client-go/v2/com_github_tikv_client_go_v2-v2.0.8-0.20240424052342-0229f4077f0c.zip", + "https://storage.googleapis.com/pingcapmirror/gomod/github.com/tikv/client-go/v2/com_github_tikv_client_go_v2-v2.0.8-0.20240424052342-0229f4077f0c.zip", ], ) go_repository( diff --git a/br/pkg/restore/import.go b/br/pkg/restore/import.go index be549ac13c6a0..6f21ca7229ad3 100644 --- a/br/pkg/restore/import.go +++ b/br/pkg/restore/import.go @@ -1340,7 +1340,8 @@ func (importer *FileImporter) ingestSSTs( ) (*import_sstpb.IngestResponse, error) { leader := regionInfo.Leader if leader == nil { - leader = regionInfo.Region.GetPeers()[0] + return nil, errors.Annotatef(berrors.ErrPDLeaderNotFound, + "region id %d has no leader", regionInfo.Region.Id) } reqCtx := &kvrpcpb.Context{ RegionId: regionInfo.Region.GetId(), diff --git a/br/pkg/restore/import_retry_test.go b/br/pkg/restore/import_retry_test.go index c10a6c56b14c8..8e2a386b0e5f5 100644 --- a/br/pkg/restore/import_retry_test.go +++ b/br/pkg/restore/import_retry_test.go @@ -70,7 +70,8 @@ func initTestClient(isRawKv bool) *TestClient { } regions[i] = &split.RegionInfo{ Leader: &metapb.Peer{ - Id: i, + Id: i, + StoreId: 1, }, Region: &metapb.Region{ Id: i, @@ -281,7 +282,7 @@ func TestEpochNotMatch(t *testing.T) { {Id: 43}, }, }, - Leader: &metapb.Peer{Id: 43}, + Leader: &metapb.Peer{Id: 43, StoreId: 1}, } newRegion := pdtypes.NewRegionInfo(info.Region, info.Leader) mergeRegion := func() { @@ -340,7 +341,8 @@ func TestRegionSplit(t *testing.T) { EndKey: codec.EncodeBytes(nil, []byte("aayy")), }, Leader: &metapb.Peer{ - Id: 43, + Id: 43, + StoreId: 1, }, }, { @@ -350,7 +352,8 @@ func TestRegionSplit(t *testing.T) { EndKey: target.Region.EndKey, }, Leader: &metapb.Peer{ - Id: 45, + Id: 45, + StoreId: 1, }, }, } diff --git a/br/pkg/restore/split/mock_pd_client.go b/br/pkg/restore/split/mock_pd_client.go index cc01d68ecfc45..4bd709260e90a 100644 --- a/br/pkg/restore/split/mock_pd_client.go +++ b/br/pkg/restore/split/mock_pd_client.go @@ -74,8 +74,13 @@ func (c *MockPDClientForSplit) setRegions(boundaries [][]byte) []*metapb.Region StartKey: boundaries[i-1], EndKey: boundaries[i], } + p := &metapb.Peer{ + Id: c.lastRegionID, + StoreId: 1, + } c.Regions.SetRegion(&pdtypes.Region{ - Meta: r, + Meta: r, + Leader: p, }) ret = append(ret, r) } diff --git a/br/pkg/restore/split/split.go b/br/pkg/restore/split/split.go index 97197df839ccb..c69e5959f9812 100644 --- a/br/pkg/restore/split/split.go +++ b/br/pkg/restore/split/split.go @@ -59,7 +59,23 @@ func checkRegionConsistency(startKey, endKey []byte, regions []*RegionInfo) erro } cur := regions[0] + if cur.Leader == nil { + return errors.Annotatef(berrors.ErrPDBatchScanRegion, + "region %d's leader is nil", cur.Region.Id) + } + if cur.Leader.StoreId == 0 { + return errors.Annotatef(berrors.ErrPDBatchScanRegion, + "region %d's leader's store id is 0", cur.Region.Id) + } for _, r := range regions[1:] { + if r.Leader == nil { + return errors.Annotatef(berrors.ErrPDBatchScanRegion, + "region %d's leader is nil", r.Region.Id) + } + if r.Leader.StoreId == 0 { + return errors.Annotatef(berrors.ErrPDBatchScanRegion, + "region %d's leader's store id is 0", r.Region.Id) + } if !bytes.Equal(cur.Region.EndKey, r.Region.StartKey) { return errors.Annotatef(berrors.ErrPDBatchScanRegion, "region %d's endKey not equal to next region %d's startKey, endKey: %s, startKey: %s, region epoch: %s %s", diff --git a/br/pkg/restore/split/split_test.go b/br/pkg/restore/split/split_test.go index 077cdcdd1cb54..9ca523fe214f4 100644 --- a/br/pkg/restore/split/split_test.go +++ b/br/pkg/restore/split/split_test.go @@ -504,6 +504,10 @@ func TestPaginateScanRegion(t *testing.T) { StartKey: []byte{1}, EndKey: []byte{2}, }, + Leader: &metapb.Peer{ + Id: 1, + StoreId: 1, + }, }) mockPDClient.Regions.SetRegion(&pdtypes.Region{ Meta: &metapb.Region{ @@ -511,6 +515,10 @@ func TestPaginateScanRegion(t *testing.T) { StartKey: []byte{4}, EndKey: []byte{5}, }, + Leader: &metapb.Peer{ + Id: 4, + StoreId: 1, + }, }) _, err = PaginateScanRegion(ctx, mockClient, []byte{1}, []byte{5}, 3) @@ -525,6 +533,10 @@ func TestPaginateScanRegion(t *testing.T) { StartKey: []byte{2}, EndKey: []byte{3}, }, + Leader: &metapb.Peer{ + Id: 2, + StoreId: 1, + }, }, { Meta: &metapb.Region{ @@ -532,6 +544,10 @@ func TestPaginateScanRegion(t *testing.T) { StartKey: []byte{3}, EndKey: []byte{4}, }, + Leader: &metapb.Peer{ + Id: 3, + StoreId: 1, + }, }, } mockPDClient.scanRegions.beforeHook = func() { @@ -590,6 +606,10 @@ func TestRegionConsistency(t *testing.T) { "region 6's endKey not equal to next region 8's startKey(.*?)", []*RegionInfo{ { + Leader: &metapb.Peer{ + Id: 6, + StoreId: 1, + }, Region: &metapb.Region{ Id: 6, StartKey: codec.EncodeBytes([]byte{}, []byte("b")), @@ -598,6 +618,10 @@ func TestRegionConsistency(t *testing.T) { }, }, { + Leader: &metapb.Peer{ + Id: 8, + StoreId: 1, + }, Region: &metapb.Region{ Id: 8, StartKey: codec.EncodeBytes([]byte{}, []byte("e")), @@ -606,6 +630,58 @@ func TestRegionConsistency(t *testing.T) { }, }, }, + { + codec.EncodeBytes([]byte{}, []byte("c")), + codec.EncodeBytes([]byte{}, []byte("e")), + "region 6's leader is nil(.*?)", + []*RegionInfo{ + { + Region: &metapb.Region{ + Id: 6, + StartKey: codec.EncodeBytes([]byte{}, []byte("c")), + EndKey: codec.EncodeBytes([]byte{}, []byte("d")), + RegionEpoch: nil, + }, + }, + { + Region: &metapb.Region{ + Id: 8, + StartKey: codec.EncodeBytes([]byte{}, []byte("d")), + EndKey: codec.EncodeBytes([]byte{}, []byte("e")), + }, + }, + }, + }, + { + codec.EncodeBytes([]byte{}, []byte("c")), + codec.EncodeBytes([]byte{}, []byte("e")), + "region 6's leader's store id is 0(.*?)", + []*RegionInfo{ + { + Leader: &metapb.Peer{ + Id: 6, + StoreId: 0, + }, + Region: &metapb.Region{ + Id: 6, + StartKey: codec.EncodeBytes([]byte{}, []byte("c")), + EndKey: codec.EncodeBytes([]byte{}, []byte("d")), + RegionEpoch: nil, + }, + }, + { + Leader: &metapb.Peer{ + Id: 6, + StoreId: 0, + }, + Region: &metapb.Region{ + Id: 8, + StartKey: codec.EncodeBytes([]byte{}, []byte("d")), + EndKey: codec.EncodeBytes([]byte{}, []byte("e")), + }, + }, + }, + }, } for _, ca := range cases { err := checkRegionConsistency(ca.startKey, ca.endKey, ca.regions) diff --git a/go.mod b/go.mod index 6cedcfa435e4d..ab2d959d1cd5a 100644 --- a/go.mod +++ b/go.mod @@ -106,7 +106,7 @@ require ( github.com/tdakkota/asciicheck v0.2.0 github.com/tiancaiamao/appdash v0.0.0-20181126055449-889f96f722a2 github.com/tidwall/btree v1.7.0 - github.com/tikv/client-go/v2 v2.0.8-0.20240409022718-714958ccd4d5 + github.com/tikv/client-go/v2 v2.0.8-0.20240424052342-0229f4077f0c github.com/tikv/pd/client v0.0.0-20240322051414-fb9e2d561b6e github.com/timakin/bodyclose v0.0.0-20240125160201-f835fa56326a github.com/twmb/murmur3 v1.1.6 diff --git a/go.sum b/go.sum index 0502d3a3e9a78..d9ff6090529a5 100644 --- a/go.sum +++ b/go.sum @@ -785,8 +785,8 @@ github.com/tiancaiamao/gp v0.0.0-20221230034425-4025bc8a4d4a h1:J/YdBZ46WKpXsxsW github.com/tiancaiamao/gp v0.0.0-20221230034425-4025bc8a4d4a/go.mod h1:h4xBhSNtOeEosLJ4P7JyKXX7Cabg7AVkWCK5gV2vOrM= github.com/tidwall/btree v1.7.0 h1:L1fkJH/AuEh5zBnnBbmTwQ5Lt+bRJ5A8EWecslvo9iI= github.com/tidwall/btree v1.7.0/go.mod h1:twD9XRA5jj9VUQGELzDO4HPQTNJsoWWfYEL+EUQ2cKY= -github.com/tikv/client-go/v2 v2.0.8-0.20240409022718-714958ccd4d5 h1:NIYXG5l8JrDyc7k0zO17ppKJkRlUfKMWFOnjIQT5Tc4= -github.com/tikv/client-go/v2 v2.0.8-0.20240409022718-714958ccd4d5/go.mod h1:+vXk4Aex17GnI8gfSMPxrL0SQLbBYgP3Db4FvHiImwM= +github.com/tikv/client-go/v2 v2.0.8-0.20240424052342-0229f4077f0c h1:M97Y/RO0vGpX0FplwGTk02idZDmSPEJlO6fTCPaxkCI= +github.com/tikv/client-go/v2 v2.0.8-0.20240424052342-0229f4077f0c/go.mod h1:+vXk4Aex17GnI8gfSMPxrL0SQLbBYgP3Db4FvHiImwM= github.com/tikv/pd/client v0.0.0-20240322051414-fb9e2d561b6e h1:u2OoEvmh3qyjIiAKXUPRiFCOSwznByMINDx2fsorjAo= github.com/tikv/pd/client v0.0.0-20240322051414-fb9e2d561b6e/go.mod h1:Z/QAgOt29zvwBTd0H6pdx45VO6KRNc/O/DzGkVmSyZg= github.com/timakin/bodyclose v0.0.0-20240125160201-f835fa56326a h1:A6uKudFIfAEpoPdaal3aSqGxBzLyU8TqyXImLwo6dIo= diff --git a/lightning/pkg/importer/BUILD.bazel b/lightning/pkg/importer/BUILD.bazel index 0075a4db17b97..83797e95530cb 100644 --- a/lightning/pkg/importer/BUILD.bazel +++ b/lightning/pkg/importer/BUILD.bazel @@ -54,7 +54,6 @@ go_library( "//pkg/meta/autoid", "//pkg/parser", "//pkg/parser/ast", - "//pkg/parser/format", "//pkg/parser/model", "//pkg/parser/mysql", "//pkg/planner/core", @@ -116,7 +115,6 @@ go_test( "meta_manager_test.go", "precheck_impl_test.go", "precheck_test.go", - "restore_schema_test.go", "table_import_test.go", "tidb_test.go", ], diff --git a/lightning/pkg/importer/import.go b/lightning/pkg/importer/import.go index 20ad2e3b06f3a..a1cd053023e2b 100644 --- a/lightning/pkg/importer/import.go +++ b/lightning/pkg/importer/import.go @@ -56,7 +56,6 @@ import ( "github.com/pingcap/tidb/pkg/lightning/tikv" "github.com/pingcap/tidb/pkg/lightning/worker" "github.com/pingcap/tidb/pkg/meta/autoid" - "github.com/pingcap/tidb/pkg/parser" "github.com/pingcap/tidb/pkg/parser/model" "github.com/pingcap/tidb/pkg/session" "github.com/pingcap/tidb/pkg/sessionctx/variable" @@ -583,324 +582,19 @@ outside: return errors.Trace(err) } -type schemaStmtType int - -// String implements fmt.Stringer interface. -func (stmtType schemaStmtType) String() string { - switch stmtType { - case schemaCreateDatabase: - return "restore database schema" - case schemaCreateTable: - return "restore table schema" - case schemaCreateView: - return "restore view schema" - } - return "unknown statement of schema" -} - -const ( - schemaCreateDatabase schemaStmtType = iota - schemaCreateTable - schemaCreateView -) - -type schemaJob struct { - dbName string - tblName string // empty for create db jobs - stmtType schemaStmtType - stmts []string -} - -type restoreSchemaWorker struct { - ctx context.Context - quit context.CancelFunc - logger log.Logger - jobCh chan *schemaJob - errCh chan error - wg sync.WaitGroup - db *sql.DB - parser *parser.Parser - store storage.ExternalStorage -} - -func (worker *restoreSchemaWorker) addJob(sqlStr string, job *schemaJob) error { - stmts, err := createIfNotExistsStmt(worker.parser, sqlStr, job.dbName, job.tblName) - if err != nil { - return errors.Trace(err) - } - job.stmts = stmts - return worker.appendJob(job) -} - -func (worker *restoreSchemaWorker) makeJobs( - dbMetas []*mydump.MDDatabaseMeta, - getDBs func(context.Context) ([]*model.DBInfo, error), - getTables func(context.Context, string) ([]*model.TableInfo, error), -) error { - defer func() { - close(worker.jobCh) - worker.quit() - }() - - if len(dbMetas) == 0 { - return nil - } - - // 1. restore databases, execute statements concurrency - - dbs, err := getDBs(worker.ctx) - if err != nil { - worker.logger.Warn("get databases from downstream failed", zap.Error(err)) - } - dbSet := make(set.StringSet, len(dbs)) - for _, db := range dbs { - dbSet.Insert(db.Name.L) - } - - for _, dbMeta := range dbMetas { - // if downstream already has this database, we can skip ddl job - if dbSet.Exist(strings.ToLower(dbMeta.Name)) { - worker.logger.Info( - "database already exists in downstream, skip processing the source file", - zap.String("db", dbMeta.Name), - ) - continue - } - - sql := dbMeta.GetSchema(worker.ctx, worker.store) - err = worker.addJob(sql, &schemaJob{ - dbName: dbMeta.Name, - tblName: "", - stmtType: schemaCreateDatabase, - }) - if err != nil { - return err - } - } - err = worker.wait() - if err != nil { - return err - } - - // 2. restore tables, execute statements concurrency - - for _, dbMeta := range dbMetas { - // we can ignore error here, and let check failed later if schema not match - tables, err := getTables(worker.ctx, dbMeta.Name) - if err != nil { - worker.logger.Warn("get tables from downstream failed", zap.Error(err)) - } - tableSet := make(set.StringSet, len(tables)) - for _, t := range tables { - tableSet.Insert(t.Name.L) - } - for _, tblMeta := range dbMeta.Tables { - if tableSet.Exist(strings.ToLower(tblMeta.Name)) { - // we already has this table in TiDB. - // we should skip ddl job and let SchemaValid check. - worker.logger.Info( - "table already exists in downstream, skip processing the source file", - zap.String("db", dbMeta.Name), - zap.String("table", tblMeta.Name), - ) - continue - } else if tblMeta.SchemaFile.FileMeta.Path == "" { - return common.ErrSchemaNotExists.GenWithStackByArgs(dbMeta.Name, tblMeta.Name) - } - sql, err := tblMeta.GetSchema(worker.ctx, worker.store) - if err != nil { - return err - } - if sql != "" { - err = worker.addJob(sql, &schemaJob{ - dbName: dbMeta.Name, - tblName: tblMeta.Name, - stmtType: schemaCreateTable, - }) - if err != nil { - return err - } - } - } - } - err = worker.wait() - if err != nil { - return err - } - // 3. restore views. Since views can cross database we must restore views after all table schemas are restored. - for _, dbMeta := range dbMetas { - for _, viewMeta := range dbMeta.Views { - sql, err := viewMeta.GetSchema(worker.ctx, worker.store) - if sql != "" { - err = worker.addJob(sql, &schemaJob{ - dbName: dbMeta.Name, - tblName: viewMeta.Name, - stmtType: schemaCreateView, - }) - if err != nil { - return err - } - // we don't support restore views concurrency, cauz it maybe will raise a error - err = worker.wait() - if err != nil { - return err - } - } - if err != nil { - return err - } - } - } - return nil -} - -func (worker *restoreSchemaWorker) doJob() { - var session *sql.Conn - defer func() { - if session != nil { - _ = session.Close() - } - }() -loop: - for { - select { - case <-worker.ctx.Done(): - // don't `return` or throw `worker.ctx.Err()`here, - // if we `return`, we can't mark cancelled jobs as done, - // if we `throw(worker.ctx.Err())`, it will be blocked to death - break loop - case job := <-worker.jobCh: - if job == nil { - // successful exit - return - } - var err error - if session == nil { - session, err = func() (*sql.Conn, error) { - return worker.db.Conn(worker.ctx) - }() - if err != nil { - worker.wg.Done() - worker.throw(err) - // don't return - break loop - } - } - logger := worker.logger.With(zap.String("db", job.dbName), zap.String("table", job.tblName)) - sqlWithRetry := common.SQLWithRetry{ - Logger: worker.logger, - DB: session, - } - for _, stmt := range job.stmts { - task := logger.Begin(zap.DebugLevel, fmt.Sprintf("execute SQL: %s", stmt)) - err = sqlWithRetry.Exec(worker.ctx, "run create schema job", stmt) - if err != nil { - // try to imitate IF NOT EXISTS behavior for parsing errors - exists := false - switch job.stmtType { - case schemaCreateDatabase: - var err2 error - exists, err2 = common.SchemaExists(worker.ctx, session, job.dbName) - if err2 != nil { - task.Error("failed to check database existence", zap.Error(err2)) - } - case schemaCreateTable: - exists, _ = common.TableExists(worker.ctx, session, job.dbName, job.tblName) - } - if exists { - err = nil - } - } - task.End(zap.ErrorLevel, err) - - if err != nil { - err = common.ErrCreateSchema.Wrap(err).GenWithStackByArgs(common.UniqueTable(job.dbName, job.tblName), job.stmtType.String()) - worker.wg.Done() - worker.throw(err) - // don't return - break loop - } - } - worker.wg.Done() - } - } - // mark the cancelled job as `Done`, a little tricky, - // cauz we need make sure `worker.wg.Wait()` wouldn't blocked forever - for range worker.jobCh { - worker.wg.Done() - } -} - -func (worker *restoreSchemaWorker) wait() error { - // avoid to `worker.wg.Wait()` blocked forever when all `doJob`'s goroutine exited. - // don't worry about goroutine below, it never become a zombie, - // cauz we have mechanism to clean cancelled jobs in `worker.jobCh`. - // means whole jobs has been send to `worker.jobCh` would be done. - waitCh := make(chan struct{}) - go func() { - worker.wg.Wait() - close(waitCh) - }() - select { - case err := <-worker.errCh: - return err - case <-worker.ctx.Done(): - return worker.ctx.Err() - case <-waitCh: - return nil - } -} - -func (worker *restoreSchemaWorker) throw(err error) { - select { - case <-worker.ctx.Done(): - // don't throw `worker.ctx.Err()` again, it will be blocked to death. - return - case worker.errCh <- err: - worker.quit() - } -} - -func (worker *restoreSchemaWorker) appendJob(job *schemaJob) error { - worker.wg.Add(1) - select { - case err := <-worker.errCh: - // cancel the job - worker.wg.Done() - return err - case <-worker.ctx.Done(): - // cancel the job - worker.wg.Done() - return errors.Trace(worker.ctx.Err()) - case worker.jobCh <- job: - return nil - } -} - func (rc *Controller) restoreSchema(ctx context.Context) error { // create table with schema file // we can handle the duplicated created with createIfNotExist statement // and we will check the schema in TiDB is valid with the datafile in DataCheck later. - logTask := log.FromContext(ctx).Begin(zap.InfoLevel, "restore all schema") + logger := log.FromContext(ctx) concurrency := min(rc.cfg.App.RegionConcurrency, 8) - childCtx, cancel := context.WithCancel(ctx) - p := parser.New() - p.SetSQLMode(rc.cfg.TiDB.SQLMode) - worker := restoreSchemaWorker{ - ctx: childCtx, - quit: cancel, - logger: log.FromContext(ctx), - jobCh: make(chan *schemaJob, concurrency), - errCh: make(chan error), - db: rc.db, - parser: p, - store: rc.store, - } - for i := 0; i < concurrency; i++ { - go worker.doJob() - } - err := worker.makeJobs(rc.dbMetas, rc.preInfoGetter.FetchRemoteDBModels, rc.preInfoGetter.FetchRemoteTableModels) - logTask.End(zap.ErrorLevel, err) + // sql.DB is a connection pool, we set it to concurrency + 1(for job generator) + // to reuse connections, as we might call db.Conn/conn.Close many times. + // there's no API to get sql.DB.MaxIdleConns, so we revert to its default which is 2 + rc.db.SetMaxIdleConns(concurrency + 1) + defer rc.db.SetMaxIdleConns(2) + schemaImp := mydump.NewSchemaImporter(logger, rc.cfg.TiDB.SQLMode, rc.db, rc.store, concurrency) + err := schemaImp.Run(ctx, rc.dbMetas) if err != nil { return err } diff --git a/lightning/pkg/importer/restore_schema_test.go b/lightning/pkg/importer/restore_schema_test.go deleted file mode 100644 index d8a4026cb9eff..0000000000000 --- a/lightning/pkg/importer/restore_schema_test.go +++ /dev/null @@ -1,247 +0,0 @@ -// Copyright 2022 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package importer - -import ( - "context" - stderrors "errors" - "fmt" - "testing" - - "github.com/DATA-DOG/go-sqlmock" - "github.com/pingcap/errors" - "github.com/pingcap/tidb/br/pkg/mock" - "github.com/pingcap/tidb/br/pkg/storage" - "github.com/pingcap/tidb/pkg/ddl" - "github.com/pingcap/tidb/pkg/lightning/backend" - "github.com/pingcap/tidb/pkg/lightning/checkpoints" - "github.com/pingcap/tidb/pkg/lightning/config" - "github.com/pingcap/tidb/pkg/lightning/mydump" - "github.com/pingcap/tidb/pkg/parser" - "github.com/pingcap/tidb/pkg/parser/ast" - "github.com/pingcap/tidb/pkg/parser/model" - "github.com/pingcap/tidb/pkg/parser/mysql" - tmock "github.com/pingcap/tidb/pkg/util/mock" - filter "github.com/pingcap/tidb/pkg/util/table-filter" - "github.com/stretchr/testify/require" - "github.com/stretchr/testify/suite" - "go.uber.org/mock/gomock" -) - -type restoreSchemaSuite struct { - suite.Suite - ctx context.Context - rc *Controller - controller *gomock.Controller - dbMock sqlmock.Sqlmock - tableInfos []*model.TableInfo - infoGetter *PreImportInfoGetterImpl - targetInfoGetter *TargetInfoGetterImpl -} - -func TestRestoreSchemaSuite(t *testing.T) { - suite.Run(t, new(restoreSchemaSuite)) -} - -func (s *restoreSchemaSuite) SetupSuite() { - ctx := context.Background() - fakeDataDir := s.T().TempDir() - - store, err := storage.NewLocalStorage(fakeDataDir) - require.NoError(s.T(), err) - // restore database schema file - fakeDBName := "fakedb" - // please follow the `mydump.defaultFileRouteRules`, matches files like '{schema}-schema-create.sql' - fakeFileName := fmt.Sprintf("%s-schema-create.sql", fakeDBName) - err = store.WriteFile(ctx, fakeFileName, []byte(fmt.Sprintf("CREATE DATABASE %s;", fakeDBName))) - require.NoError(s.T(), err) - // restore table schema files - fakeTableFilesCount := 8 - - p := parser.New() - p.SetSQLMode(mysql.ModeANSIQuotes) - se := tmock.NewContext() - - tableInfos := make([]*model.TableInfo, 0, fakeTableFilesCount) - for i := 1; i <= fakeTableFilesCount; i++ { - fakeTableName := fmt.Sprintf("tbl%d", i) - // please follow the `mydump.defaultFileRouteRules`, matches files like '{schema}.{table}-schema.sql' - fakeFileName := fmt.Sprintf("%s.%s-schema.sql", fakeDBName, fakeTableName) - fakeFileContent := fmt.Sprintf("CREATE TABLE %s(i TINYINT);", fakeTableName) - err = store.WriteFile(ctx, fakeFileName, []byte(fakeFileContent)) - require.NoError(s.T(), err) - - node, err := p.ParseOneStmt(fakeFileContent, "", "") - require.NoError(s.T(), err) - core, err := ddl.MockTableInfo(se, node.(*ast.CreateTableStmt), 0xabcdef) - require.NoError(s.T(), err) - core.State = model.StatePublic - tableInfos = append(tableInfos, core) - } - s.tableInfos = tableInfos - // restore view schema files - fakeViewFilesCount := 8 - for i := 1; i <= fakeViewFilesCount; i++ { - fakeViewName := fmt.Sprintf("tbl%d", i) - // please follow the `mydump.defaultFileRouteRules`, matches files like '{schema}.{table}-schema-view.sql' - fakeFileName := fmt.Sprintf("%s.%s-schema-view.sql", fakeDBName, fakeViewName) - fakeFileContent := []byte(fmt.Sprintf("CREATE ALGORITHM=UNDEFINED VIEW `%s` (`i`) AS SELECT `i` FROM `%s`.`%s`;", fakeViewName, fakeDBName, fmt.Sprintf("tbl%d", i))) - err = store.WriteFile(ctx, fakeFileName, fakeFileContent) - require.NoError(s.T(), err) - } - config := config.NewConfig() - config.Mydumper.DefaultFileRules = true - config.Mydumper.CharacterSet = "utf8mb4" - config.App.RegionConcurrency = 8 - mydumpLoader, err := mydump.NewLoaderWithStore(ctx, mydump.NewLoaderCfg(config), store) - s.Require().NoError(err) - - dbMetas := mydumpLoader.GetDatabases() - targetInfoGetter := &TargetInfoGetterImpl{ - cfg: config, - } - preInfoGetter := &PreImportInfoGetterImpl{ - cfg: config, - srcStorage: store, - targetInfoGetter: targetInfoGetter, - dbMetas: dbMetas, - } - preInfoGetter.Init() - s.rc = &Controller{ - checkTemplate: NewSimpleTemplate(), - cfg: config, - store: store, - dbMetas: dbMetas, - checkpointsDB: &checkpoints.NullCheckpointsDB{}, - preInfoGetter: preInfoGetter, - } - s.infoGetter = preInfoGetter - s.targetInfoGetter = targetInfoGetter -} - -//nolint:interfacer // change test case signature might cause Check failed to find this test case? -func (s *restoreSchemaSuite) SetupTest() { - s.controller, s.ctx = gomock.WithContext(context.Background(), s.T()) - mockTargetInfoGetter := mock.NewMockTargetInfoGetter(s.controller) - mockBackend := mock.NewMockBackend(s.controller) - mockTargetInfoGetter.EXPECT(). - FetchRemoteDBModels(gomock.Any()). - AnyTimes(). - Return([]*model.DBInfo{{Name: model.NewCIStr("fakedb")}}, nil) - mockTargetInfoGetter.EXPECT(). - FetchRemoteTableModels(gomock.Any(), gomock.Any()). - AnyTimes(). - Return(s.tableInfos, nil) - mockBackend.EXPECT().Close() - theBackend := backend.MakeEngineManager(mockBackend) - s.rc.engineMgr = theBackend - s.rc.backend = mockBackend - s.targetInfoGetter.backend = mockTargetInfoGetter - - mockDB, sqlMock, err := sqlmock.New() - require.NoError(s.T(), err) - for i := 0; i < 17; i++ { - sqlMock.ExpectExec(".*").WillReturnResult(sqlmock.NewResult(int64(i), 1)) - } - s.targetInfoGetter.db = mockDB - s.rc.db = mockDB - s.dbMock = sqlMock -} - -func (s *restoreSchemaSuite) TearDownTest() { - s.rc.Close() - s.controller.Finish() -} - -func (s *restoreSchemaSuite) TestRestoreSchemaSuccessful() { - // before restore, if sysVars is initialized by other test, the time_zone should be default value - if len(s.rc.sysVars) > 0 { - tz, ok := s.rc.sysVars["time_zone"] - require.True(s.T(), ok) - require.Equal(s.T(), "SYSTEM", tz) - } - - s.dbMock.ExpectQuery(".*").WillReturnRows(sqlmock.NewRows([]string{"time_zone"}).AddRow("SYSTEM")) - s.rc.cfg.TiDB.Vars = map[string]string{ - "time_zone": "UTC", - } - err := s.rc.restoreSchema(s.ctx) - require.NoError(s.T(), err) - - // test after restore schema, sysVars has been updated - tz, ok := s.rc.sysVars["time_zone"] - require.True(s.T(), ok) - require.Equal(s.T(), "UTC", tz) -} - -func (s *restoreSchemaSuite) TestRestoreSchemaFailed() { - // use injectErr which cannot be retried - injectErr := stderrors.New("could not match actual sql") - mockDB, sqlMock, err := sqlmock.New() - require.NoError(s.T(), err) - sqlMock.ExpectExec(".*").WillReturnError(injectErr) - for i := 0; i < 16; i++ { - sqlMock.ExpectExec(".*").WillReturnResult(sqlmock.NewResult(int64(i), 1)) - } - - s.rc.db = mockDB - s.targetInfoGetter.db = mockDB - err = s.rc.restoreSchema(s.ctx) - require.Error(s.T(), err) - require.True(s.T(), errors.ErrorEqual(err, injectErr)) -} - -// When restoring a CSV with `-no-schema` and the target table doesn't exist -// then we can't restore the schema as the `Path` is empty. This is to make -// sure this results in the correct error. -// https://github.com/pingcap/br/issues/1394 -func (s *restoreSchemaSuite) TestNoSchemaPath() { - fakeTable := mydump.MDTableMeta{ - DB: "fakedb", - Name: "fake1", - SchemaFile: mydump.FileInfo{ - TableName: filter.Table{ - Schema: "fakedb", - Name: "fake1", - }, - FileMeta: mydump.SourceFileMeta{ - Path: "", - }, - }, - DataFiles: []mydump.FileInfo{}, - TotalSize: 0, - } - s.rc.dbMetas[0].Tables = append(s.rc.dbMetas[0].Tables, &fakeTable) - err := s.rc.restoreSchema(s.ctx) - require.Error(s.T(), err) - require.Regexp(s.T(), `table .* schema not found`, err.Error()) - s.rc.dbMetas[0].Tables = s.rc.dbMetas[0].Tables[:len(s.rc.dbMetas[0].Tables)-1] -} - -func (s *restoreSchemaSuite) TestRestoreSchemaContextCancel() { - childCtx, cancel := context.WithCancel(s.ctx) - mockDB, sqlMock, err := sqlmock.New() - require.NoError(s.T(), err) - for i := 0; i < 17; i++ { - sqlMock.ExpectExec(".*").WillReturnResult(sqlmock.NewResult(int64(i), 1)) - } - s.rc.db = mockDB - s.targetInfoGetter.db = mockDB - cancel() - err = s.rc.restoreSchema(childCtx) - require.Error(s.T(), err) - err = errors.Cause(err) - require.Equal(s.T(), childCtx.Err(), err) -} diff --git a/lightning/pkg/importer/tidb.go b/lightning/pkg/importer/tidb.go index 9f083f76feb5c..a732bee093b5f 100644 --- a/lightning/pkg/importer/tidb.go +++ b/lightning/pkg/importer/tidb.go @@ -31,8 +31,6 @@ import ( "github.com/pingcap/tidb/pkg/lightning/metric" "github.com/pingcap/tidb/pkg/lightning/mydump" "github.com/pingcap/tidb/pkg/parser" - "github.com/pingcap/tidb/pkg/parser/ast" - "github.com/pingcap/tidb/pkg/parser/format" "github.com/pingcap/tidb/pkg/parser/model" "github.com/pingcap/tidb/pkg/parser/mysql" "github.com/pingcap/tidb/pkg/sessionctx/variable" @@ -131,47 +129,6 @@ func (timgr *TiDBManager) Close() { timgr.db.Close() } -func createIfNotExistsStmt(p *parser.Parser, createTable, dbName, tblName string) ([]string, error) { - stmts, _, err := p.ParseSQL(createTable) - if err != nil { - return []string{}, common.ErrInvalidSchemaStmt.Wrap(err).GenWithStackByArgs(createTable) - } - - var res strings.Builder - ctx := format.NewRestoreCtx(format.DefaultRestoreFlags|format.RestoreTiDBSpecialComment|format.RestoreWithTTLEnableOff, &res) - - retStmts := make([]string, 0, len(stmts)) - for _, stmt := range stmts { - switch node := stmt.(type) { - case *ast.CreateDatabaseStmt: - node.Name = model.NewCIStr(dbName) - node.IfNotExists = true - case *ast.DropDatabaseStmt: - node.Name = model.NewCIStr(dbName) - node.IfExists = true - case *ast.CreateTableStmt: - node.Table.Schema = model.NewCIStr(dbName) - node.Table.Name = model.NewCIStr(tblName) - node.IfNotExists = true - case *ast.CreateViewStmt: - node.ViewName.Schema = model.NewCIStr(dbName) - node.ViewName.Name = model.NewCIStr(tblName) - case *ast.DropTableStmt: - node.Tables[0].Schema = model.NewCIStr(dbName) - node.Tables[0].Name = model.NewCIStr(tblName) - node.IfExists = true - } - if err := stmt.Restore(ctx); err != nil { - return []string{}, common.ErrInvalidSchemaStmt.Wrap(err).GenWithStackByArgs(createTable) - } - ctx.WritePlain(";") - retStmts = append(retStmts, res.String()) - res.Reset() - } - - return retStmts, nil -} - // DropTable drops a table. func (timgr *TiDBManager) DropTable(ctx context.Context, tableName string) error { sql := common.SQLWithRetry{ diff --git a/lightning/pkg/importer/tidb_test.go b/lightning/pkg/importer/tidb_test.go index 830ab7b24218f..e10f3bede6c3b 100644 --- a/lightning/pkg/importer/tidb_test.go +++ b/lightning/pkg/importer/tidb_test.go @@ -28,7 +28,6 @@ import ( "github.com/pingcap/tidb/pkg/lightning/checkpoints" "github.com/pingcap/tidb/pkg/lightning/metric" "github.com/pingcap/tidb/pkg/lightning/mydump" - "github.com/pingcap/tidb/pkg/parser" "github.com/pingcap/tidb/pkg/parser/ast" "github.com/pingcap/tidb/pkg/parser/model" tmysql "github.com/pingcap/tidb/pkg/parser/mysql" @@ -61,109 +60,6 @@ func newTiDBSuite(t *testing.T) *tidbSuite { return &s } -func TestCreateTableIfNotExistsStmt(t *testing.T) { - dbName := "testdb" - p := parser.New() - createSQLIfNotExistsStmt := func(createTable, tableName string) []string { - res, err := createIfNotExistsStmt(p, createTable, dbName, tableName) - require.NoError(t, err) - return res - } - - require.Equal(t, []string{"CREATE DATABASE IF NOT EXISTS `testdb` CHARACTER SET = utf8 COLLATE = utf8_general_ci;"}, - createSQLIfNotExistsStmt("CREATE DATABASE `foo` CHARACTER SET = utf8 COLLATE = utf8_general_ci;", "")) - - require.Equal(t, []string{"CREATE TABLE IF NOT EXISTS `testdb`.`foo` (`bar` TINYINT(1));"}, - createSQLIfNotExistsStmt("CREATE TABLE `foo`(`bar` TINYINT(1));", "foo")) - - require.Equal(t, []string{"CREATE TABLE IF NOT EXISTS `testdb`.`foo` (`bar` TINYINT(1));"}, - createSQLIfNotExistsStmt("CREATE TABLE IF NOT EXISTS `foo`(`bar` TINYINT(1));", "foo")) - - // case insensitive - require.Equal(t, []string{"CREATE TABLE IF NOT EXISTS `testdb`.`fOo` (`bar` TINYINT(1));"}, - createSQLIfNotExistsStmt("/* cOmmEnt */ creAte tablE `fOo`(`bar` TinyinT(1));", "fOo")) - - require.Equal(t, []string{"CREATE TABLE IF NOT EXISTS `testdb`.`FoO` (`bAR` TINYINT(1));"}, - createSQLIfNotExistsStmt("/* coMMenT */ crEatE tAble If not EXISts `FoO`(`bAR` tiNyInT(1));", "FoO")) - - // only one "CREATE TABLE" is replaced - require.Equal(t, []string{"CREATE TABLE IF NOT EXISTS `testdb`.`foo` (`bar` INT(1) COMMENT 'CREATE TABLE');"}, - createSQLIfNotExistsStmt("CREATE TABLE `foo`(`bar` INT(1) COMMENT 'CREATE TABLE');", "foo")) - - // test clustered index consistency - require.Equal(t, []string{"CREATE TABLE IF NOT EXISTS `testdb`.`foo` (`bar` INT(1) PRIMARY KEY /*T![clustered_index] CLUSTERED */ COMMENT 'CREATE TABLE');"}, - createSQLIfNotExistsStmt("CREATE TABLE `foo`(`bar` INT(1) PRIMARY KEY CLUSTERED COMMENT 'CREATE TABLE');", "foo")) - require.Equal(t, []string{"CREATE TABLE IF NOT EXISTS `testdb`.`foo` (`bar` INT(1) COMMENT 'CREATE TABLE',PRIMARY KEY(`bar`) /*T![clustered_index] NONCLUSTERED */);"}, - createSQLIfNotExistsStmt("CREATE TABLE `foo`(`bar` INT(1) COMMENT 'CREATE TABLE', PRIMARY KEY (`bar`) NONCLUSTERED);", "foo")) - require.Equal(t, []string{"CREATE TABLE IF NOT EXISTS `testdb`.`foo` (`bar` INT(1) PRIMARY KEY /*T![clustered_index] NONCLUSTERED */ COMMENT 'CREATE TABLE');"}, - createSQLIfNotExistsStmt("CREATE TABLE `foo`(`bar` INT(1) PRIMARY KEY /*T![clustered_index] NONCLUSTERED */ COMMENT 'CREATE TABLE');", "foo")) - require.Equal(t, []string{"CREATE TABLE IF NOT EXISTS `testdb`.`foo` (`bar` INT(1) COMMENT 'CREATE TABLE',PRIMARY KEY(`bar`) /*T![clustered_index] CLUSTERED */);"}, - createSQLIfNotExistsStmt("CREATE TABLE `foo`(`bar` INT(1) COMMENT 'CREATE TABLE', PRIMARY KEY (`bar`) /*T![clustered_index] CLUSTERED */);", "foo")) - - require.Equal(t, []string{"CREATE TABLE IF NOT EXISTS `testdb`.`foo` (`bar` INT(1) PRIMARY KEY /*T![auto_rand] AUTO_RANDOM(2) */ COMMENT 'CREATE TABLE');"}, - createSQLIfNotExistsStmt("CREATE TABLE `foo`(`bar` INT(1) PRIMARY KEY AUTO_RANDOM(2) COMMENT 'CREATE TABLE');", "foo")) - - // upper case becomes shorter - require.Equal(t, []string{"CREATE TABLE IF NOT EXISTS `testdb`.`ſ` (`ı` TINYINT(1));"}, - createSQLIfNotExistsStmt("CREATE TABLE `ſ`(`ı` TINYINT(1));", "ſ")) - - // upper case becomes longer - require.Equal(t, []string{"CREATE TABLE IF NOT EXISTS `testdb`.`ɑ` (`ȿ` TINYINT(1));"}, - createSQLIfNotExistsStmt("CREATE TABLE `ɑ`(`ȿ` TINYINT(1));", "ɑ")) - - // non-utf-8 - require.Equal(t, []string{"CREATE TABLE IF NOT EXISTS `testdb`.`\xcc\xcc\xcc` (`???` TINYINT(1));"}, - createSQLIfNotExistsStmt("CREATE TABLE `\xcc\xcc\xcc`(`\xdd\xdd\xdd` TINYINT(1));", "\xcc\xcc\xcc")) - - // renaming a table - require.Equal(t, []string{"CREATE TABLE IF NOT EXISTS `testdb`.`ba``r` (`x` INT);"}, - createSQLIfNotExistsStmt("create table foo(x int);", "ba`r")) - - // conditional comments - require.Equal(t, []string{ - "SET NAMES 'binary';", - "SET @@SESSION.`FOREIGN_KEY_CHECKS`=0;", - "CREATE TABLE IF NOT EXISTS `testdb`.`m` (`z` DOUBLE) ENGINE = InnoDB AUTO_INCREMENT = 8343230 DEFAULT CHARACTER SET = UTF8;", - }, - createSQLIfNotExistsStmt(` - /*!40101 SET NAMES binary*/; - /*!40014 SET FOREIGN_KEY_CHECKS=0*/; - CREATE TABLE x.y (z double) ENGINE=InnoDB AUTO_INCREMENT=8343230 DEFAULT CHARSET=utf8; - `, "m")) - - // create view - require.Equal(t, []string{ - "SET NAMES 'binary';", - "DROP TABLE IF EXISTS `testdb`.`m`;", - "DROP VIEW IF EXISTS `testdb`.`m`;", - "SET @`PREV_CHARACTER_SET_CLIENT`=@@`character_set_client`;", - "SET @`PREV_CHARACTER_SET_RESULTS`=@@`character_set_results`;", - "SET @`PREV_COLLATION_CONNECTION`=@@`collation_connection`;", - "SET @@SESSION.`character_set_client`=`utf8`;", - "SET @@SESSION.`character_set_results`=`utf8`;", - "SET @@SESSION.`collation_connection`=`utf8_general_ci`;", - "CREATE ALGORITHM = UNDEFINED DEFINER = `root`@`192.168.198.178` SQL SECURITY DEFINER VIEW `testdb`.`m` (`s`) AS SELECT `s` FROM `db1`.`v1` WHERE `i`<2;", - "SET @@SESSION.`character_set_client`=@`PREV_CHARACTER_SET_CLIENT`;", - "SET @@SESSION.`character_set_results`=@`PREV_CHARACTER_SET_RESULTS`;", - "SET @@SESSION.`collation_connection`=@`PREV_COLLATION_CONNECTION`;", - }, - createSQLIfNotExistsStmt(` - /*!40101 SET NAMES binary*/; - DROP TABLE IF EXISTS v2; - DROP VIEW IF EXISTS v2; - SET @PREV_CHARACTER_SET_CLIENT=@@CHARACTER_SET_CLIENT; - SET @PREV_CHARACTER_SET_RESULTS=@@CHARACTER_SET_RESULTS; - SET @PREV_COLLATION_CONNECTION=@@COLLATION_CONNECTION; - SET character_set_client = utf8; - SET character_set_results = utf8; - SET collation_connection = utf8_general_ci; - CREATE ALGORITHM=UNDEFINED DEFINER=root@192.168.198.178 SQL SECURITY DEFINER VIEW v2 (s) AS SELECT s FROM db1.v1 WHERE i<2; - SET character_set_client = @PREV_CHARACTER_SET_CLIENT; - SET character_set_results = @PREV_CHARACTER_SET_RESULTS; - SET collation_connection = @PREV_COLLATION_CONNECTION; - `, "m")) -} - func TestDropTable(t *testing.T) { s := newTiDBSuite(t) ctx := context.Background() diff --git a/pkg/ddl/BUILD.bazel b/pkg/ddl/BUILD.bazel index cae7a382ff8c4..1f504132bb680 100644 --- a/pkg/ddl/BUILD.bazel +++ b/pkg/ddl/BUILD.bazel @@ -53,6 +53,7 @@ go_library( "rollingback.go", "sanity_check.go", "schema.go", + "schema_version.go", "sequence.go", "split_region.go", "stat.go", diff --git a/pkg/ddl/callback.go b/pkg/ddl/callback.go index 114003317df4c..8b89ed42a7776 100644 --- a/pkg/ddl/callback.go +++ b/pkg/ddl/callback.go @@ -49,6 +49,7 @@ type Callback interface { // OnChanged is called after a ddl statement is finished. OnChanged(err error) error // OnSchemaStateChanged is called after a schema state is changed. + // only called inside tests. OnSchemaStateChanged(schemaVer int64) // OnJobRunBefore is called before running job. OnJobRunBefore(job *model.Job) diff --git a/pkg/ddl/ddl.go b/pkg/ddl/ddl.go index 03bde23fd274c..48c99cfae5f90 100644 --- a/pkg/ddl/ddl.go +++ b/pkg/ddl/ddl.go @@ -368,7 +368,7 @@ type ddlCtx struct { stateSyncer syncer.StateSyncer ddlJobDoneCh chan struct{} ddlEventCh chan<- *statsutil.DDLEvent - lease time.Duration // lease is schema lease. + lease time.Duration // lease is schema lease, default 45s, see config.Lease. binlogCli *pumpcli.PumpsClient // binlogCli is used for Binlog. infoCache *infoschema.InfoCache statsHandle *handle.Handle @@ -398,6 +398,7 @@ type ddlCtx struct { // hook may be modified. mu struct { sync.RWMutex + // see newDefaultCallBack for its value in normal flow. hook Callback interceptor Interceptor } diff --git a/pkg/ddl/ddl_worker.go b/pkg/ddl/ddl_worker.go index c205b7a98df76..898da94134877 100644 --- a/pkg/ddl/ddl_worker.go +++ b/pkg/ddl/ddl_worker.go @@ -43,7 +43,6 @@ import ( tidbutil "github.com/pingcap/tidb/pkg/util" "github.com/pingcap/tidb/pkg/util/dbterror" "github.com/pingcap/tidb/pkg/util/logutil" - "github.com/pingcap/tidb/pkg/util/mathutil" "github.com/pingcap/tidb/pkg/util/resourcegrouptag" "github.com/pingcap/tidb/pkg/util/topsql" topsqlstate "github.com/pingcap/tidb/pkg/util/topsql/state" @@ -103,7 +102,7 @@ type worker struct { sess *sess.Session // sess is used and only used in running DDL job. delRangeManager delRangeManager logCtx context.Context - lockSeqNum bool + seqNumLocked bool *ddlCtx } @@ -214,6 +213,7 @@ func (d *ddl) limitDDLJobs(ch chan *limitJobTask, handler func(tasks []*limitJob tasks := make([]*limitJobTask, 0, batchAddingJobs) for { select { + // the channel is never closed case task := <-ch: tasks = tasks[:0] jobLen := len(ch) @@ -394,11 +394,11 @@ func (d *ddl) addBatchDDLJobs(tasks []*limitJobTask) error { return errors.Trace(err) } defer d.sessPool.Put(se) - job, err := getJobsBySQL(sess.NewSession(se), JobTable, fmt.Sprintf("type = %d", model.ActionFlashbackCluster)) + jobs, err := getJobsBySQL(sess.NewSession(se), JobTable, fmt.Sprintf("type = %d", model.ActionFlashbackCluster)) if err != nil { return errors.Trace(err) } - if len(job) != 0 { + if len(jobs) != 0 { return errors.Errorf("Can't add ddl job, have flashback cluster job") } @@ -745,7 +745,7 @@ func (w *worker) finishDDLJob(t *meta.Meta, job *model.Job) (err error) { func (w *worker) writeDDLSeqNum(job *model.Job) { w.ddlSeqNumMu.Lock() w.ddlSeqNumMu.seqNum++ - w.lockSeqNum = true + w.seqNumLocked = true job.SeqNum = w.ddlSeqNumMu.seqNum } @@ -800,12 +800,12 @@ func (w *JobContext) setDDLLabelForTopSQL(jobQuery string) { } func (w *worker) unlockSeqNum(err error) { - if w.lockSeqNum { + if w.seqNumLocked { if err != nil { // if meet error, we should reset seqNum. w.ddlSeqNumMu.seqNum-- } - w.lockSeqNum = false + w.seqNumLocked = false w.ddlSeqNumMu.Unlock() } } @@ -1402,68 +1402,12 @@ func waitSchemaChanged(d *ddlCtx, waitTime time.Duration, latestSchemaVersion in return checkAllVersions(d, job, latestSchemaVersion, timeStart) } -func checkAllVersions(d *ddlCtx, job *model.Job, latestSchemaVersion int64, timeStart time.Time) error { - failpoint.Inject("checkDownBeforeUpdateGlobalVersion", func(val failpoint.Value) { - if val.(bool) { - if mockDDLErrOnce > 0 && mockDDLErrOnce != latestSchemaVersion { - panic("check down before update global version failed") - } - mockDDLErrOnce = -1 - } - }) - - // OwnerCheckAllVersions returns only when all TiDB schemas are synced(exclude the isolated TiDB). - err := d.schemaSyncer.OwnerCheckAllVersions(d.ctx, job.ID, latestSchemaVersion) - if err != nil { - logutil.Logger(d.ctx).Info("wait latest schema version encounter error", zap.String("category", "ddl"), zap.Int64("ver", latestSchemaVersion), - zap.Int64("jobID", job.ID), zap.Duration("take time", time.Since(timeStart)), zap.Error(err)) - return err - } - logutil.Logger(d.ctx).Info("wait latest schema version changed(get the metadata lock if tidb_enable_metadata_lock is true)", zap.String("category", "ddl"), - zap.Int64("ver", latestSchemaVersion), - zap.Duration("take time", time.Since(timeStart)), - zap.String("job", job.String())) - return nil -} - // waitSchemaSyncedForMDL likes waitSchemaSynced, but it waits for getting the metadata lock of the latest version of this DDL. func waitSchemaSyncedForMDL(d *ddlCtx, job *model.Job, latestSchemaVersion int64) error { timeStart := time.Now() return checkAllVersions(d, job, latestSchemaVersion, timeStart) } -// waitSchemaSynced handles the following situation: -// If the job enters a new state, and the worker crashs when it's in the process of waiting for 2 * lease time, -// Then the worker restarts quickly, we may run the job immediately again, -// but in this case we don't wait enough 2 * lease time to let other servers update the schema. -// So here we get the latest schema version to make sure all servers' schema version update to the latest schema version -// in a cluster, or to wait for 2 * lease time. -func waitSchemaSynced(d *ddlCtx, job *model.Job, waitTime time.Duration) error { - if !job.IsRunning() && !job.IsRollingback() && !job.IsDone() && !job.IsRollbackDone() { - return nil - } - - ver, _ := d.store.CurrentVersion(kv.GlobalTxnScope) - snapshot := d.store.GetSnapshot(ver) - m := meta.NewSnapshotMeta(snapshot) - latestSchemaVersion, err := m.GetSchemaVersionWithNonEmptyDiff() - if err != nil { - logutil.Logger(d.ctx).Warn("get global version failed", zap.String("category", "ddl"), zap.Int64("jobID", job.ID), zap.Error(err)) - return err - } - - failpoint.Inject("checkDownBeforeUpdateGlobalVersion", func(val failpoint.Value) { - if val.(bool) { - if mockDDLErrOnce > 0 && mockDDLErrOnce != latestSchemaVersion { - panic("check down before update global version failed") - } - mockDDLErrOnce = -1 - } - }) - - return waitSchemaChanged(d, waitTime, latestSchemaVersion, job) -} - func buildPlacementAffects(oldIDs []int64, newIDs []int64) []*model.AffectedOption { if len(oldIDs) == 0 { return nil @@ -1478,254 +1422,3 @@ func buildPlacementAffects(oldIDs []int64, newIDs []int64) []*model.AffectedOpti } return affects } - -// updateSchemaVersion increments the schema version by 1 and sets SchemaDiff. -func updateSchemaVersion(d *ddlCtx, t *meta.Meta, job *model.Job, multiInfos ...schemaIDAndTableInfo) (int64, error) { - schemaVersion, err := d.setSchemaVersion(job, d.store) - if err != nil { - return 0, errors.Trace(err) - } - diff := &model.SchemaDiff{ - Version: schemaVersion, - Type: job.Type, - SchemaID: job.SchemaID, - } - switch job.Type { - case model.ActionCreateTables: - var tableInfos []*model.TableInfo - err = job.DecodeArgs(&tableInfos) - if err != nil { - return 0, errors.Trace(err) - } - diff.AffectedOpts = make([]*model.AffectedOption, len(tableInfos)) - for i := range tableInfos { - diff.AffectedOpts[i] = &model.AffectedOption{ - SchemaID: job.SchemaID, - OldSchemaID: job.SchemaID, - TableID: tableInfos[i].ID, - OldTableID: tableInfos[i].ID, - } - } - case model.ActionTruncateTable: - // Truncate table has two table ID, should be handled differently. - err = job.DecodeArgs(&diff.TableID) - if err != nil { - return 0, errors.Trace(err) - } - diff.OldTableID = job.TableID - - // affects are used to update placement rule cache - if len(job.CtxVars) > 0 { - oldIDs := job.CtxVars[0].([]int64) - newIDs := job.CtxVars[1].([]int64) - diff.AffectedOpts = buildPlacementAffects(oldIDs, newIDs) - } - case model.ActionCreateView: - tbInfo := &model.TableInfo{} - var orReplace bool - var oldTbInfoID int64 - if err := job.DecodeArgs(tbInfo, &orReplace, &oldTbInfoID); err != nil { - return 0, errors.Trace(err) - } - // When the statement is "create or replace view " and we need to drop the old view, - // it has two table IDs and should be handled differently. - if oldTbInfoID > 0 && orReplace { - diff.OldTableID = oldTbInfoID - } - diff.TableID = tbInfo.ID - case model.ActionRenameTable: - err = job.DecodeArgs(&diff.OldSchemaID) - if err != nil { - return 0, errors.Trace(err) - } - diff.TableID = job.TableID - case model.ActionRenameTables: - var ( - oldSchemaIDs, newSchemaIDs, tableIDs []int64 - tableNames, oldSchemaNames []*model.CIStr - ) - err = job.DecodeArgs(&oldSchemaIDs, &newSchemaIDs, &tableNames, &tableIDs, &oldSchemaNames) - if err != nil { - return 0, errors.Trace(err) - } - affects := make([]*model.AffectedOption, len(newSchemaIDs)-1) - for i, newSchemaID := range newSchemaIDs { - // Do not add the first table to AffectedOpts. Related issue tidb#47064. - if i == 0 { - continue - } - affects[i-1] = &model.AffectedOption{ - SchemaID: newSchemaID, - TableID: tableIDs[i], - OldTableID: tableIDs[i], - OldSchemaID: oldSchemaIDs[i], - } - } - diff.TableID = tableIDs[0] - diff.SchemaID = newSchemaIDs[0] - diff.OldSchemaID = oldSchemaIDs[0] - diff.AffectedOpts = affects - case model.ActionExchangeTablePartition: - // From start of function: diff.SchemaID = job.SchemaID - // Old is original non partitioned table - diff.OldTableID = job.TableID - diff.OldSchemaID = job.SchemaID - // Update the partitioned table (it is only done in the last state) - var ( - ptSchemaID int64 - ptTableID int64 - ptDefID int64 - partName string // Not used - withValidation bool // Not used - ) - // See ddl.ExchangeTablePartition - err = job.DecodeArgs(&ptDefID, &ptSchemaID, &ptTableID, &partName, &withValidation) - if err != nil { - return 0, errors.Trace(err) - } - // This is needed for not crashing TiFlash! - // TODO: Update TiFlash, to handle StateWriteOnly - diff.AffectedOpts = []*model.AffectedOption{{ - TableID: ptTableID, - }} - if job.SchemaState != model.StatePublic { - // No change, just to refresh the non-partitioned table - // with its new ExchangePartitionInfo. - diff.TableID = job.TableID - // Keep this as Schema ID of non-partitioned table - // to avoid trigger early rename in TiFlash - diff.AffectedOpts[0].SchemaID = job.SchemaID - // Need reload partition table, use diff.AffectedOpts[0].OldSchemaID to mark it. - if len(multiInfos) > 0 { - diff.AffectedOpts[0].OldSchemaID = ptSchemaID - } - } else { - // Swap - diff.TableID = ptDefID - // Also add correct SchemaID in case different schemas - diff.AffectedOpts[0].SchemaID = ptSchemaID - } - case model.ActionTruncateTablePartition: - diff.TableID = job.TableID - if len(job.CtxVars) > 0 { - oldIDs := job.CtxVars[0].([]int64) - newIDs := job.CtxVars[1].([]int64) - diff.AffectedOpts = buildPlacementAffects(oldIDs, newIDs) - } - case model.ActionDropTablePartition, model.ActionRecoverTable, model.ActionDropTable: - // affects are used to update placement rule cache - diff.TableID = job.TableID - if len(job.CtxVars) > 0 { - if oldIDs, ok := job.CtxVars[0].([]int64); ok { - diff.AffectedOpts = buildPlacementAffects(oldIDs, oldIDs) - } - } - case model.ActionReorganizePartition: - diff.TableID = job.TableID - // TODO: should this be for every state of Reorganize? - if len(job.CtxVars) > 0 { - if droppedIDs, ok := job.CtxVars[0].([]int64); ok { - if addedIDs, ok := job.CtxVars[1].([]int64); ok { - // to use AffectedOpts we need both new and old to have the same length - maxParts := mathutil.Max[int](len(droppedIDs), len(addedIDs)) - // Also initialize them to 0! - oldIDs := make([]int64, maxParts) - copy(oldIDs, droppedIDs) - newIDs := make([]int64, maxParts) - copy(newIDs, addedIDs) - diff.AffectedOpts = buildPlacementAffects(oldIDs, newIDs) - } - } - } - case model.ActionRemovePartitioning, model.ActionAlterTablePartitioning: - diff.TableID = job.TableID - diff.OldTableID = job.TableID - if job.SchemaState == model.StateDeleteReorganization { - partInfo := &model.PartitionInfo{} - var partNames []string - err = job.DecodeArgs(&partNames, &partInfo) - if err != nil { - return 0, errors.Trace(err) - } - // Final part, new table id is assigned - diff.TableID = partInfo.NewTableID - if len(job.CtxVars) > 0 { - if droppedIDs, ok := job.CtxVars[0].([]int64); ok { - if addedIDs, ok := job.CtxVars[1].([]int64); ok { - // to use AffectedOpts we need both new and old to have the same length - maxParts := mathutil.Max[int](len(droppedIDs), len(addedIDs)) - // Also initialize them to 0! - oldIDs := make([]int64, maxParts) - copy(oldIDs, droppedIDs) - newIDs := make([]int64, maxParts) - copy(newIDs, addedIDs) - diff.AffectedOpts = buildPlacementAffects(oldIDs, newIDs) - } - } - } - } - case model.ActionCreateTable: - diff.TableID = job.TableID - if len(job.Args) > 0 { - tbInfo, _ := job.Args[0].(*model.TableInfo) - // When create table with foreign key, there are two schema status change: - // 1. none -> write-only - // 2. write-only -> public - // In the second status change write-only -> public, infoschema loader should apply drop old table first, then - // apply create new table. So need to set diff.OldTableID here to make sure it. - if tbInfo != nil && tbInfo.State == model.StatePublic && len(tbInfo.ForeignKeys) > 0 { - diff.OldTableID = job.TableID - } - } - case model.ActionRecoverSchema: - var ( - recoverSchemaInfo *RecoverSchemaInfo - recoverSchemaCheckFlag int64 - ) - err = job.DecodeArgs(&recoverSchemaInfo, &recoverSchemaCheckFlag) - if err != nil { - return 0, errors.Trace(err) - } - // Reserved recoverSchemaCheckFlag value for gc work judgment. - job.Args[checkFlagIndexInJobArgs] = recoverSchemaCheckFlag - recoverTabsInfo := recoverSchemaInfo.RecoverTabsInfo - diff.AffectedOpts = make([]*model.AffectedOption, len(recoverTabsInfo)) - for i := range recoverTabsInfo { - diff.AffectedOpts[i] = &model.AffectedOption{ - SchemaID: job.SchemaID, - OldSchemaID: job.SchemaID, - TableID: recoverTabsInfo[i].TableInfo.ID, - OldTableID: recoverTabsInfo[i].TableInfo.ID, - } - } - case model.ActionFlashbackCluster: - diff.TableID = -1 - if job.SchemaState == model.StatePublic { - diff.RegenerateSchemaMap = true - } - default: - diff.TableID = job.TableID - } - if len(multiInfos) > 0 { - existsMap := make(map[int64]struct{}) - existsMap[diff.TableID] = struct{}{} - for _, affect := range diff.AffectedOpts { - existsMap[affect.TableID] = struct{}{} - } - for _, info := range multiInfos { - _, exist := existsMap[info.tblInfo.ID] - if exist { - continue - } - existsMap[info.tblInfo.ID] = struct{}{} - diff.AffectedOpts = append(diff.AffectedOpts, &model.AffectedOption{ - SchemaID: info.schemaID, - OldSchemaID: info.schemaID, - TableID: info.tblInfo.ID, - OldTableID: info.tblInfo.ID, - }) - } - } - err = t.SetSchemaDiff(diff) - return schemaVersion, errors.Trace(err) -} diff --git a/pkg/ddl/schema_version.go b/pkg/ddl/schema_version.go new file mode 100644 index 0000000000000..760c1d3b3d888 --- /dev/null +++ b/pkg/ddl/schema_version.go @@ -0,0 +1,416 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ddl + +import ( + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/meta" + "github.com/pingcap/tidb/pkg/parser/model" + "github.com/pingcap/tidb/pkg/util/logutil" + "github.com/pingcap/tidb/pkg/util/mathutil" + "go.uber.org/zap" +) + +// SetSchemaDiffForCreateTables set SchemaDiff for ActionCreateTables. +func SetSchemaDiffForCreateTables(diff *model.SchemaDiff, job *model.Job) error { + var tableInfos []*model.TableInfo + err := job.DecodeArgs(&tableInfos) + if err != nil { + return errors.Trace(err) + } + diff.AffectedOpts = make([]*model.AffectedOption, len(tableInfos)) + for i := range tableInfos { + diff.AffectedOpts[i] = &model.AffectedOption{ + SchemaID: job.SchemaID, + OldSchemaID: job.SchemaID, + TableID: tableInfos[i].ID, + OldTableID: tableInfos[i].ID, + } + } + return nil +} + +// SetSchemaDiffForTruncateTable set SchemaDiff for ActionTruncateTable. +func SetSchemaDiffForTruncateTable(diff *model.SchemaDiff, job *model.Job) error { + // Truncate table has two table ID, should be handled differently. + err := job.DecodeArgs(&diff.TableID) + if err != nil { + return errors.Trace(err) + } + diff.OldTableID = job.TableID + + // affects are used to update placement rule cache + if len(job.CtxVars) > 0 { + oldIDs := job.CtxVars[0].([]int64) + newIDs := job.CtxVars[1].([]int64) + diff.AffectedOpts = buildPlacementAffects(oldIDs, newIDs) + } + return nil +} + +// SetSchemaDiffForCreateView set SchemaDiff for ActionCreateView. +func SetSchemaDiffForCreateView(diff *model.SchemaDiff, job *model.Job) error { + tbInfo := &model.TableInfo{} + var orReplace bool + var oldTbInfoID int64 + if err := job.DecodeArgs(tbInfo, &orReplace, &oldTbInfoID); err != nil { + return errors.Trace(err) + } + // When the statement is "create or replace view " and we need to drop the old view, + // it has two table IDs and should be handled differently. + if oldTbInfoID > 0 && orReplace { + diff.OldTableID = oldTbInfoID + } + diff.TableID = tbInfo.ID + return nil +} + +// SetSchemaDiffForRenameTable set SchemaDiff for ActionRenameTable. +func SetSchemaDiffForRenameTable(diff *model.SchemaDiff, job *model.Job) error { + err := job.DecodeArgs(&diff.OldSchemaID) + if err != nil { + return errors.Trace(err) + } + diff.TableID = job.TableID + return nil +} + +// SetSchemaDiffForRenameTables set SchemaDiff for ActionRenameTables. +func SetSchemaDiffForRenameTables(diff *model.SchemaDiff, job *model.Job) error { + var ( + oldSchemaIDs, newSchemaIDs, tableIDs []int64 + tableNames, oldSchemaNames []*model.CIStr + ) + err := job.DecodeArgs(&oldSchemaIDs, &newSchemaIDs, &tableNames, &tableIDs, &oldSchemaNames) + if err != nil { + return errors.Trace(err) + } + affects := make([]*model.AffectedOption, len(newSchemaIDs)-1) + for i, newSchemaID := range newSchemaIDs { + // Do not add the first table to AffectedOpts. Related issue tidb#47064. + if i == 0 { + continue + } + affects[i-1] = &model.AffectedOption{ + SchemaID: newSchemaID, + TableID: tableIDs[i], + OldTableID: tableIDs[i], + OldSchemaID: oldSchemaIDs[i], + } + } + diff.TableID = tableIDs[0] + diff.SchemaID = newSchemaIDs[0] + diff.OldSchemaID = oldSchemaIDs[0] + diff.AffectedOpts = affects + return nil +} + +// SetSchemaDiffForExchangeTablePartition set SchemaDiff for ActionExchangeTablePartition. +func SetSchemaDiffForExchangeTablePartition(diff *model.SchemaDiff, job *model.Job, multiInfos ...schemaIDAndTableInfo) error { + // From start of function: diff.SchemaID = job.SchemaID + // Old is original non partitioned table + diff.OldTableID = job.TableID + diff.OldSchemaID = job.SchemaID + // Update the partitioned table (it is only done in the last state) + var ( + ptSchemaID int64 + ptTableID int64 + ptDefID int64 + partName string // Not used + withValidation bool // Not used + ) + // See ddl.ExchangeTablePartition + err := job.DecodeArgs(&ptDefID, &ptSchemaID, &ptTableID, &partName, &withValidation) + if err != nil { + return errors.Trace(err) + } + // This is needed for not crashing TiFlash! + // TODO: Update TiFlash, to handle StateWriteOnly + diff.AffectedOpts = []*model.AffectedOption{{ + TableID: ptTableID, + }} + if job.SchemaState != model.StatePublic { + // No change, just to refresh the non-partitioned table + // with its new ExchangePartitionInfo. + diff.TableID = job.TableID + // Keep this as Schema ID of non-partitioned table + // to avoid trigger early rename in TiFlash + diff.AffectedOpts[0].SchemaID = job.SchemaID + // Need reload partition table, use diff.AffectedOpts[0].OldSchemaID to mark it. + if len(multiInfos) > 0 { + diff.AffectedOpts[0].OldSchemaID = ptSchemaID + } + } else { + // Swap + diff.TableID = ptDefID + // Also add correct SchemaID in case different schemas + diff.AffectedOpts[0].SchemaID = ptSchemaID + } + return nil +} + +// SetSchemaDiffForTruncateTablePartition set SchemaDiff for ActionTruncateTablePartition. +func SetSchemaDiffForTruncateTablePartition(diff *model.SchemaDiff, job *model.Job) { + diff.TableID = job.TableID + if len(job.CtxVars) > 0 { + oldIDs := job.CtxVars[0].([]int64) + newIDs := job.CtxVars[1].([]int64) + diff.AffectedOpts = buildPlacementAffects(oldIDs, newIDs) + } +} + +// SetSchemaDiffForDropTable set SchemaDiff for ActionDropTablePartition, ActionRecoverTable, ActionDropTable. +func SetSchemaDiffForDropTable(diff *model.SchemaDiff, job *model.Job) { + // affects are used to update placement rule cache + diff.TableID = job.TableID + if len(job.CtxVars) > 0 { + if oldIDs, ok := job.CtxVars[0].([]int64); ok { + diff.AffectedOpts = buildPlacementAffects(oldIDs, oldIDs) + } + } +} + +// SetSchemaDiffForReorganizePartition set SchemaDiff for ActionReorganizePartition. +func SetSchemaDiffForReorganizePartition(diff *model.SchemaDiff, job *model.Job) { + diff.TableID = job.TableID + // TODO: should this be for every state of Reorganize? + if len(job.CtxVars) > 0 { + if droppedIDs, ok := job.CtxVars[0].([]int64); ok { + if addedIDs, ok := job.CtxVars[1].([]int64); ok { + // to use AffectedOpts we need both new and old to have the same length + maxParts := mathutil.Max[int](len(droppedIDs), len(addedIDs)) + // Also initialize them to 0! + oldIDs := make([]int64, maxParts) + copy(oldIDs, droppedIDs) + newIDs := make([]int64, maxParts) + copy(newIDs, addedIDs) + diff.AffectedOpts = buildPlacementAffects(oldIDs, newIDs) + } + } + } +} + +// SetSchemaDiffForPartitionModify set SchemaDiff for ActionRemovePartitioning, ActionAlterTablePartitioning. +func SetSchemaDiffForPartitionModify(diff *model.SchemaDiff, job *model.Job) error { + diff.TableID = job.TableID + diff.OldTableID = job.TableID + if job.SchemaState == model.StateDeleteReorganization { + partInfo := &model.PartitionInfo{} + var partNames []string + err := job.DecodeArgs(&partNames, &partInfo) + if err != nil { + return errors.Trace(err) + } + // Final part, new table id is assigned + diff.TableID = partInfo.NewTableID + if len(job.CtxVars) > 0 { + if droppedIDs, ok := job.CtxVars[0].([]int64); ok { + if addedIDs, ok := job.CtxVars[1].([]int64); ok { + // to use AffectedOpts we need both new and old to have the same length + maxParts := mathutil.Max[int](len(droppedIDs), len(addedIDs)) + // Also initialize them to 0! + oldIDs := make([]int64, maxParts) + copy(oldIDs, droppedIDs) + newIDs := make([]int64, maxParts) + copy(newIDs, addedIDs) + diff.AffectedOpts = buildPlacementAffects(oldIDs, newIDs) + } + } + } + } + return nil +} + +// SetSchemaDiffForCreateTable set SchemaDiff for ActionCreateTable. +func SetSchemaDiffForCreateTable(diff *model.SchemaDiff, job *model.Job) { + diff.TableID = job.TableID + if len(job.Args) > 0 { + tbInfo, _ := job.Args[0].(*model.TableInfo) + // When create table with foreign key, there are two schema status change: + // 1. none -> write-only + // 2. write-only -> public + // In the second status change write-only -> public, infoschema loader should apply drop old table first, then + // apply create new table. So need to set diff.OldTableID here to make sure it. + if tbInfo != nil && tbInfo.State == model.StatePublic && len(tbInfo.ForeignKeys) > 0 { + diff.OldTableID = job.TableID + } + } +} + +// SetSchemaDiffForRecoverSchema set SchemaDiff for ActionRecoverSchema. +func SetSchemaDiffForRecoverSchema(diff *model.SchemaDiff, job *model.Job) error { + var ( + recoverSchemaInfo *RecoverSchemaInfo + recoverSchemaCheckFlag int64 + ) + err := job.DecodeArgs(&recoverSchemaInfo, &recoverSchemaCheckFlag) + if err != nil { + return errors.Trace(err) + } + // Reserved recoverSchemaCheckFlag value for gc work judgment. + job.Args[checkFlagIndexInJobArgs] = recoverSchemaCheckFlag + recoverTabsInfo := recoverSchemaInfo.RecoverTabsInfo + diff.AffectedOpts = make([]*model.AffectedOption, len(recoverTabsInfo)) + for i := range recoverTabsInfo { + diff.AffectedOpts[i] = &model.AffectedOption{ + SchemaID: job.SchemaID, + OldSchemaID: job.SchemaID, + TableID: recoverTabsInfo[i].TableInfo.ID, + OldTableID: recoverTabsInfo[i].TableInfo.ID, + } + } + return nil +} + +// SetSchemaDiffForFlashbackCluster set SchemaDiff for ActionFlashbackCluster. +func SetSchemaDiffForFlashbackCluster(diff *model.SchemaDiff, job *model.Job) { + diff.TableID = -1 + if job.SchemaState == model.StatePublic { + diff.RegenerateSchemaMap = true + } +} + +// SetSchemaDiffForMultiInfos set SchemaDiff for multiInfos. +func SetSchemaDiffForMultiInfos(diff *model.SchemaDiff, multiInfos ...schemaIDAndTableInfo) { + if len(multiInfos) > 0 { + existsMap := make(map[int64]struct{}) + existsMap[diff.TableID] = struct{}{} + for _, affect := range diff.AffectedOpts { + existsMap[affect.TableID] = struct{}{} + } + for _, info := range multiInfos { + _, exist := existsMap[info.tblInfo.ID] + if exist { + continue + } + existsMap[info.tblInfo.ID] = struct{}{} + diff.AffectedOpts = append(diff.AffectedOpts, &model.AffectedOption{ + SchemaID: info.schemaID, + OldSchemaID: info.schemaID, + TableID: info.tblInfo.ID, + OldTableID: info.tblInfo.ID, + }) + } + } +} + +// updateSchemaVersion increments the schema version by 1 and sets SchemaDiff. +func updateSchemaVersion(d *ddlCtx, t *meta.Meta, job *model.Job, multiInfos ...schemaIDAndTableInfo) (int64, error) { + schemaVersion, err := d.setSchemaVersion(job, d.store) + if err != nil { + return 0, errors.Trace(err) + } + diff := &model.SchemaDiff{ + Version: schemaVersion, + Type: job.Type, + SchemaID: job.SchemaID, + } + switch job.Type { + case model.ActionCreateTables: + err = SetSchemaDiffForCreateTables(diff, job) + case model.ActionTruncateTable: + err = SetSchemaDiffForTruncateTable(diff, job) + case model.ActionCreateView: + err = SetSchemaDiffForCreateView(diff, job) + case model.ActionRenameTable: + err = SetSchemaDiffForRenameTable(diff, job) + case model.ActionRenameTables: + err = SetSchemaDiffForRenameTables(diff, job) + case model.ActionExchangeTablePartition: + err = SetSchemaDiffForExchangeTablePartition(diff, job, multiInfos...) + case model.ActionTruncateTablePartition: + SetSchemaDiffForTruncateTablePartition(diff, job) + case model.ActionDropTablePartition, model.ActionRecoverTable, model.ActionDropTable: + SetSchemaDiffForDropTable(diff, job) + case model.ActionReorganizePartition: + SetSchemaDiffForReorganizePartition(diff, job) + case model.ActionRemovePartitioning, model.ActionAlterTablePartitioning: + err = SetSchemaDiffForPartitionModify(diff, job) + case model.ActionCreateTable: + SetSchemaDiffForCreateTable(diff, job) + case model.ActionRecoverSchema: + err = SetSchemaDiffForRecoverSchema(diff, job) + case model.ActionFlashbackCluster: + SetSchemaDiffForFlashbackCluster(diff, job) + default: + diff.TableID = job.TableID + } + if err != nil { + return 0, err + } + SetSchemaDiffForMultiInfos(diff, multiInfos...) + err = t.SetSchemaDiff(diff) + return schemaVersion, errors.Trace(err) +} + +func checkAllVersions(d *ddlCtx, job *model.Job, latestSchemaVersion int64, timeStart time.Time) error { + failpoint.Inject("checkDownBeforeUpdateGlobalVersion", func(val failpoint.Value) { + if val.(bool) { + if mockDDLErrOnce > 0 && mockDDLErrOnce != latestSchemaVersion { + panic("check down before update global version failed") + } + mockDDLErrOnce = -1 + } + }) + + // OwnerCheckAllVersions returns only when all TiDB schemas are synced(exclude the isolated TiDB). + err := d.schemaSyncer.OwnerCheckAllVersions(d.ctx, job.ID, latestSchemaVersion) + if err != nil { + logutil.Logger(d.ctx).Info("wait latest schema version encounter error", zap.String("category", "ddl"), zap.Int64("ver", latestSchemaVersion), + zap.Int64("jobID", job.ID), zap.Duration("take time", time.Since(timeStart)), zap.Error(err)) + return err + } + logutil.Logger(d.ctx).Info("wait latest schema version changed(get the metadata lock if tidb_enable_metadata_lock is true)", zap.String("category", "ddl"), + zap.Int64("ver", latestSchemaVersion), + zap.Duration("take time", time.Since(timeStart)), + zap.String("job", job.String())) + return nil +} + +// waitSchemaSynced handles the following situation: +// If the job enters a new state, and the worker crashs when it's in the process of waiting for 2 * lease time, +// Then the worker restarts quickly, we may run the job immediately again, +// but in this case we don't wait enough 2 * lease time to let other servers update the schema. +// So here we get the latest schema version to make sure all servers' schema version update to the latest schema version +// in a cluster, or to wait for 2 * lease time. +func waitSchemaSynced(d *ddlCtx, job *model.Job, waitTime time.Duration) error { + if !job.IsRunning() && !job.IsRollingback() && !job.IsDone() && !job.IsRollbackDone() { + return nil + } + + ver, _ := d.store.CurrentVersion(kv.GlobalTxnScope) + snapshot := d.store.GetSnapshot(ver) + m := meta.NewSnapshotMeta(snapshot) + latestSchemaVersion, err := m.GetSchemaVersionWithNonEmptyDiff() + if err != nil { + logutil.Logger(d.ctx).Warn("get global version failed", zap.String("category", "ddl"), zap.Int64("jobID", job.ID), zap.Error(err)) + return err + } + + failpoint.Inject("checkDownBeforeUpdateGlobalVersion", func(val failpoint.Value) { + if val.(bool) { + if mockDDLErrOnce > 0 && mockDDLErrOnce != latestSchemaVersion { + panic("check down before update global version failed") + } + mockDDLErrOnce = -1 + } + }) + + return waitSchemaChanged(d, waitTime, latestSchemaVersion, job) +} diff --git a/pkg/ddl/syncer/syncer.go b/pkg/ddl/syncer/syncer.go index d47d284b436a1..4ef458d6b4fcf 100644 --- a/pkg/ddl/syncer/syncer.go +++ b/pkg/ddl/syncer/syncer.go @@ -338,8 +338,8 @@ func (s *schemaVersionSyncer) OwnerCheckAllVersions(ctx context.Context, jobID i key := string(kv.Key) tidbIDInResp := key[strings.LastIndex(key, "/")+1:] // We need to check if the tidb ID is in the updatedMap, in case that deleting etcd is failed, and tidb server is down. - isUpdated := updatedMap[tidbIDInResp] != "" - succ = isUpdatedLatestVersion(string(kv.Key), string(kv.Value), latestVer, notMatchVerCnt, intervalCnt, isUpdated) + nodeAlive := updatedMap[tidbIDInResp] != "" + succ = isUpdatedLatestVersion(string(kv.Key), string(kv.Value), latestVer, notMatchVerCnt, intervalCnt, nodeAlive) if !succ { break } @@ -375,14 +375,14 @@ func (s *schemaVersionSyncer) OwnerCheckAllVersions(ctx context.Context, jobID i } } -func isUpdatedLatestVersion(key, val string, latestVer int64, notMatchVerCnt, intervalCnt int, isUpdated bool) bool { +func isUpdatedLatestVersion(key, val string, latestVer int64, notMatchVerCnt, intervalCnt int, nodeAlive bool) bool { ver, err := strconv.Atoi(val) if err != nil { logutil.BgLogger().Info("syncer check all versions, convert value to int failed, continue checking.", zap.String("category", "ddl"), zap.String("ddl", key), zap.String("value", val), zap.Error(err)) return false } - if int64(ver) < latestVer && isUpdated { + if int64(ver) < latestVer && nodeAlive { if notMatchVerCnt%intervalCnt == 0 { logutil.BgLogger().Info("syncer check all versions, someone is not synced, continue checking", zap.String("category", "ddl"), zap.String("ddl", key), zap.Int("currentVer", ver), zap.Int64("latestVer", latestVer)) diff --git a/pkg/domain/domain.go b/pkg/domain/domain.go index 81b7bda7a7948..0eb177f35cdd3 100644 --- a/pkg/domain/domain.go +++ b/pkg/domain/domain.go @@ -797,7 +797,8 @@ func (do *Domain) refreshMDLCheckTableInfo() { defer do.sysSessionPool.Put(se) exec := sctx.GetRestrictedSQLExecutor() domainSchemaVer := do.InfoSchema().SchemaMetaVersion() - rows, _, err := exec.ExecRestrictedSQL(kv.WithInternalSourceType(context.Background(), kv.InternalTxnMeta), nil, fmt.Sprintf("select job_id, version, table_ids from mysql.tidb_mdl_info where version <= %d", domainSchemaVer)) + rows, _, err := exec.ExecRestrictedSQL(kv.WithInternalSourceType(context.Background(), kv.InternalTxnMeta), nil, + fmt.Sprintf("select job_id, version, table_ids from mysql.tidb_mdl_info where version <= %d", domainSchemaVer)) if err != nil { logutil.BgLogger().Warn("get mdl info from tidb_mdl_info failed", zap.Error(err)) return diff --git a/pkg/executor/analyze.go b/pkg/executor/analyze.go index 14e5f0cb45d91..2f87ce7116ca5 100644 --- a/pkg/executor/analyze.go +++ b/pkg/executor/analyze.go @@ -76,10 +76,6 @@ var ( MaxRegionSampleSize = int64(1000) ) -const ( - maxSketchSize = 10000 -) - type taskType int const ( diff --git a/pkg/executor/analyze_col.go b/pkg/executor/analyze_col.go index 2ba5272041ac7..3d699b5e26af5 100644 --- a/pkg/executor/analyze_col.go +++ b/pkg/executor/analyze_col.go @@ -157,7 +157,7 @@ func (e *AnalyzeColumnsExec) buildStats(ranges []*ranger.Range, needExtStats boo handleHist = &statistics.Histogram{} handleCms = statistics.NewCMSketch(int32(e.opts[ast.AnalyzeOptCMSketchDepth]), int32(e.opts[ast.AnalyzeOptCMSketchWidth])) handleTopn = statistics.NewTopN(int(e.opts[ast.AnalyzeOptNumTopN])) - handleFms = statistics.NewFMSketch(maxSketchSize) + handleFms = statistics.NewFMSketch(statistics.MaxSketchSize) if e.analyzePB.IdxReq.Version != nil { statsVer = int(*e.analyzePB.IdxReq.Version) } @@ -167,7 +167,7 @@ func (e *AnalyzeColumnsExec) buildStats(ranges []*ranger.Range, needExtStats boo for i := range collectors { collectors[i] = &statistics.SampleCollector{ IsMerger: true, - FMSketch: statistics.NewFMSketch(maxSketchSize), + FMSketch: statistics.NewFMSketch(statistics.MaxSketchSize), MaxSampleSize: int64(e.opts[ast.AnalyzeOptNumSamples]), CMSketch: statistics.NewCMSketch(int32(e.opts[ast.AnalyzeOptCMSketchDepth]), int32(e.opts[ast.AnalyzeOptCMSketchWidth])), } diff --git a/pkg/executor/analyze_col_v2.go b/pkg/executor/analyze_col_v2.go index eeba3f89bdd22..07f14a5819d30 100644 --- a/pkg/executor/analyze_col_v2.go +++ b/pkg/executor/analyze_col_v2.go @@ -228,7 +228,7 @@ func (e *AnalyzeColumnsExecV2) buildSamplingStats( l := len(e.analyzePB.ColReq.ColumnsInfo) + len(e.analyzePB.ColReq.ColumnGroups) rootRowCollector := statistics.NewRowSampleCollector(int(e.analyzePB.ColReq.SampleSize), e.analyzePB.ColReq.GetSampleRate(), l) for i := 0; i < l; i++ { - rootRowCollector.Base().FMSketches = append(rootRowCollector.Base().FMSketches, statistics.NewFMSketch(maxSketchSize)) + rootRowCollector.Base().FMSketches = append(rootRowCollector.Base().FMSketches, statistics.NewFMSketch(statistics.MaxSketchSize)) } sc := e.ctx.GetSessionVars().StmtCtx @@ -562,7 +562,7 @@ func (e *AnalyzeColumnsExecV2) buildSubIndexJobForSpecialIndex(indexInfos []*mod NumColumns: int32(len(indexInfo.Columns)), TopNSize: &topnSize, Version: statsVersion, - SketchSize: maxSketchSize, + SketchSize: statistics.MaxSketchSize, } if idxExec.isCommonHandle && indexInfo.Primary { idxExec.analyzePB.Tp = tipb.AnalyzeType_TypeCommonHandle @@ -621,7 +621,7 @@ func (e *AnalyzeColumnsExecV2) subMergeWorker(resultCh chan<- *samplingMergeResu }) retCollector := statistics.NewRowSampleCollector(int(e.analyzePB.ColReq.SampleSize), e.analyzePB.ColReq.GetSampleRate(), l) for i := 0; i < l; i++ { - retCollector.Base().FMSketches = append(retCollector.Base().FMSketches, statistics.NewFMSketch(maxSketchSize)) + retCollector.Base().FMSketches = append(retCollector.Base().FMSketches, statistics.NewFMSketch(statistics.MaxSketchSize)) } for { data, ok := <-taskCh diff --git a/pkg/executor/analyze_idx.go b/pkg/executor/analyze_idx.go index a5b62957dfab6..a7c0a4b86fb71 100644 --- a/pkg/executor/analyze_idx.go +++ b/pkg/executor/analyze_idx.go @@ -191,7 +191,7 @@ func (e *AnalyzeIndexExec) buildStatsFromResult(result distsql.SelectResult, nee cms = statistics.NewCMSketch(int32(e.opts[ast.AnalyzeOptCMSketchDepth]), int32(e.opts[ast.AnalyzeOptCMSketchWidth])) topn = statistics.NewTopN(int(e.opts[ast.AnalyzeOptNumTopN])) } - fms := statistics.NewFMSketch(maxSketchSize) + fms := statistics.NewFMSketch(statistics.MaxSketchSize) statsVer := statistics.Version1 if e.analyzePB.IdxReq.Version != nil { statsVer = int(*e.analyzePB.IdxReq.Version) diff --git a/pkg/executor/builder.go b/pkg/executor/builder.go index 2ffe47d50995e..210c801585913 100644 --- a/pkg/executor/builder.go +++ b/pkg/executor/builder.go @@ -2507,7 +2507,7 @@ func (b *executorBuilder) buildAnalyzeIndexPushdown(task plannercore.AnalyzeInde NumColumns: int32(len(task.IndexInfo.Columns)), TopNSize: topNSize, Version: statsVersion, - SketchSize: maxSketchSize, + SketchSize: statistics.MaxSketchSize, } if e.isCommonHandle && e.idxInfo.Primary { e.analyzePB.Tp = tipb.AnalyzeType_TypeCommonHandle @@ -2625,7 +2625,7 @@ func (b *executorBuilder) buildAnalyzeSamplingPushdown( BucketSize: int64(opts[ast.AnalyzeOptNumBuckets]), SampleSize: int64(opts[ast.AnalyzeOptNumSamples]), SampleRate: sampleRate, - SketchSize: maxSketchSize, + SketchSize: statistics.MaxSketchSize, ColumnsInfo: util.ColumnsToProto(task.ColsInfo, task.TblInfo.PKIsHandle, false), ColumnGroups: colGroups, } @@ -2755,7 +2755,7 @@ func (b *executorBuilder) buildAnalyzeColumnsPushdown( e.analyzePB.ColReq = &tipb.AnalyzeColumnsReq{ BucketSize: int64(opts[ast.AnalyzeOptNumBuckets]), SampleSize: MaxRegionSampleSize, - SketchSize: maxSketchSize, + SketchSize: statistics.MaxSketchSize, ColumnsInfo: util.ColumnsToProto(cols, task.HandleCols != nil && task.HandleCols.IsInt(), false), CmsketchDepth: &depth, CmsketchWidth: &width, @@ -2781,7 +2781,7 @@ func (b *executorBuilder) buildAnalyzeColumnsPushdown( width := int32(opts[ast.AnalyzeOptCMSketchWidth]) e.analyzePB.IdxReq.CmsketchDepth = &depth e.analyzePB.IdxReq.CmsketchWidth = &width - e.analyzePB.IdxReq.SketchSize = maxSketchSize + e.analyzePB.IdxReq.SketchSize = statistics.MaxSketchSize e.analyzePB.ColReq.PrimaryColumnIds = tables.TryGetCommonPkColumnIds(task.TblInfo) e.analyzePB.Tp = tipb.AnalyzeType_TypeMixed e.commonHandle = task.CommonHandleInfo diff --git a/pkg/executor/importer/table_import.go b/pkg/executor/importer/table_import.go index 3fb097261363c..1e5d1ad2b14e6 100644 --- a/pkg/executor/importer/table_import.go +++ b/pkg/executor/importer/table_import.go @@ -966,6 +966,10 @@ func GetImportRootDir(tidbCfg *tidb.Config) string { } // FlushTableStats flushes the stats of the table. +// stats will be flushed in domain.updateStatsWorker, default interval is [1, 2) minutes, +// see DumpStatsDeltaToKV for more details. then the background analyzer will analyze +// the table. +// the stats stay in memory until the next flush, so it might be lost if the tidb-server restarts. func FlushTableStats(ctx context.Context, se sessionctx.Context, tableID int64, result *JobImportResult) error { if err := sessiontxn.NewTxn(ctx, se); err != nil { return err diff --git a/pkg/executor/test/oomtest/oom_test.go b/pkg/executor/test/oomtest/oom_test.go index 203de6af4b084..30cb654841457 100644 --- a/pkg/executor/test/oomtest/oom_test.go +++ b/pkg/executor/test/oomtest/oom_test.go @@ -56,12 +56,9 @@ func TestMemTracker4UpdateExec(t *testing.T) { oom.SetTracker("") oom.ClearMessageFilter() - oom.AddMessageFilter( - "expensive_query during bootstrap phase", - "schemaLeaseChecker is not set for this transaction") + oom.AddMessageFilter("expensive_query during bootstrap phase") tk.MustExec("insert into t_MemTracker4UpdateExec values (1,1,1), (2,2,2), (3,3,3)") - require.Equal(t, "schemaLeaseChecker is not set for this transaction", oom.GetTracker()) tk.Session().GetSessionVars().MemQuotaQuery = 244 tk.MustExec("update t_MemTracker4UpdateExec set a = 4") @@ -81,12 +78,9 @@ func TestMemTracker4InsertAndReplaceExec(t *testing.T) { log.SetLevel(zap.InfoLevel) oom.SetTracker("") - oom.AddMessageFilter( - "schemaLeaseChecker is not set for this transaction", - "expensive_query during bootstrap phase") + oom.AddMessageFilter("expensive_query during bootstrap phase") tk.MustExec("insert into t_MemTracker4InsertAndReplaceExec values (1,1,1), (2,2,2), (3,3,3)") - require.Equal(t, "schemaLeaseChecker is not set for this transaction", oom.GetTracker()) tk.Session().GetSessionVars().MemQuotaQuery = 1 oom.ClearMessageFilter() oom.AddMessageFilter("expensive_query during bootstrap phase") diff --git a/pkg/expression/BUILD.bazel b/pkg/expression/BUILD.bazel index bba4285c6caab..464415390e620 100644 --- a/pkg/expression/BUILD.bazel +++ b/pkg/expression/BUILD.bazel @@ -199,6 +199,7 @@ go_test( "//pkg/config", "//pkg/errctx", "//pkg/errno", + "//pkg/expression/context", "//pkg/kv", "//pkg/parser", "//pkg/parser/ast", diff --git a/pkg/expression/bench_test.go b/pkg/expression/bench_test.go index 1144a68f99a8b..18bf00acca102 100644 --- a/pkg/expression/bench_test.go +++ b/pkg/expression/bench_test.go @@ -1436,11 +1436,11 @@ func genVecBuiltinFuncBenchCase(ctx BuildContext, funcName string, testCase vecE tp := eType2FieldType(testCase.retEvalType) switch testCase.retEvalType { case types.ETInt: - fc = &castAsIntFunctionClass{baseFunctionClass{ast.Cast, 1, 1}, tp} + fc = &castAsIntFunctionClass{baseFunctionClass{ast.Cast, 1, 1}, tp, false} case types.ETDecimal: - fc = &castAsDecimalFunctionClass{baseFunctionClass{ast.Cast, 1, 1}, tp} + fc = &castAsDecimalFunctionClass{baseFunctionClass{ast.Cast, 1, 1}, tp, false} case types.ETReal: - fc = &castAsRealFunctionClass{baseFunctionClass{ast.Cast, 1, 1}, tp} + fc = &castAsRealFunctionClass{baseFunctionClass{ast.Cast, 1, 1}, tp, false} case types.ETDatetime, types.ETTimestamp: fc = &castAsTimeFunctionClass{baseFunctionClass{ast.Cast, 1, 1}, tp} case types.ETDuration: diff --git a/pkg/expression/builtin_cast.go b/pkg/expression/builtin_cast.go index cfdb35bbfc23b..f04e5b096b9a6 100644 --- a/pkg/expression/builtin_cast.go +++ b/pkg/expression/builtin_cast.go @@ -111,7 +111,8 @@ var ( type castAsIntFunctionClass struct { baseFunctionClass - tp *types.FieldType + tp *types.FieldType + inUnion bool } func (c *castAsIntFunctionClass) getFunction(ctx BuildContext, args []Expression) (sig builtinFunc, err error) { @@ -122,7 +123,7 @@ func (c *castAsIntFunctionClass) getFunction(ctx BuildContext, args []Expression if err != nil { return nil, err } - bf := newBaseBuiltinCastFunc(b, ctx.IsInUnionCast()) + bf := newBaseBuiltinCastFunc(b, c.inUnion) if args[0].GetType().Hybrid() || IsBinaryLiteral(args[0]) { sig = &builtinCastIntAsIntSig{bf} sig.setPbCode(tipb.ScalarFuncSig_CastIntAsInt) @@ -160,7 +161,8 @@ func (c *castAsIntFunctionClass) getFunction(ctx BuildContext, args []Expression type castAsRealFunctionClass struct { baseFunctionClass - tp *types.FieldType + tp *types.FieldType + inUnion bool } func (c *castAsRealFunctionClass) getFunction(ctx BuildContext, args []Expression) (sig builtinFunc, err error) { @@ -171,7 +173,7 @@ func (c *castAsRealFunctionClass) getFunction(ctx BuildContext, args []Expressio if err != nil { return nil, err } - bf := newBaseBuiltinCastFunc(b, ctx.IsInUnionCast()) + bf := newBaseBuiltinCastFunc(b, c.inUnion) if IsBinaryLiteral(args[0]) { sig = &builtinCastRealAsRealSig{bf} sig.setPbCode(tipb.ScalarFuncSig_CastRealAsReal) @@ -215,7 +217,8 @@ func (c *castAsRealFunctionClass) getFunction(ctx BuildContext, args []Expressio type castAsDecimalFunctionClass struct { baseFunctionClass - tp *types.FieldType + tp *types.FieldType + inUnion bool } func (c *castAsDecimalFunctionClass) getFunction(ctx BuildContext, args []Expression) (sig builtinFunc, err error) { @@ -226,7 +229,7 @@ func (c *castAsDecimalFunctionClass) getFunction(ctx BuildContext, args []Expres if err != nil { return nil, err } - bf := newBaseBuiltinCastFunc(b, ctx.IsInUnionCast()) + bf := newBaseBuiltinCastFunc(b, c.inUnion) if IsBinaryLiteral(args[0]) { sig = &builtinCastDecimalAsDecimalSig{bf} sig.setPbCode(tipb.ScalarFuncSig_CastDecimalAsDecimal) @@ -2052,11 +2055,9 @@ func CanImplicitEvalReal(expr Expression) bool { // BuildCastFunction4Union build a implicitly CAST ScalarFunction from the Union // Expression. func BuildCastFunction4Union(ctx BuildContext, expr Expression, tp *types.FieldType) (res Expression) { - if !ctx.IsInUnionCast() { - ctx.SetInUnionCast(true) - defer ctx.SetInUnionCast(false) - } - return BuildCastFunction(ctx, expr, tp) + res, err := BuildCastFunctionWithCheck(ctx, expr, tp, true) + terror.Log(err) + return } // BuildCastCollationFunction builds a ScalarFunction which casts the collation. @@ -2091,13 +2092,13 @@ func BuildCastCollationFunction(ctx BuildContext, expr Expression, ec *ExprColla // BuildCastFunction builds a CAST ScalarFunction from the Expression. func BuildCastFunction(ctx BuildContext, expr Expression, tp *types.FieldType) (res Expression) { - res, err := BuildCastFunctionWithCheck(ctx, expr, tp) + res, err := BuildCastFunctionWithCheck(ctx, expr, tp, false) terror.Log(err) return } // BuildCastFunctionWithCheck builds a CAST ScalarFunction from the Expression and return error if any. -func BuildCastFunctionWithCheck(ctx BuildContext, expr Expression, tp *types.FieldType) (res Expression, err error) { +func BuildCastFunctionWithCheck(ctx BuildContext, expr Expression, tp *types.FieldType, inUnion bool) (res Expression, err error) { argType := expr.GetType() // If source argument's nullable, then target type should be nullable if !mysql.HasNotNullFlag(argType.GetFlag()) { @@ -2107,11 +2108,11 @@ func BuildCastFunctionWithCheck(ctx BuildContext, expr Expression, tp *types.Fie var fc functionClass switch tp.EvalType() { case types.ETInt: - fc = &castAsIntFunctionClass{baseFunctionClass{ast.Cast, 1, 1}, tp} + fc = &castAsIntFunctionClass{baseFunctionClass{ast.Cast, 1, 1}, tp, inUnion} case types.ETDecimal: - fc = &castAsDecimalFunctionClass{baseFunctionClass{ast.Cast, 1, 1}, tp} + fc = &castAsDecimalFunctionClass{baseFunctionClass{ast.Cast, 1, 1}, tp, inUnion} case types.ETReal: - fc = &castAsRealFunctionClass{baseFunctionClass{ast.Cast, 1, 1}, tp} + fc = &castAsRealFunctionClass{baseFunctionClass{ast.Cast, 1, 1}, tp, inUnion} case types.ETDatetime, types.ETTimestamp: fc = &castAsTimeFunctionClass{baseFunctionClass{ast.Cast, 1, 1}, tp} case types.ETDuration: diff --git a/pkg/expression/builtin_cast_test.go b/pkg/expression/builtin_cast_test.go index 606151c46f737..97746afa52b07 100644 --- a/pkg/expression/builtin_cast_test.go +++ b/pkg/expression/builtin_cast_test.go @@ -1690,7 +1690,7 @@ func TestCastArrayFunc(t *testing.T) { }, } for _, tt := range tbl { - f, err := BuildCastFunctionWithCheck(ctx, datumsToConstants(types.MakeDatums(types.CreateBinaryJSON(tt.input)))[0], tt.tp) + f, err := BuildCastFunctionWithCheck(ctx, datumsToConstants(types.MakeDatums(types.CreateBinaryJSON(tt.input)))[0], tt.tp, false) if !tt.buildFuncSuccess { require.Error(t, err, tt.input) continue diff --git a/pkg/expression/constant_test.go b/pkg/expression/constant_test.go index be622ed7caf9e..ffe4be536088b 100644 --- a/pkg/expression/constant_test.go +++ b/pkg/expression/constant_test.go @@ -21,6 +21,7 @@ import ( "testing" "time" + exprctx "github.com/pingcap/tidb/pkg/expression/context" "github.com/pingcap/tidb/pkg/parser/ast" "github.com/pingcap/tidb/pkg/parser/mysql" "github.com/pingcap/tidb/pkg/types" @@ -194,8 +195,9 @@ func TestConstantPropagation(t *testing.T) { func TestConstantFolding(t *testing.T) { tests := []struct { - condition func(ctx BuildContext) Expression - result string + condition func(ctx BuildContext) Expression + result string + nullRejectCheck bool }{ { condition: func(ctx BuildContext) Expression { @@ -236,15 +238,20 @@ func TestConstantFolding(t *testing.T) { { condition: func(ctx BuildContext) Expression { expr := newFunction(ctx, ast.ConcatWS, newColumn(0), NewNull()) - ctx.SetInNullRejectCheck(true) return expr }, - result: "concat_ws(cast(Column#0, var_string(20)), )", + nullRejectCheck: true, + result: "concat_ws(cast(Column#0, var_string(20)), )", }, } for _, tt := range tests { - ctx := mock.NewContext() + ctx := mock.NewContext().GetExprCtx() + require.False(t, ctx.IsInNullRejectCheck()) expr := tt.condition(ctx) + if tt.nullRejectCheck { + ctx = exprctx.WithNullRejectCheck(ctx) + require.True(t, ctx.IsInNullRejectCheck()) + } newConds := FoldConstant(ctx, expr) require.Equalf(t, tt.result, newConds.String(), "different for expr %s", tt.condition) } diff --git a/pkg/expression/context/context.go b/pkg/expression/context/context.go index 75bcbca413f21..218d37cc76ef9 100644 --- a/pkg/expression/context/context.go +++ b/pkg/expression/context/context.go @@ -83,14 +83,11 @@ type BuildContext interface { SetSkipPlanCache(reason string) // AllocPlanColumnID allocates column id for plan. AllocPlanColumnID() int64 - // SetInNullRejectCheck sets the flag to indicate whether the expression is in null reject check. - SetInNullRejectCheck(in bool) // IsInNullRejectCheck returns the flag to indicate whether the expression is in null reject check. + // It should always return `false` in most implementations because we do not want to do null reject check + // in most cases except for the method `isNullRejected` in planner. + // See the comments for `isNullRejected` in planner for more details. IsInNullRejectCheck() bool - // SetInUnionCast sets the flag to indicate whether the expression is in union cast. - SetInUnionCast(in bool) - // IsInUnionCast indicates whether executing in special cast context that negative unsigned num will be zero. - IsInUnionCast() bool // ConnectionID indicates the connection ID of the current session. // If the context is not in a session, it should return 0. ConnectionID() uint64 @@ -107,6 +104,21 @@ type ExprContext interface { GetGroupConcatMaxLen() uint64 } +// NullRejectCheckExprContext is a wrapper to return true for `IsInNullRejectCheck`. +type NullRejectCheckExprContext struct { + ExprContext +} + +// WithNullRejectCheck returns a new `NullRejectCheckExprContext` with the given `ExprContext`. +func WithNullRejectCheck(ctx ExprContext) *NullRejectCheckExprContext { + return &NullRejectCheckExprContext{ExprContext: ctx} +} + +// IsInNullRejectCheck always returns true for `NullRejectCheckExprContext` +func (ctx *NullRejectCheckExprContext) IsInNullRejectCheck() bool { + return true +} + // AssertLocationWithSessionVars asserts the location in the context and session variables are the same. // It is only used for testing. func AssertLocationWithSessionVars(ctxLoc *time.Location, vars *variable.SessionVars) { diff --git a/pkg/expression/contextsession/sessionctx.go b/pkg/expression/contextsession/sessionctx.go index c259f3301c122..e60c96a0cc2cc 100644 --- a/pkg/expression/contextsession/sessionctx.go +++ b/pkg/expression/contextsession/sessionctx.go @@ -17,7 +17,6 @@ package contextsession import ( "context" "math" - "sync/atomic" "time" "github.com/pingcap/tidb/pkg/errctx" @@ -52,8 +51,6 @@ var _ exprctx.ExprContext = struct { type ExprCtxExtendedImpl struct { sctx sessionctx.Context *SessionEvalContext - inNullRejectCheck atomic.Bool - inUnionCast atomic.Bool } // NewExprExtendedImpl creates a new ExprCtxExtendedImpl. @@ -117,24 +114,9 @@ func (ctx *ExprCtxExtendedImpl) AllocPlanColumnID() int64 { return ctx.sctx.GetSessionVars().AllocPlanColumnID() } -// SetInNullRejectCheck sets whether the expression is in null reject check. -func (ctx *ExprCtxExtendedImpl) SetInNullRejectCheck(in bool) { - ctx.inNullRejectCheck.Store(in) -} - // IsInNullRejectCheck returns whether the expression is in null reject check. func (ctx *ExprCtxExtendedImpl) IsInNullRejectCheck() bool { - return ctx.inNullRejectCheck.Load() -} - -// SetInUnionCast sets the flag to indicate whether the expression is in union cast. -func (ctx *ExprCtxExtendedImpl) SetInUnionCast(in bool) { - ctx.inUnionCast.Store(in) -} - -// IsInUnionCast indicates whether executing in special cast context that negative unsigned num will be zero. -func (ctx *ExprCtxExtendedImpl) IsInUnionCast() bool { - return ctx.inUnionCast.Load() + return false } // GetWindowingUseHighPrecision determines whether to compute window operations without loss of precision. diff --git a/pkg/expression/contextsession/sessionctx_test.go b/pkg/expression/contextsession/sessionctx_test.go index 993cc67a112cc..323d327d56257 100644 --- a/pkg/expression/contextsession/sessionctx_test.go +++ b/pkg/expression/contextsession/sessionctx_test.go @@ -315,17 +315,6 @@ func TestSessionBuildContext(t *testing.T) { // InNullRejectCheck require.False(t, impl.IsInNullRejectCheck()) - impl.SetInNullRejectCheck(true) - require.True(t, impl.IsInNullRejectCheck()) - impl.SetInNullRejectCheck(false) - require.False(t, impl.IsInNullRejectCheck()) - - // InUnionCast - require.False(t, impl.IsInUnionCast()) - impl.SetInUnionCast(true) - require.True(t, impl.IsInUnionCast()) - impl.SetInUnionCast(false) - require.False(t, impl.IsInUnionCast()) // ConnID vars.ConnectionID = 123 diff --git a/pkg/infoschema/cache.go b/pkg/infoschema/cache.go index f3d4d67e576c6..32740250b1b0d 100644 --- a/pkg/infoschema/cache.go +++ b/pkg/infoschema/cache.go @@ -143,7 +143,10 @@ func (h *InfoCache) getByVersionNoLock(version int64) InfoSchema { return h.cache[i].infoschema.SchemaMetaVersion() <= version }) - // `GetByVersion` is allowed to load the latest schema that is less than argument `version`. + // `GetByVersion` is allowed to load the latest schema that is less than argument + // `version` when the argument `version` <= the latest schema version. + // if `version` > the latest schema version, always return nil, loadInfoSchema + // will use this behavior to decide whether to load schema diffs or full reload. // Consider cache has values [10, 9, _, _, 6, 5, 4, 3, 2, 1], version 8 and 7 is empty because of the diff is empty. // If we want to get version 8, we can return version 6 because v7 and v8 do not change anything, they are totally the same, // in this case the `i` will not be 0. diff --git a/pkg/lightning/backend/local/BUILD.bazel b/pkg/lightning/backend/local/BUILD.bazel index 13eb66eab9019..c297e333d2d7d 100644 --- a/pkg/lightning/backend/local/BUILD.bazel +++ b/pkg/lightning/backend/local/BUILD.bazel @@ -23,6 +23,7 @@ go_library( visibility = ["//visibility:public"], deps = [ "//br/pkg/checksum", + "//br/pkg/errors", "//br/pkg/logutil", "//br/pkg/membuf", "//br/pkg/pdutil", diff --git a/pkg/lightning/backend/local/duplicate.go b/pkg/lightning/backend/local/duplicate.go index 7f72b7b998448..2e93a25abe513 100644 --- a/pkg/lightning/backend/local/duplicate.go +++ b/pkg/lightning/backend/local/duplicate.go @@ -31,6 +31,7 @@ import ( "github.com/pingcap/kvproto/pkg/errorpb" "github.com/pingcap/kvproto/pkg/import_sstpb" "github.com/pingcap/kvproto/pkg/kvrpcpb" + berrors "github.com/pingcap/tidb/br/pkg/errors" "github.com/pingcap/tidb/br/pkg/logutil" "github.com/pingcap/tidb/br/pkg/restore/split" "github.com/pingcap/tidb/pkg/distsql" @@ -312,7 +313,8 @@ func getDupDetectClient( ) (import_sstpb.ImportSST_DuplicateDetectClient, error) { leader := region.Leader if leader == nil { - leader = region.Region.GetPeers()[0] + return nil, errors.Annotatef(berrors.ErrPDLeaderNotFound, + "region id %d has no leader", region.Region.Id) } importClient, err := importClientFactory.Create(ctx, leader.GetStoreId()) if err != nil { diff --git a/pkg/lightning/backend/local/region_job.go b/pkg/lightning/backend/local/region_job.go index db9170e6af6d2..7bc812e4b9bb6 100644 --- a/pkg/lightning/backend/local/region_job.go +++ b/pkg/lightning/backend/local/region_job.go @@ -30,6 +30,7 @@ import ( sst "github.com/pingcap/kvproto/pkg/import_sstpb" "github.com/pingcap/kvproto/pkg/kvrpcpb" "github.com/pingcap/kvproto/pkg/metapb" + berrors "github.com/pingcap/tidb/br/pkg/errors" "github.com/pingcap/tidb/br/pkg/logutil" "github.com/pingcap/tidb/br/pkg/restore/split" "github.com/pingcap/tidb/pkg/kv" @@ -624,7 +625,8 @@ func (local *Backend) doIngest(ctx context.Context, j *regionJob) (*sst.IngestRe leader := j.region.Leader if leader == nil { - leader = j.region.Region.GetPeers()[0] + return nil, errors.Annotatef(berrors.ErrPDLeaderNotFound, + "region id %d has no leader", j.region.Region.Id) } cli, err := clientFactory.Create(ctx, leader.StoreId) diff --git a/pkg/lightning/mydump/BUILD.bazel b/pkg/lightning/mydump/BUILD.bazel index 754f92ae65b54..381da027f3f29 100644 --- a/pkg/lightning/mydump/BUILD.bazel +++ b/pkg/lightning/mydump/BUILD.bazel @@ -13,6 +13,7 @@ go_library( "reader.go", "region.go", "router.go", + "schema_import.go", ], importpath = "github.com/pingcap/tidb/pkg/lightning/mydump", visibility = ["//visibility:public"], @@ -24,11 +25,18 @@ go_library( "//pkg/lightning/log", "//pkg/lightning/metric", "//pkg/lightning/worker", + "//pkg/parser", + "//pkg/parser/ast", + "//pkg/parser/format", + "//pkg/parser/model", "//pkg/parser/mysql", "//pkg/types", + "//pkg/util", "//pkg/util/filter", "//pkg/util/regexpr-router", + "//pkg/util/set", "//pkg/util/slice", + "//pkg/util/sqlescape", "//pkg/util/table-filter", "//pkg/util/zeropool", "@com_github_pingcap_errors//:errors", @@ -59,6 +67,7 @@ go_test( "reader_test.go", "region_test.go", "router_test.go", + "schema_import_test.go", ], data = glob([ "csv/*", @@ -75,12 +84,14 @@ go_test( "//pkg/lightning/config", "//pkg/lightning/log", "//pkg/lightning/worker", + "//pkg/parser", "//pkg/parser/mysql", "//pkg/testkit/testsetup", "//pkg/types", "//pkg/util/filter", "//pkg/util/table-filter", "//pkg/util/table-router", + "@com_github_data_dog_go_sqlmock//:go-sqlmock", "@com_github_pingcap_errors//:errors", "@com_github_stretchr_testify//assert", "@com_github_stretchr_testify//require", diff --git a/pkg/lightning/mydump/loader.go b/pkg/lightning/mydump/loader.go index 30c1e316d9cf8..d2fc407ddcdeb 100644 --- a/pkg/lightning/mydump/loader.go +++ b/pkg/lightning/mydump/loader.go @@ -130,7 +130,7 @@ func (m *MDTableMeta) GetSchema(ctx context.Context, store storage.ExternalStora zap.String("Path", m.SchemaFile.FileMeta.Path), log.ShortError(err), ) - return "", err + return "", errors.Trace(err) } return string(schema), nil } diff --git a/pkg/lightning/mydump/schema_import.go b/pkg/lightning/mydump/schema_import.go new file mode 100644 index 0000000000000..32120f0ce457f --- /dev/null +++ b/pkg/lightning/mydump/schema_import.go @@ -0,0 +1,371 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mydump + +import ( + "context" + "database/sql" + "fmt" + "strings" + + "github.com/pingcap/errors" + "github.com/pingcap/tidb/br/pkg/storage" + "github.com/pingcap/tidb/pkg/lightning/common" + "github.com/pingcap/tidb/pkg/lightning/log" + "github.com/pingcap/tidb/pkg/parser" + "github.com/pingcap/tidb/pkg/parser/ast" + "github.com/pingcap/tidb/pkg/parser/format" + "github.com/pingcap/tidb/pkg/parser/model" + "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/pingcap/tidb/pkg/util" + "github.com/pingcap/tidb/pkg/util/set" + "github.com/pingcap/tidb/pkg/util/sqlescape" + "go.uber.org/zap" +) + +type schemaStmtType int + +// String implements fmt.Stringer interface. +func (stmtType schemaStmtType) String() string { + switch stmtType { + case schemaCreateDatabase: + return "restore database schema" + case schemaCreateTable: + return "restore table schema" + case schemaCreateView: + return "restore view schema" + } + return "unknown statement of schema" +} + +const ( + schemaCreateDatabase schemaStmtType = iota + schemaCreateTable + schemaCreateView +) + +type schemaJob struct { + dbName string + tblName string // empty for create db jobs + stmtType schemaStmtType + sqlStr string +} + +// SchemaImporter is used to import schema from dump files. +type SchemaImporter struct { + logger log.Logger + db *sql.DB + sqlMode mysql.SQLMode + store storage.ExternalStorage + concurrency int +} + +// NewSchemaImporter creates a new SchemaImporter instance. +func NewSchemaImporter(logger log.Logger, sqlMode mysql.SQLMode, db *sql.DB, store storage.ExternalStorage, concurrency int) *SchemaImporter { + return &SchemaImporter{ + logger: logger, + db: db, + sqlMode: sqlMode, + store: store, + concurrency: concurrency, + } +} + +// Run imports all schemas from the given database metas. +func (si *SchemaImporter) Run(ctx context.Context, dbMetas []*MDDatabaseMeta) (err error) { + logTask := si.logger.Begin(zap.InfoLevel, "restore all schema") + defer func() { + logTask.End(zap.ErrorLevel, err) + }() + + if len(dbMetas) == 0 { + return nil + } + + if err = si.importDatabases(ctx, dbMetas); err != nil { + return errors.Trace(err) + } + if err = si.importTables(ctx, dbMetas); err != nil { + return errors.Trace(err) + } + return errors.Trace(si.importViews(ctx, dbMetas)) +} + +func (si *SchemaImporter) importDatabases(ctx context.Context, dbMetas []*MDDatabaseMeta) error { + existingSchemas, err := si.getExistingDatabases(ctx) + if err != nil { + return err + } + + ch := make(chan *MDDatabaseMeta) + eg, egCtx := util.NewErrorGroupWithRecoverWithCtx(ctx) + for i := 0; i < si.concurrency; i++ { + eg.Go(func() error { + p := parser.New() + p.SetSQLMode(si.sqlMode) + for dbMeta := range ch { + sqlStr := dbMeta.GetSchema(egCtx, si.store) + if err2 := si.runJob(egCtx, p, &schemaJob{ + dbName: dbMeta.Name, + stmtType: schemaCreateDatabase, + sqlStr: sqlStr, + }); err2 != nil { + return err2 + } + } + return nil + }) + } + eg.Go(func() error { + defer close(ch) + for i := range dbMetas { + dbMeta := dbMetas[i] + // if downstream already has this database, we can skip ddl job + if existingSchemas.Exist(strings.ToLower(dbMeta.Name)) { + si.logger.Info("database already exists in downstream, skip", + zap.String("db", dbMeta.Name), + ) + continue + } + select { + case ch <- dbMeta: + case <-egCtx.Done(): + } + } + return nil + }) + + return eg.Wait() +} + +func (si *SchemaImporter) importTables(ctx context.Context, dbMetas []*MDDatabaseMeta) error { + ch := make(chan *MDTableMeta) + eg, egCtx := util.NewErrorGroupWithRecoverWithCtx(ctx) + for i := 0; i < si.concurrency; i++ { + eg.Go(func() error { + p := parser.New() + p.SetSQLMode(si.sqlMode) + for tableMeta := range ch { + sqlStr, err := tableMeta.GetSchema(egCtx, si.store) + if err != nil { + return err + } + if err = si.runJob(egCtx, p, &schemaJob{ + dbName: tableMeta.DB, + tblName: tableMeta.Name, + stmtType: schemaCreateTable, + sqlStr: sqlStr, + }); err != nil { + return err + } + } + return nil + }) + } + eg.Go(func() error { + defer close(ch) + for _, dbMeta := range dbMetas { + if len(dbMeta.Tables) == 0 { + continue + } + tables, err := si.getExistingTables(egCtx, dbMeta.Name) + if err != nil { + return err + } + for i := range dbMeta.Tables { + tblMeta := dbMeta.Tables[i] + if tables.Exist(strings.ToLower(tblMeta.Name)) { + // we already has this table in TiDB. + // we should skip ddl job and let SchemaValid check. + si.logger.Info("table already exists in downstream, skip", + zap.String("db", dbMeta.Name), + zap.String("table", tblMeta.Name), + ) + continue + } else if tblMeta.SchemaFile.FileMeta.Path == "" { + return common.ErrSchemaNotExists.GenWithStackByArgs(dbMeta.Name, tblMeta.Name) + } + + select { + case ch <- tblMeta: + case <-egCtx.Done(): + return egCtx.Err() + } + } + } + return nil + }) + + return eg.Wait() +} + +// dumpling dump a view as a table-schema sql file which creates a table of same name +// as the view, and a view-schema sql file which drops the table and creates the view. +func (si *SchemaImporter) importViews(ctx context.Context, dbMetas []*MDDatabaseMeta) error { + // 3. restore views. Since views can cross database we must restore views after all table schemas are restored. + // we don't support restore views concurrency, cauz it maybe will raise a error + p := parser.New() + p.SetSQLMode(si.sqlMode) + for _, dbMeta := range dbMetas { + if len(dbMeta.Views) == 0 { + continue + } + existingViews, err := si.getExistingViews(ctx, dbMeta.Name) + if err != nil { + return err + } + for _, viewMeta := range dbMeta.Views { + if existingViews.Exist(strings.ToLower(viewMeta.Name)) { + si.logger.Info("view already exists in downstream, skip", + zap.String("db", dbMeta.Name), + zap.String("view-name", viewMeta.Name)) + continue + } + sqlStr, err := viewMeta.GetSchema(ctx, si.store) + if err != nil { + return err + } + if strings.TrimSpace(sqlStr) == "" { + si.logger.Info("view schema is empty, skip", + zap.String("db", dbMeta.Name), + zap.String("view-name", viewMeta.Name)) + continue + } + if err = si.runJob(ctx, p, &schemaJob{ + dbName: dbMeta.Name, + tblName: viewMeta.Name, + stmtType: schemaCreateView, + sqlStr: sqlStr, + }); err != nil { + return err + } + } + } + return nil +} + +func (si *SchemaImporter) runJob(ctx context.Context, p *parser.Parser, job *schemaJob) error { + stmts, err := createIfNotExistsStmt(p, job.sqlStr, job.dbName, job.tblName) + if err != nil { + return errors.Trace(err) + } + conn, err := si.db.Conn(ctx) + if err != nil { + return err + } + defer func() { + _ = conn.Close() + }() + + logger := si.logger.With(zap.String("db", job.dbName), zap.String("table", job.tblName)) + sqlWithRetry := common.SQLWithRetry{ + Logger: logger, + DB: conn, + } + for _, stmt := range stmts { + task := logger.Begin(zap.DebugLevel, fmt.Sprintf("execute SQL: %s", stmt)) + err = sqlWithRetry.Exec(ctx, "run create schema job", stmt) + task.End(zap.ErrorLevel, err) + + if err != nil { + return common.ErrCreateSchema.Wrap(err).GenWithStackByArgs(common.UniqueTable(job.dbName, job.tblName), job.stmtType.String()) + } + } + return nil +} + +func (si *SchemaImporter) getExistingDatabases(ctx context.Context) (set.StringSet, error) { + return si.getExistingSchemas(ctx, `SELECT SCHEMA_NAME FROM information_schema.SCHEMATA`) +} + +// the result contains views too, but as table and view share the same name space, it's ok. +func (si *SchemaImporter) getExistingTables(ctx context.Context, dbName string) (set.StringSet, error) { + sb := new(strings.Builder) + sqlescape.MustFormatSQL(sb, `SELECT TABLE_NAME FROM information_schema.TABLES WHERE TABLE_SCHEMA = %?`, dbName) + return si.getExistingSchemas(ctx, sb.String()) +} + +func (si *SchemaImporter) getExistingViews(ctx context.Context, dbName string) (set.StringSet, error) { + sb := new(strings.Builder) + sqlescape.MustFormatSQL(sb, `SELECT TABLE_NAME FROM information_schema.VIEWS WHERE TABLE_SCHEMA = %?`, dbName) + return si.getExistingSchemas(ctx, sb.String()) +} + +// get existing databases/tables/views using the given query, the first column of +// the query result should be the name. +// The returned names are convert to lower case. +func (si *SchemaImporter) getExistingSchemas(ctx context.Context, query string) (set.StringSet, error) { + conn, err := si.db.Conn(ctx) + if err != nil { + return nil, errors.Trace(err) + } + defer func() { + _ = conn.Close() + }() + sqlWithRetry := common.SQLWithRetry{ + Logger: si.logger, + DB: conn, + } + stringRows, err := sqlWithRetry.QueryStringRows(ctx, "get existing schemas", query) + if err != nil { + return nil, errors.Trace(err) + } + res := make(set.StringSet, len(stringRows)) + for _, row := range stringRows { + res.Insert(strings.ToLower(row[0])) + } + return res, nil +} + +func createIfNotExistsStmt(p *parser.Parser, createTable, dbName, tblName string) ([]string, error) { + stmts, _, err := p.ParseSQL(createTable) + if err != nil { + return []string{}, common.ErrInvalidSchemaStmt.Wrap(err).GenWithStackByArgs(createTable) + } + + var res strings.Builder + ctx := format.NewRestoreCtx(format.DefaultRestoreFlags|format.RestoreTiDBSpecialComment|format.RestoreWithTTLEnableOff, &res) + + retStmts := make([]string, 0, len(stmts)) + for _, stmt := range stmts { + switch node := stmt.(type) { + case *ast.CreateDatabaseStmt: + node.Name = model.NewCIStr(dbName) + node.IfNotExists = true + case *ast.DropDatabaseStmt: + node.Name = model.NewCIStr(dbName) + node.IfExists = true + case *ast.CreateTableStmt: + node.Table.Schema = model.NewCIStr(dbName) + node.Table.Name = model.NewCIStr(tblName) + node.IfNotExists = true + case *ast.CreateViewStmt: + node.ViewName.Schema = model.NewCIStr(dbName) + node.ViewName.Name = model.NewCIStr(tblName) + case *ast.DropTableStmt: + node.Tables[0].Schema = model.NewCIStr(dbName) + node.Tables[0].Name = model.NewCIStr(tblName) + node.IfExists = true + } + if err := stmt.Restore(ctx); err != nil { + return []string{}, common.ErrInvalidSchemaStmt.Wrap(err).GenWithStackByArgs(createTable) + } + ctx.WritePlain(";") + retStmts = append(retStmts, res.String()) + res.Reset() + } + + return retStmts, nil +} diff --git a/pkg/lightning/mydump/schema_import_test.go b/pkg/lightning/mydump/schema_import_test.go new file mode 100644 index 0000000000000..024aea9484bd7 --- /dev/null +++ b/pkg/lightning/mydump/schema_import_test.go @@ -0,0 +1,370 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mydump + +import ( + "context" + "fmt" + "os" + "path" + "testing" + + "github.com/DATA-DOG/go-sqlmock" + "github.com/pingcap/errors" + "github.com/pingcap/tidb/br/pkg/storage" + "github.com/pingcap/tidb/pkg/lightning/common" + "github.com/pingcap/tidb/pkg/lightning/log" + "github.com/pingcap/tidb/pkg/parser" + "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/stretchr/testify/require" + "go.uber.org/zap" +) + +func TestSchemaImporter(t *testing.T) { + db, mock, err := sqlmock.New() + mock.MatchExpectationsInOrder(false) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, mock.ExpectationsWereMet()) + // have to ignore the error here, as sqlmock doesn't allow set number of + // expectations, and each opened connection requires a Close() call. + _ = db.Close() + }) + ctx := context.Background() + tempDir := t.TempDir() + store, err := storage.NewLocalStorage(tempDir) + require.NoError(t, err) + logger := log.Logger{Logger: zap.NewExample()} + importer := NewSchemaImporter(logger, mysql.SQLMode(0), db, store, 4) + require.NoError(t, importer.Run(ctx, nil)) + + t.Run("get existing schema err", func(t *testing.T) { + mock.ExpectQuery(`information_schema.SCHEMATA`).WillReturnError(errors.New("non retryable error")) + require.ErrorContains(t, importer.Run(ctx, []*MDDatabaseMeta{{Name: "test"}}), "non retryable error") + require.NoError(t, mock.ExpectationsWereMet()) + }) + + t.Run("database already exists", func(t *testing.T) { + mock.ExpectQuery(`information_schema.SCHEMATA`).WillReturnRows( + sqlmock.NewRows([]string{"SCHEMA_NAME"}).AddRow("test")) + require.NoError(t, importer.Run(ctx, []*MDDatabaseMeta{{Name: "test"}})) + require.NoError(t, mock.ExpectationsWereMet()) + }) + + t.Run("create non exist database", func(t *testing.T) { + mock.ExpectQuery(`information_schema.SCHEMATA`).WillReturnRows( + sqlmock.NewRows([]string{"SCHEMA_NAME"})) + dbMetas := make([]*MDDatabaseMeta, 0, 10) + for i := 0; i < 10; i++ { + mock.ExpectExec(fmt.Sprintf("CREATE DATABASE IF NOT EXISTS `test%02d`", i)). + WillReturnResult(sqlmock.NewResult(0, 0)) + dbMetas = append(dbMetas, &MDDatabaseMeta{Name: fmt.Sprintf("test%02d", i)}) + } + require.NoError(t, importer.Run(ctx, dbMetas)) + require.NoError(t, mock.ExpectationsWereMet()) + }) + + t.Run("break on database error", func(t *testing.T) { + importer2 := NewSchemaImporter(logger, mysql.SQLMode(0), db, store, 1) + mock.ExpectQuery(`information_schema.SCHEMATA`).WillReturnRows( + sqlmock.NewRows([]string{"SCHEMA_NAME"})) + fileName := "invalid-schema.sql" + require.NoError(t, os.WriteFile(path.Join(tempDir, fileName), []byte("CREATE invalid;"), 0o644)) + dbMetas := []*MDDatabaseMeta{ + {Name: "test", charSet: "auto", SchemaFile: FileInfo{FileMeta: SourceFileMeta{Path: fileName}}}, + {Name: "test2"}, // not chance to run + } + require.ErrorContains(t, importer2.Run(ctx, dbMetas), "invalid schema statement") + require.NoError(t, mock.ExpectationsWereMet()) + require.NoError(t, os.Remove(path.Join(tempDir, fileName))) + + dbMetas = append([]*MDDatabaseMeta{{Name: "ttt"}}, dbMetas...) + mock.ExpectQuery(`information_schema.SCHEMATA`).WillReturnRows( + sqlmock.NewRows([]string{"SCHEMA_NAME"})) + mock.ExpectExec("CREATE DATABASE IF NOT EXISTS `ttt`"). + WillReturnError(errors.New("non retryable error")) + err2 := importer2.Run(ctx, dbMetas) + require.ErrorIs(t, err2, common.ErrCreateSchema) + require.ErrorContains(t, err2, "non retryable error") + require.NoError(t, mock.ExpectationsWereMet()) + }) + + t.Run("table: get existing schema err", func(t *testing.T) { + mock.ExpectQuery(`information_schema.SCHEMATA`).WillReturnRows( + sqlmock.NewRows([]string{"SCHEMA_NAME"}). + AddRow("test01").AddRow("test02").AddRow("test03"). + AddRow("test04").AddRow("test05")) + mock.ExpectQuery("TABLES WHERE TABLE_SCHEMA = 'test02'"). + WillReturnError(errors.New("non retryable error")) + dbMetas := []*MDDatabaseMeta{ + {Name: "test01"}, + {Name: "test02", Tables: []*MDTableMeta{{DB: "test02", Name: "t"}}}, + } + require.ErrorContains(t, importer.Run(ctx, dbMetas), "non retryable error") + require.NoError(t, mock.ExpectationsWereMet()) + }) + + t.Run("table: invalid schema file", func(t *testing.T) { + mock.ExpectQuery(`information_schema.SCHEMATA`).WillReturnRows( + sqlmock.NewRows([]string{"SCHEMA_NAME"}). + AddRow("test01").AddRow("test02").AddRow("test03"). + AddRow("test04").AddRow("test05")) + mock.ExpectQuery("TABLES WHERE TABLE_SCHEMA = 'test01'"). + WillReturnRows(sqlmock.NewRows([]string{"TABLE_NAME"}).AddRow("t1")) + fileName := "t2-invalid-schema.sql" + require.NoError(t, os.WriteFile(path.Join(tempDir, fileName), []byte("CREATE table t2 whatever;"), 0o644)) + dbMetas := []*MDDatabaseMeta{ + {Name: "test01", Tables: []*MDTableMeta{ + {DB: "test01", Name: "t1"}, + {DB: "test01", Name: "T2", charSet: "auto", + SchemaFile: FileInfo{FileMeta: SourceFileMeta{Path: fileName}}}, + }}, + } + require.ErrorContains(t, importer.Run(ctx, dbMetas), "line 1 column 24 near") + require.NoError(t, mock.ExpectationsWereMet()) + + // create table t2 downstream manually as workaround + mock.ExpectQuery(`information_schema.SCHEMATA`).WillReturnRows( + sqlmock.NewRows([]string{"SCHEMA_NAME"}). + AddRow("test01").AddRow("test02").AddRow("test03"). + AddRow("test04").AddRow("test05")) + mock.ExpectQuery("TABLES WHERE TABLE_SCHEMA = 'test01'"). + WillReturnRows(sqlmock.NewRows([]string{"TABLE_NAME"}).AddRow("t1").AddRow("t2")) + require.NoError(t, importer.Run(ctx, dbMetas)) + require.NoError(t, mock.ExpectationsWereMet()) + require.NoError(t, os.Remove(path.Join(tempDir, fileName))) + }) + + t.Run("table: break on error", func(t *testing.T) { + importer2 := NewSchemaImporter(logger, mysql.SQLMode(0), db, store, 1) + mock.ExpectQuery(`information_schema.SCHEMATA`).WillReturnRows( + sqlmock.NewRows([]string{"SCHEMA_NAME"}). + AddRow("test01").AddRow("test02").AddRow("test03"). + AddRow("test04").AddRow("test05")) + fileNameT1 := "test01.t1-schema.sql" + fileNameT2 := "test01.t2-schema.sql" + require.NoError(t, os.WriteFile(path.Join(tempDir, fileNameT1), []byte("CREATE table t1(a int);"), 0o644)) + require.NoError(t, os.WriteFile(path.Join(tempDir, fileNameT2), []byte("CREATE table t2(a int);"), 0o644)) + dbMetas := []*MDDatabaseMeta{ + {Name: "test01", Tables: []*MDTableMeta{ + {DB: "test01", Name: "t1", charSet: "auto", SchemaFile: FileInfo{FileMeta: SourceFileMeta{Path: fileNameT1}}}, + {DB: "test01", Name: "t2", charSet: "auto", SchemaFile: FileInfo{FileMeta: SourceFileMeta{Path: fileNameT2}}}, + }}, + } + mock.ExpectQuery("TABLES WHERE TABLE_SCHEMA = 'test01'"). + WillReturnRows(sqlmock.NewRows([]string{"TABLE_NAME"})) + mock.ExpectExec("CREATE TABLE IF NOT EXISTS `test01`.`t1`"). + WillReturnError(errors.New("non retryable create table error")) + require.ErrorContains(t, importer2.Run(ctx, dbMetas), "non retryable create table error") + require.NoError(t, mock.ExpectationsWereMet()) + require.NoError(t, os.Remove(path.Join(tempDir, fileNameT1))) + require.NoError(t, os.Remove(path.Join(tempDir, fileNameT2))) + }) + + t.Run("view: get existing schema err", func(t *testing.T) { + mock.ExpectQuery(`information_schema.SCHEMATA`).WillReturnRows( + sqlmock.NewRows([]string{"SCHEMA_NAME"}).AddRow("test01").AddRow("test02")) + mock.ExpectQuery("VIEWS WHERE TABLE_SCHEMA = 'test02'"). + WillReturnError(errors.New("non retryable error")) + dbMetas := []*MDDatabaseMeta{ + {Name: "test01"}, + {Name: "test02", Views: []*MDTableMeta{{DB: "test02", Name: "v"}}}, + } + require.ErrorContains(t, importer.Run(ctx, dbMetas), "non retryable error") + require.NoError(t, mock.ExpectationsWereMet()) + }) + + t.Run("view: fail on create", func(t *testing.T) { + mock.ExpectQuery(`information_schema.SCHEMATA`).WillReturnRows( + sqlmock.NewRows([]string{"SCHEMA_NAME"}).AddRow("test01").AddRow("test02")) + mock.ExpectQuery("VIEWS WHERE TABLE_SCHEMA = 'test02'"). + WillReturnRows(sqlmock.NewRows([]string{"TABLE_NAME"})) + fileNameV0 := "empty-file.sql" + fileNameV1 := "invalid-schema.sql" + fileNameV2 := "test02.v2-schema-view.sql" + require.NoError(t, os.WriteFile(path.Join(tempDir, fileNameV0), []byte(""), 0o644)) + require.NoError(t, os.WriteFile(path.Join(tempDir, fileNameV1), []byte("xxxx;"), 0o644)) + require.NoError(t, os.WriteFile(path.Join(tempDir, fileNameV2), []byte("create view v2 as select * from t;"), 0o644)) + dbMetas := []*MDDatabaseMeta{ + {Name: "test01"}, + {Name: "test02", Views: []*MDTableMeta{ + {DB: "test02", Name: "V0", charSet: "auto", SchemaFile: FileInfo{FileMeta: SourceFileMeta{Path: fileNameV0}}}, + {DB: "test02", Name: "v1", charSet: "auto", SchemaFile: FileInfo{FileMeta: SourceFileMeta{Path: fileNameV1}}}, + {DB: "test02", Name: "V2", charSet: "auto", SchemaFile: FileInfo{FileMeta: SourceFileMeta{Path: fileNameV2}}}}}, + } + require.ErrorContains(t, importer.Run(ctx, dbMetas), `line 1 column 4 near "xxxx;"`) + require.NoError(t, mock.ExpectationsWereMet()) + + // create view v2 downstream manually as workaround + mock.ExpectQuery(`information_schema.SCHEMATA`).WillReturnRows( + sqlmock.NewRows([]string{"SCHEMA_NAME"}).AddRow("test01").AddRow("test02")) + mock.ExpectQuery("VIEWS WHERE TABLE_SCHEMA = 'test02'"). + WillReturnRows(sqlmock.NewRows([]string{"TABLE_NAME"}).AddRow("V1")) + mock.ExpectExec("VIEW `test02`.`V2` AS SELECT"). + WillReturnResult(sqlmock.NewResult(0, 0)) + require.NoError(t, importer.Run(ctx, dbMetas)) + require.NoError(t, mock.ExpectationsWereMet()) + require.NoError(t, os.Remove(path.Join(tempDir, fileNameV0))) + require.NoError(t, os.Remove(path.Join(tempDir, fileNameV1))) + require.NoError(t, os.Remove(path.Join(tempDir, fileNameV2))) + }) +} + +func TestSchemaImporterManyTables(t *testing.T) { + db, mock, err := sqlmock.New() + mock.MatchExpectationsInOrder(false) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, mock.ExpectationsWereMet()) + // have to ignore the error here, as sqlmock doesn't allow set number of + // expectations, and each opened connection requires a Close() call. + _ = db.Close() + }) + ctx := context.Background() + tempDir := t.TempDir() + store, err := storage.NewLocalStorage(tempDir) + require.NoError(t, err) + logger := log.Logger{Logger: zap.NewExample()} + importer := NewSchemaImporter(logger, mysql.SQLMode(0), db, store, 8) + + mock.ExpectQuery(`information_schema.SCHEMATA`).WillReturnRows( + sqlmock.NewRows([]string{"SCHEMA_NAME"})) + dbMetas := make([]*MDDatabaseMeta, 0, 30) + for i := 0; i < 30; i++ { + dbName := fmt.Sprintf("test%02d", i) + dbMeta := &MDDatabaseMeta{Name: dbName, Tables: make([]*MDTableMeta, 0, 100)} + mock.ExpectExec(fmt.Sprintf("CREATE DATABASE IF NOT EXISTS `%s`", dbName)). + WillReturnResult(sqlmock.NewResult(0, 0)) + mock.ExpectQuery(fmt.Sprintf("TABLES WHERE TABLE_SCHEMA = '%s'", dbName)). + WillReturnRows(sqlmock.NewRows([]string{"TABLE_NAME"})) + for j := 0; j < 50; j++ { + tblName := fmt.Sprintf("t%03d", j) + fileName := fmt.Sprintf("%s.%s-schema.sql", dbName, tblName) + require.NoError(t, os.WriteFile(path.Join(tempDir, fileName), []byte(fmt.Sprintf("CREATE TABLE %s(a int);", tblName)), 0o644)) + mock.ExpectExec(fmt.Sprintf("CREATE TABLE IF NOT EXISTS `%s`.`%s`", dbName, tblName)). + WillReturnResult(sqlmock.NewResult(0, 0)) + dbMeta.Tables = append(dbMeta.Tables, &MDTableMeta{ + DB: dbName, Name: tblName, charSet: "auto", + SchemaFile: FileInfo{FileMeta: SourceFileMeta{Path: fileName}}, + }) + } + dbMetas = append(dbMetas, dbMeta) + } + require.NoError(t, importer.Run(ctx, dbMetas)) + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestCreateTableIfNotExistsStmt(t *testing.T) { + dbName := "testdb" + p := parser.New() + createSQLIfNotExistsStmt := func(createTable, tableName string) []string { + res, err := createIfNotExistsStmt(p, createTable, dbName, tableName) + require.NoError(t, err) + return res + } + + require.Equal(t, []string{"CREATE DATABASE IF NOT EXISTS `testdb` CHARACTER SET = utf8 COLLATE = utf8_general_ci;"}, + createSQLIfNotExistsStmt("CREATE DATABASE `foo` CHARACTER SET = utf8 COLLATE = utf8_general_ci;", "")) + + require.Equal(t, []string{"CREATE TABLE IF NOT EXISTS `testdb`.`foo` (`bar` TINYINT(1));"}, + createSQLIfNotExistsStmt("CREATE TABLE `foo`(`bar` TINYINT(1));", "foo")) + + require.Equal(t, []string{"CREATE TABLE IF NOT EXISTS `testdb`.`foo` (`bar` TINYINT(1));"}, + createSQLIfNotExistsStmt("CREATE TABLE IF NOT EXISTS `foo`(`bar` TINYINT(1));", "foo")) + + // case insensitive + require.Equal(t, []string{"CREATE TABLE IF NOT EXISTS `testdb`.`fOo` (`bar` TINYINT(1));"}, + createSQLIfNotExistsStmt("/* cOmmEnt */ creAte tablE `fOo`(`bar` TinyinT(1));", "fOo")) + + require.Equal(t, []string{"CREATE TABLE IF NOT EXISTS `testdb`.`FoO` (`bAR` TINYINT(1));"}, + createSQLIfNotExistsStmt("/* coMMenT */ crEatE tAble If not EXISts `FoO`(`bAR` tiNyInT(1));", "FoO")) + + // only one "CREATE TABLE" is replaced + require.Equal(t, []string{"CREATE TABLE IF NOT EXISTS `testdb`.`foo` (`bar` INT(1) COMMENT 'CREATE TABLE');"}, + createSQLIfNotExistsStmt("CREATE TABLE `foo`(`bar` INT(1) COMMENT 'CREATE TABLE');", "foo")) + + // test clustered index consistency + require.Equal(t, []string{"CREATE TABLE IF NOT EXISTS `testdb`.`foo` (`bar` INT(1) PRIMARY KEY /*T![clustered_index] CLUSTERED */ COMMENT 'CREATE TABLE');"}, + createSQLIfNotExistsStmt("CREATE TABLE `foo`(`bar` INT(1) PRIMARY KEY CLUSTERED COMMENT 'CREATE TABLE');", "foo")) + require.Equal(t, []string{"CREATE TABLE IF NOT EXISTS `testdb`.`foo` (`bar` INT(1) COMMENT 'CREATE TABLE',PRIMARY KEY(`bar`) /*T![clustered_index] NONCLUSTERED */);"}, + createSQLIfNotExistsStmt("CREATE TABLE `foo`(`bar` INT(1) COMMENT 'CREATE TABLE', PRIMARY KEY (`bar`) NONCLUSTERED);", "foo")) + require.Equal(t, []string{"CREATE TABLE IF NOT EXISTS `testdb`.`foo` (`bar` INT(1) PRIMARY KEY /*T![clustered_index] NONCLUSTERED */ COMMENT 'CREATE TABLE');"}, + createSQLIfNotExistsStmt("CREATE TABLE `foo`(`bar` INT(1) PRIMARY KEY /*T![clustered_index] NONCLUSTERED */ COMMENT 'CREATE TABLE');", "foo")) + require.Equal(t, []string{"CREATE TABLE IF NOT EXISTS `testdb`.`foo` (`bar` INT(1) COMMENT 'CREATE TABLE',PRIMARY KEY(`bar`) /*T![clustered_index] CLUSTERED */);"}, + createSQLIfNotExistsStmt("CREATE TABLE `foo`(`bar` INT(1) COMMENT 'CREATE TABLE', PRIMARY KEY (`bar`) /*T![clustered_index] CLUSTERED */);", "foo")) + + require.Equal(t, []string{"CREATE TABLE IF NOT EXISTS `testdb`.`foo` (`bar` INT(1) PRIMARY KEY /*T![auto_rand] AUTO_RANDOM(2) */ COMMENT 'CREATE TABLE');"}, + createSQLIfNotExistsStmt("CREATE TABLE `foo`(`bar` INT(1) PRIMARY KEY AUTO_RANDOM(2) COMMENT 'CREATE TABLE');", "foo")) + + // upper case becomes shorter + require.Equal(t, []string{"CREATE TABLE IF NOT EXISTS `testdb`.`ſ` (`ı` TINYINT(1));"}, + createSQLIfNotExistsStmt("CREATE TABLE `ſ`(`ı` TINYINT(1));", "ſ")) + + // upper case becomes longer + require.Equal(t, []string{"CREATE TABLE IF NOT EXISTS `testdb`.`ɑ` (`ȿ` TINYINT(1));"}, + createSQLIfNotExistsStmt("CREATE TABLE `ɑ`(`ȿ` TINYINT(1));", "ɑ")) + + // non-utf-8 + require.Equal(t, []string{"CREATE TABLE IF NOT EXISTS `testdb`.`\xcc\xcc\xcc` (`???` TINYINT(1));"}, + createSQLIfNotExistsStmt("CREATE TABLE `\xcc\xcc\xcc`(`\xdd\xdd\xdd` TINYINT(1));", "\xcc\xcc\xcc")) + + // renaming a table + require.Equal(t, []string{"CREATE TABLE IF NOT EXISTS `testdb`.`ba``r` (`x` INT);"}, + createSQLIfNotExistsStmt("create table foo(x int);", "ba`r")) + + // conditional comments + require.Equal(t, []string{ + "SET NAMES 'binary';", + "SET @@SESSION.`FOREIGN_KEY_CHECKS`=0;", + "CREATE TABLE IF NOT EXISTS `testdb`.`m` (`z` DOUBLE) ENGINE = InnoDB AUTO_INCREMENT = 8343230 DEFAULT CHARACTER SET = UTF8;", + }, + createSQLIfNotExistsStmt(` + /*!40101 SET NAMES binary*/; + /*!40014 SET FOREIGN_KEY_CHECKS=0*/; + CREATE TABLE x.y (z double) ENGINE=InnoDB AUTO_INCREMENT=8343230 DEFAULT CHARSET=utf8; + `, "m")) + + // create view + require.Equal(t, []string{ + "SET NAMES 'binary';", + "DROP TABLE IF EXISTS `testdb`.`m`;", + "DROP VIEW IF EXISTS `testdb`.`m`;", + "SET @`PREV_CHARACTER_SET_CLIENT`=@@`character_set_client`;", + "SET @`PREV_CHARACTER_SET_RESULTS`=@@`character_set_results`;", + "SET @`PREV_COLLATION_CONNECTION`=@@`collation_connection`;", + "SET @@SESSION.`character_set_client`=`utf8`;", + "SET @@SESSION.`character_set_results`=`utf8`;", + "SET @@SESSION.`collation_connection`=`utf8_general_ci`;", + "CREATE ALGORITHM = UNDEFINED DEFINER = `root`@`192.168.198.178` SQL SECURITY DEFINER VIEW `testdb`.`m` (`s`) AS SELECT `s` FROM `db1`.`v1` WHERE `i`<2;", + "SET @@SESSION.`character_set_client`=@`PREV_CHARACTER_SET_CLIENT`;", + "SET @@SESSION.`character_set_results`=@`PREV_CHARACTER_SET_RESULTS`;", + "SET @@SESSION.`collation_connection`=@`PREV_COLLATION_CONNECTION`;", + }, + createSQLIfNotExistsStmt(` + /*!40101 SET NAMES binary*/; + DROP TABLE IF EXISTS v2; + DROP VIEW IF EXISTS v2; + SET @PREV_CHARACTER_SET_CLIENT=@@CHARACTER_SET_CLIENT; + SET @PREV_CHARACTER_SET_RESULTS=@@CHARACTER_SET_RESULTS; + SET @PREV_COLLATION_CONNECTION=@@COLLATION_CONNECTION; + SET character_set_client = utf8; + SET character_set_results = utf8; + SET collation_connection = utf8_general_ci; + CREATE ALGORITHM=UNDEFINED DEFINER=root@192.168.198.178 SQL SECURITY DEFINER VIEW v2 (s) AS SELECT s FROM db1.v1 WHERE i<2; + SET character_set_client = @PREV_CHARACTER_SET_CLIENT; + SET character_set_results = @PREV_CHARACTER_SET_RESULTS; + SET collation_connection = @PREV_COLLATION_CONNECTION; + `, "m")) +} diff --git a/pkg/meta/meta.go b/pkg/meta/meta.go index a231a0168206d..5b31b9a2f2392 100644 --- a/pkg/meta/meta.go +++ b/pkg/meta/meta.go @@ -63,7 +63,8 @@ var ( // } var ( - mMetaPrefix = []byte("m") + mMetaPrefix = []byte("m") + // the value inside it is actually the max current used ID, not next id. mNextGlobalIDKey = []byte("NextGlobalID") mSchemaVersionKey = []byte("SchemaVersionKey") mDBs = []byte("DBs") @@ -473,6 +474,8 @@ func (m *Meta) GetAutoIDAccessors(dbID, tableID int64) AutoIDAccessors { // To solve this problem, we always check the schema diff at first, if the diff is empty, we know at t2 moment we can only see the v9 schema, // so make neededSchemaVersion = neededSchemaVersion - 1. // For `Reload`, we can also do this: if the newest version's diff is not set yet, it is ok to load the previous version's infoSchema, and wait for the next reload. +// if there are multiple consecutive jobs failed or cancelled after the schema version +// increased, the returned 'version - 1' might still not have diff. func (m *Meta) GetSchemaVersionWithNonEmptyDiff() (int64, error) { v, err := m.txn.GetInt64(mSchemaVersionKey) if err != nil { @@ -1267,6 +1270,8 @@ var ( var ( // DefaultJobListKey keeps all actions of DDL jobs except "add index". + // this and below list are always appended, so the order is the same as the + // job's creation order. DefaultJobListKey JobListKeyType = mDDLJobListKey // AddIndexJobListKey only keeps the action of adding index. AddIndexJobListKey JobListKeyType = mDDLJobAddIdxList diff --git a/pkg/parser/model/ddl.go b/pkg/parser/model/ddl.go index df525f8dcbd47..5862949eae183 100644 --- a/pkg/parser/model/ddl.go +++ b/pkg/parser/model/ddl.go @@ -488,9 +488,13 @@ type JobMeta struct { // Job is for a DDL operation. type Job struct { - ID int64 `json:"id"` - Type ActionType `json:"type"` - SchemaID int64 `json:"schema_id"` + ID int64 `json:"id"` + Type ActionType `json:"type"` + // SchemaID means different for different job types: + // - ExchangeTablePartition: db id of non-partitioned table + SchemaID int64 `json:"schema_id"` + // TableID means different for different job types: + // - ExchangeTablePartition: non-partitioned table id TableID int64 `json:"table_id"` SchemaName string `json:"schema_name"` TableName string `json:"table_name"` @@ -504,8 +508,15 @@ type Job struct { Mu sync.Mutex `json:"-"` // CtxVars are variables attached to the job. It is for internal usage. // E.g. passing arguments between functions by one single *Job pointer. + // for ExchangeTablePartition, RenameTables, RenameTable, it's [slice-of-db-id, slice-of-table-id] CtxVars []interface{} `json:"-"` - Args []interface{} `json:"-"` + // Note: it might change when state changes, such as when rollback on AddColumn. + // - CreateTable, it's [model.TableInfo, foreignKeyCheck] + // - AddIndex or AddPrimaryKey: [unique, .... + // - TruncateTable: [new-table-id, foreignKeyCheck, ... + // - RenameTable: [old-db-id, new-table-name, old-db-name] + // - ExchangeTablePartition: [partition-id, pt-db-id, pt-id, partition-name, with-validation] + Args []interface{} `json:"-"` // RawArgs : We must use json raw message to delay parsing special args. RawArgs json.RawMessage `json:"raw_args"` SchemaState SchemaState `json:"schema_state"` @@ -517,7 +528,7 @@ type Job struct { // StartTS uses timestamp allocated by TSO. // Now it's the TS when we put the job to TiKV queue. StartTS uint64 `json:"start_ts"` - // DependencyID is the job's ID that the current job depends on. + // DependencyID is the largest job ID before current job and current job depends on. DependencyID int64 `json:"dependency_id"` // Query string of the ddl job. Query string `json:"query"` @@ -561,8 +572,9 @@ type Job struct { // CDCWriteSource indicates the source of CDC write. CDCWriteSource uint64 `json:"cdc_write_source"` - // LocalMode indicates whether the job is running in local TiDB. - // Only happens when tidb_enable_fast_ddl = on + // LocalMode = true means the job is running on the local TiDB that the client + // connects to, else it's run on the DDL owner. + // Only happens when tidb_enable_fast_create_table = on LocalMode bool `json:"local_mode"` // SQLMode for executing DDL query. @@ -1016,9 +1028,9 @@ type JobState int32 const ( JobStateNone JobState = 0 JobStateRunning JobState = 1 + // JobStateRollingback is the state to do the rolling back job. // When DDL encountered an unrecoverable error at reorganization state, // some keys has been added already, we need to remove them. - // JobStateRollingback is the state to do the rolling back job. JobStateRollingback JobState = 2 JobStateRollbackDone JobState = 3 JobStateDone JobState = 4 diff --git a/pkg/planner/context/context.go b/pkg/planner/context/context.go index 0d78cc5edc54e..41b38ac2ead6e 100644 --- a/pkg/planner/context/context.go +++ b/pkg/planner/context/context.go @@ -37,6 +37,8 @@ type PlanContext interface { GetRestrictedSQLExecutor() sqlexec.RestrictedSQLExecutor // GetExprCtx gets the expression context. GetExprCtx() exprctx.ExprContext + // GetNullRejectCheckExprCtx gets the expression context with null rejected check. + GetNullRejectCheckExprCtx() exprctx.ExprContext // GetStore returns the store of session. GetStore() kv.Storage // GetSessionVars gets the session variables. diff --git a/pkg/planner/contextimpl/BUILD.bazel b/pkg/planner/contextimpl/BUILD.bazel index c1035df8e48f9..fa66dbfb08ea1 100644 --- a/pkg/planner/contextimpl/BUILD.bazel +++ b/pkg/planner/contextimpl/BUILD.bazel @@ -10,5 +10,6 @@ go_library( "//pkg/planner/context", "//pkg/sessionctx", "//pkg/sessiontxn", + "//pkg/util/intest", ], ) diff --git a/pkg/planner/contextimpl/impl.go b/pkg/planner/contextimpl/impl.go index 04fc7015b5e7f..0d42754861e30 100644 --- a/pkg/planner/contextimpl/impl.go +++ b/pkg/planner/contextimpl/impl.go @@ -19,6 +19,7 @@ import ( "github.com/pingcap/tidb/pkg/planner/context" "github.com/pingcap/tidb/pkg/sessionctx" "github.com/pingcap/tidb/pkg/sessiontxn" + "github.com/pingcap/tidb/pkg/util/intest" ) var _ context.PlanContext = struct { @@ -28,13 +29,25 @@ var _ context.PlanContext = struct { // PlanCtxExtendedImpl provides extended method for session context to implement `PlanContext` type PlanCtxExtendedImpl struct { - sctx sessionctx.Context - exprCtx exprctx.BuildContext + sctx sessionctx.Context + nullRejectCheckExprCtx *exprctx.NullRejectCheckExprContext } // NewPlanCtxExtendedImpl creates a new PlanCtxExtendedImpl. func NewPlanCtxExtendedImpl(sctx sessionctx.Context) *PlanCtxExtendedImpl { - return &PlanCtxExtendedImpl{sctx: sctx} + return &PlanCtxExtendedImpl{ + sctx: sctx, + nullRejectCheckExprCtx: exprctx.WithNullRejectCheck(sctx.GetExprCtx()), + } +} + +// GetNullRejectCheckExprCtx returns a context with null rejected check +func (ctx *PlanCtxExtendedImpl) GetNullRejectCheckExprCtx() exprctx.ExprContext { + intest.AssertFunc(func() bool { + // assert `sctx.GetExprCtx()` should keep the same to avoid some unexpected behavior. + return ctx.nullRejectCheckExprCtx.ExprContext == ctx.sctx.GetExprCtx() + }) + return ctx.nullRejectCheckExprCtx } // AdviseTxnWarmup advises the txn to warm up. diff --git a/pkg/planner/core/casetest/planstats/main_test.go b/pkg/planner/core/casetest/planstats/main_test.go index f53fa7fc26c58..a1289ccab2a6c 100644 --- a/pkg/planner/core/casetest/planstats/main_test.go +++ b/pkg/planner/core/casetest/planstats/main_test.go @@ -40,6 +40,7 @@ func TestMain(m *testing.M) { goleak.IgnoreTopFunction("gopkg.in/natefinch/lumberjack%2ev2.(*Logger).millRun"), goleak.IgnoreTopFunction("github.com/tikv/client-go/v2/txnkv/transaction.keepAlive"), goleak.IgnoreTopFunction("go.opencensus.io/stats/view.(*worker).start"), + goleak.IgnoreTopFunction("github.com/pingcap/tidb/pkg/statistics/handle/syncload.(*statsSyncLoad).SendLoadRequests.func1"), // For TestPlanStatsLoadTimeout } callback := func(i int) int { diff --git a/pkg/planner/core/casetest/rule/testdata/outer2inner_out.json b/pkg/planner/core/casetest/rule/testdata/outer2inner_out.json index 7cb1295f5d23f..2f99510904096 100644 --- a/pkg/planner/core/casetest/rule/testdata/outer2inner_out.json +++ b/pkg/planner/core/casetest/rule/testdata/outer2inner_out.json @@ -83,12 +83,12 @@ { "SQL": "select * from t2 left outer join t1 on a1=a2 where b2+b1 > 2; -- expression evaluates to UNKNOWN/FALSE even though we have fields from outer table", "Plan": [ - "HashJoin 12487.50 root inner join, equal:[eq(test.t2.a2, test.t1.a1)], other cond:gt(plus(test.t2.b2, test.t1.b1), 2)", - "├─TableReader(Build) 9990.00 root data:Selection", - "│ └─Selection 9990.00 cop[tikv] not(isnull(test.t1.a1))", - "│ └─TableFullScan 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo", - "└─TableReader(Probe) 9990.00 root data:Selection", - " └─Selection 9990.00 cop[tikv] not(isnull(test.t2.a2))", + "Selection 9990.00 root gt(plus(test.t2.b2, test.t1.b1), 2)", + "└─HashJoin 12487.50 root left outer join, equal:[eq(test.t2.a2, test.t1.a1)]", + " ├─TableReader(Build) 9990.00 root data:Selection", + " │ └─Selection 9990.00 cop[tikv] not(isnull(test.t1.a1))", + " │ └─TableFullScan 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo", + " └─TableReader(Probe) 10000.00 root data:TableFullScan", " └─TableFullScan 10000.00 cop[tikv] table:t2 keep order:false, stats:pseudo" ] }, @@ -147,51 +147,50 @@ { "SQL": "select * from t1 ta left outer join (t1 tb left outer join t1 tc on tb.b1 = tc.b1) on ta.a1=tc.a1; -- nested join. On clause is null filtering on tc.", "Plan": [ - "HashJoin 15593.77 root left outer join, equal:[eq(test.t1.a1, test.t1.a1)]", - "├─TableReader(Build) 10000.00 root data:TableFullScan", - "│ └─TableFullScan 10000.00 cop[tikv] table:ta keep order:false, stats:pseudo", - "└─Projection(Probe) 12475.01 root test.t1.a1, test.t1.b1, test.t1.c1, test.t1.a1, test.t1.b1, test.t1.c1", - " └─HashJoin 12475.01 root inner join, equal:[eq(test.t1.b1, test.t1.b1)]", - " ├─TableReader(Build) 9980.01 root data:Selection", - " │ └─Selection 9980.01 cop[tikv] not(isnull(test.t1.a1)), not(isnull(test.t1.b1))", - " │ └─TableFullScan 10000.00 cop[tikv] table:tc keep order:false, stats:pseudo", - " └─TableReader(Probe) 9990.00 root data:Selection", - " └─Selection 9990.00 cop[tikv] not(isnull(test.t1.b1))", - " └─TableFullScan 10000.00 cop[tikv] table:tb keep order:false, stats:pseudo" + "HashJoin 12487.50 root left outer join, equal:[eq(test.t1.a1, test.t1.a1)]", + "├─Selection(Build) 9990.00 root not(isnull(test.t1.a1))", + "│ └─HashJoin 12487.50 root left outer join, equal:[eq(test.t1.b1, test.t1.b1)]", + "│ ├─TableReader(Build) 9990.00 root data:Selection", + "│ │ └─Selection 9990.00 cop[tikv] not(isnull(test.t1.b1))", + "│ │ └─TableFullScan 10000.00 cop[tikv] table:tc keep order:false, stats:pseudo", + "│ └─TableReader(Probe) 10000.00 root data:TableFullScan", + "│ └─TableFullScan 10000.00 cop[tikv] table:tb keep order:false, stats:pseudo", + "└─TableReader(Probe) 10000.00 root data:TableFullScan", + " └─TableFullScan 10000.00 cop[tikv] table:ta keep order:false, stats:pseudo" ] }, { "SQL": "select * from t1 ta left outer join (t1 tb left outer join t1 tc on tb.b1 = tc.b1) on ta.a1=tc.a1 where tb.a1 > 5; -- nested join. On clause and WHERE clause are filters", "Plan": [ - "Projection 5203.12 root test.t1.a1, test.t1.b1, test.t1.c1, test.t1.a1, test.t1.b1, test.t1.c1, test.t1.a1, test.t1.b1, test.t1.c1", - "└─HashJoin 5203.12 root inner join, equal:[eq(test.t1.a1, test.t1.a1)]", - " ├─HashJoin(Build) 4162.50 root inner join, equal:[eq(test.t1.b1, test.t1.b1)]", - " │ ├─TableReader(Build) 3330.00 root data:Selection", - " │ │ └─Selection 3330.00 cop[tikv] gt(test.t1.a1, 5), not(isnull(test.t1.b1))", - " │ │ └─TableFullScan 10000.00 cop[tikv] table:tb keep order:false, stats:pseudo", - " │ └─TableReader(Probe) 9980.01 root data:Selection", - " │ └─Selection 9980.01 cop[tikv] not(isnull(test.t1.a1)), not(isnull(test.t1.b1))", - " │ └─TableFullScan 10000.00 cop[tikv] table:tc keep order:false, stats:pseudo", - " └─TableReader(Probe) 9990.00 root data:Selection", - " └─Selection 9990.00 cop[tikv] not(isnull(test.t1.a1))", - " └─TableFullScan 10000.00 cop[tikv] table:ta keep order:false, stats:pseudo" + "HashJoin 4166.67 root inner join, equal:[eq(test.t1.a1, test.t1.a1)]", + "├─Selection(Build) 3333.33 root not(isnull(test.t1.a1))", + "│ └─HashJoin 4166.67 root left outer join, equal:[eq(test.t1.b1, test.t1.b1)]", + "│ ├─TableReader(Build) 3333.33 root data:Selection", + "│ │ └─Selection 3333.33 cop[tikv] gt(test.t1.a1, 5)", + "│ │ └─TableFullScan 10000.00 cop[tikv] table:tb keep order:false, stats:pseudo", + "│ └─TableReader(Probe) 9990.00 root data:Selection", + "│ └─Selection 9990.00 cop[tikv] not(isnull(test.t1.b1))", + "│ └─TableFullScan 10000.00 cop[tikv] table:tc keep order:false, stats:pseudo", + "└─TableReader(Probe) 9990.00 root data:Selection", + " └─Selection 9990.00 cop[tikv] not(isnull(test.t1.a1))", + " └─TableFullScan 10000.00 cop[tikv] table:ta keep order:false, stats:pseudo" ] }, { "SQL": "select * from (t2 left join t1 on a1=a2) join t3 on b1=b3 -- on clause applied nested join", "Plan": [ - "Projection 15593.77 root test.t2.a2, test.t2.b2, test.t2.c2, test.t1.a1, test.t1.b1, test.t1.c1, test.t3.a3, test.t3.b3, test.t3.c3", - "└─HashJoin 15593.77 root inner join, equal:[eq(test.t1.b1, test.t3.b3)]", - " ├─TableReader(Build) 9990.00 root data:Selection", - " │ └─Selection 9990.00 cop[tikv] not(isnull(test.t3.b3))", - " │ └─TableFullScan 10000.00 cop[tikv] table:t3 keep order:false, stats:pseudo", - " └─HashJoin(Probe) 12475.01 root inner join, equal:[eq(test.t1.a1, test.t2.a2)]", - " ├─TableReader(Build) 9980.01 root data:Selection", - " │ └─Selection 9980.01 cop[tikv] not(isnull(test.t1.a1)), not(isnull(test.t1.b1))", - " │ └─TableFullScan 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo", - " └─TableReader(Probe) 9990.00 root data:Selection", - " └─Selection 9990.00 cop[tikv] not(isnull(test.t2.a2))", - " └─TableFullScan 10000.00 cop[tikv] table:t2 keep order:false, stats:pseudo" + "Projection 12487.50 root test.t2.a2, test.t2.b2, test.t2.c2, test.t1.a1, test.t1.b1, test.t1.c1, test.t3.a3, test.t3.b3, test.t3.c3", + "└─HashJoin 12487.50 root inner join, equal:[eq(test.t3.b3, test.t1.b1)]", + " ├─Selection(Build) 9990.00 root not(isnull(test.t1.b1))", + " │ └─HashJoin 12487.50 root left outer join, equal:[eq(test.t2.a2, test.t1.a1)]", + " │ ├─TableReader(Build) 9990.00 root data:Selection", + " │ │ └─Selection 9990.00 cop[tikv] not(isnull(test.t1.a1))", + " │ │ └─TableFullScan 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo", + " │ └─TableReader(Probe) 10000.00 root data:TableFullScan", + " │ └─TableFullScan 10000.00 cop[tikv] table:t2 keep order:false, stats:pseudo", + " └─TableReader(Probe) 9990.00 root data:Selection", + " └─Selection 9990.00 cop[tikv] not(isnull(test.t3.b3))", + " └─TableFullScan 10000.00 cop[tikv] table:t3 keep order:false, stats:pseudo" ] }, { diff --git a/pkg/planner/core/exhaust_physical_plans.go b/pkg/planner/core/exhaust_physical_plans.go index bde3bc3a45430..4d588efd1b7f7 100644 --- a/pkg/planner/core/exhaust_physical_plans.go +++ b/pkg/planner/core/exhaust_physical_plans.go @@ -1416,102 +1416,103 @@ func (p *LogicalJoin) constructInnerIndexScanTask( // // Step2: build other inner plan node to task func (p *LogicalJoin) constructIndexJoinInnerSideTask(dsCopTask *CopTask, ds *DataSource, path *util.AccessPath, wrapper *indexJoinInnerChildWrapper) base.Task { - if len(wrapper.zippedChildren) == 0 { - t := dsCopTask.ConvertToRootTask(ds.SCtx()) - return t - } - la, canPushAggToCop := wrapper.zippedChildren[len(wrapper.zippedChildren)-1].(*LogicalAggregation) - if la != nil && la.HasDistinct() { - // TODO: remove AllowDistinctAggPushDown after the cost estimation of distinct pushdown is implemented. - // If AllowDistinctAggPushDown is set to true, we should not consider RootTask. - if !la.SCtx().GetSessionVars().AllowDistinctAggPushDown { - canPushAggToCop = false + var la *LogicalAggregation + var canPushAggToCop bool + if len(wrapper.zippedChildren) > 0 { + la, canPushAggToCop = wrapper.zippedChildren[len(wrapper.zippedChildren)-1].(*LogicalAggregation) + if la != nil && la.HasDistinct() { + // TODO: remove AllowDistinctAggPushDown after the cost estimation of distinct pushdown is implemented. + // If AllowDistinctAggPushDown is set to true, we should not consider RootTask. + if !la.SCtx().GetSessionVars().AllowDistinctAggPushDown { + canPushAggToCop = false + } } } - if canPushAggToCop { - // Try stream aggregation first. - // We will choose the stream aggregation if the following conditions are met: - // 1. Force hint stream agg by /*+ stream_agg() */ - // 2. Other conditions copy from getStreamAggs() in exhaust_physical_plans.go - _, preferStream := la.ResetHintIfConflicted() - for _, aggFunc := range la.AggFuncs { - if aggFunc.Mode == aggregation.FinalMode { - preferStream = false - break - } - } - // group by a + b is not interested in any order. - groupByCols := la.GetGroupByCols() - if len(groupByCols) != len(la.GroupByItems) { + // If the bottom plan is not aggregation or the aggregation can't be pushed to coprocessor, we will construct a root task directly. + if !canPushAggToCop { + result := dsCopTask.ConvertToRootTask(ds.SCtx()).(*RootTask) + result.SetPlan(p.constructInnerByZippedChildren(wrapper.zippedChildren, result.GetPlan())) + return result + } + + // Try stream aggregation first. + // We will choose the stream aggregation if the following conditions are met: + // 1. Force hint stream agg by /*+ stream_agg() */ + // 2. Other conditions copy from getStreamAggs() in exhaust_physical_plans.go + _, preferStream := la.ResetHintIfConflicted() + for _, aggFunc := range la.AggFuncs { + if aggFunc.Mode == aggregation.FinalMode { preferStream = false + break } - if la.HasDistinct() && !la.distinctArgsMeetsProperty() { + } + // group by a + b is not interested in any order. + groupByCols := la.GetGroupByCols() + if len(groupByCols) != len(la.GroupByItems) { + preferStream = false + } + if la.HasDistinct() && !la.distinctArgsMeetsProperty() { + preferStream = false + } + // sort items must be the super set of group by items + if path != nil && path.Index != nil && !path.Index.MVIndex && + ds.tableInfo.GetPartitionInfo() == nil { + if len(path.IdxCols) < len(groupByCols) { preferStream = false } - // sort items must be the super set of group by items - if path != nil && path.Index != nil && !path.Index.MVIndex && - ds.tableInfo.GetPartitionInfo() == nil { - if len(path.IdxCols) < len(groupByCols) { + sctx := p.SCtx() + for i, groupbyCol := range groupByCols { + if path.IdxColLens[i] != types.UnspecifiedLength || + !groupbyCol.EqualByExprAndID(sctx.GetExprCtx().GetEvalCtx(), path.IdxCols[i]) { preferStream = false } - sctx := p.SCtx() - for i, groupbyCol := range groupByCols { - if path.IdxColLens[i] != types.UnspecifiedLength || - !groupbyCol.EqualByExprAndID(sctx.GetExprCtx().GetEvalCtx(), path.IdxCols[i]) { - preferStream = false - } - } - } else { - preferStream = false } + } else { + preferStream = false + } - // build physical agg and attach to task - var aggTask base.Task - // build stream agg and change ds keep order to true - if preferStream { - newGbyItems := make([]expression.Expression, len(la.GroupByItems)) - copy(newGbyItems, la.GroupByItems) - newAggFuncs := make([]*aggregation.AggFuncDesc, len(la.AggFuncs)) - copy(newAggFuncs, la.AggFuncs) - streamAgg := basePhysicalAgg{ - GroupByItems: newGbyItems, - AggFuncs: newAggFuncs, - }.initForStream(la.SCtx(), la.StatsInfo(), la.QueryBlockOffset(), nil) - streamAgg.SetSchema(la.schema.Clone()) - // change to keep order for index scan and dsCopTask - if dsCopTask.indexPlan != nil { - // get the index scan from dsCopTask.indexPlan - physicalIndexScan, _ := dsCopTask.indexPlan.(*PhysicalIndexScan) - if physicalIndexScan == nil && len(dsCopTask.indexPlan.Children()) == 1 { - physicalIndexScan, _ = dsCopTask.indexPlan.Children()[0].(*PhysicalIndexScan) - } - if physicalIndexScan != nil { - physicalIndexScan.KeepOrder = true - dsCopTask.keepOrder = true - aggTask = streamAgg.Attach2Task(dsCopTask) - } + // build physical agg and attach to task + var aggTask base.Task + // build stream agg and change ds keep order to true + if preferStream { + newGbyItems := make([]expression.Expression, len(la.GroupByItems)) + copy(newGbyItems, la.GroupByItems) + newAggFuncs := make([]*aggregation.AggFuncDesc, len(la.AggFuncs)) + copy(newAggFuncs, la.AggFuncs) + streamAgg := basePhysicalAgg{ + GroupByItems: newGbyItems, + AggFuncs: newAggFuncs, + }.initForStream(la.SCtx(), la.StatsInfo(), la.QueryBlockOffset(), nil) + streamAgg.SetSchema(la.schema.Clone()) + // change to keep order for index scan and dsCopTask + if dsCopTask.indexPlan != nil { + // get the index scan from dsCopTask.indexPlan + physicalIndexScan, _ := dsCopTask.indexPlan.(*PhysicalIndexScan) + if physicalIndexScan == nil && len(dsCopTask.indexPlan.Children()) == 1 { + physicalIndexScan, _ = dsCopTask.indexPlan.Children()[0].(*PhysicalIndexScan) + } + if physicalIndexScan != nil { + physicalIndexScan.KeepOrder = true + dsCopTask.keepOrder = true + aggTask = streamAgg.Attach2Task(dsCopTask) } } + } - // build hash agg, when the stream agg is illegal such as the order by prop is not matched - if aggTask == nil { - physicalHashAgg := NewPhysicalHashAgg(la, la.StatsInfo(), nil) - physicalHashAgg.SetSchema(la.schema.Clone()) - aggTask = physicalHashAgg.Attach2Task(dsCopTask) - } - - // build other inner plan node to task - result, ok := aggTask.(*RootTask) - if !ok { - return nil - } - result.p = p.constructInnerByZippedChildren(wrapper.zippedChildren[0:len(wrapper.zippedChildren)-1], result.p) - return result + // build hash agg, when the stream agg is illegal such as the order by prop is not matched + if aggTask == nil { + physicalHashAgg := NewPhysicalHashAgg(la, la.StatsInfo(), nil) + physicalHashAgg.SetSchema(la.schema.Clone()) + aggTask = physicalHashAgg.Attach2Task(dsCopTask) } - result := dsCopTask.ConvertToRootTask(ds.SCtx()).(*RootTask) - result.SetPlan(p.constructInnerByZippedChildren(wrapper.zippedChildren, result.GetPlan())) + // build other inner plan node to task + result, ok := aggTask.(*RootTask) + if !ok { + return nil + } + result.SetPlan(p.constructInnerByZippedChildren(wrapper.zippedChildren[0:len(wrapper.zippedChildren)-1], result.p)) return result } diff --git a/pkg/planner/core/expression_rewriter.go b/pkg/planner/core/expression_rewriter.go index 147bff418e360..873295ca99ea7 100644 --- a/pkg/planner/core/expression_rewriter.go +++ b/pkg/planner/core/expression_rewriter.go @@ -1493,7 +1493,7 @@ func (er *expressionRewriter) Leave(originInNode ast.Node) (retNode ast.Node, ok return retNode, false } - castFunction, err := expression.BuildCastFunctionWithCheck(er.sctx, arg, v.Tp) + castFunction, err := expression.BuildCastFunctionWithCheck(er.sctx, arg, v.Tp, false) if err != nil { er.err = err return retNode, false diff --git a/pkg/planner/core/logical_plan_builder.go b/pkg/planner/core/logical_plan_builder.go index 4e83b11413e35..1341222b519b6 100644 --- a/pkg/planner/core/logical_plan_builder.go +++ b/pkg/planner/core/logical_plan_builder.go @@ -249,6 +249,7 @@ func (b *PlanBuilder) buildAggregation(ctx context.Context, p base.LogicalPlan, b.optFlag |= flagPredicatePushDown b.optFlag |= flagEliminateAgg b.optFlag |= flagEliminateProjection + b.optFlag |= flagConvertOuterToInnerJoin if b.ctx.GetSessionVars().EnableSkewDistinctAgg { b.optFlag |= flagSkewDistinctAgg @@ -562,13 +563,13 @@ func (p *LogicalJoin) ExtractOnCondition( } if leftCol != nil && rightCol != nil { if deriveLeft { - if isNullFilteredOneExpr(ctx, leftSchema, expr) && !mysql.HasNotNullFlag(leftCol.RetType.GetFlag()) { + if util.IsNullRejected(ctx, leftSchema, expr) && !mysql.HasNotNullFlag(leftCol.RetType.GetFlag()) { notNullExpr := expression.BuildNotNullExpr(ctx.GetExprCtx(), leftCol) leftCond = append(leftCond, notNullExpr) } } if deriveRight { - if isNullFilteredOneExpr(ctx, rightSchema, expr) && !mysql.HasNotNullFlag(rightCol.RetType.GetFlag()) { + if util.IsNullRejected(ctx, rightSchema, expr) && !mysql.HasNotNullFlag(rightCol.RetType.GetFlag()) { notNullExpr := expression.BuildNotNullExpr(ctx.GetExprCtx(), rightCol) rightCond = append(rightCond, notNullExpr) } @@ -1290,6 +1291,7 @@ func (b *PlanBuilder) buildSelection(ctx context.Context, p base.LogicalPlan, wh if b.curClause != havingClause { b.curClause = whereClause } + b.optFlag |= flagConvertOuterToInnerJoin conditions := splitWhere(where) expressions := make([]expression.Expression, 0, len(conditions)) @@ -5533,7 +5535,7 @@ func (b *PlanBuilder) buildProjUponView(_ context.Context, dbName model.CIStr, t // buildApplyWithJoinType builds apply plan with outerPlan and innerPlan, which apply join with particular join type for // every row from outerPlan and the whole innerPlan. func (b *PlanBuilder) buildApplyWithJoinType(outerPlan, innerPlan base.LogicalPlan, tp JoinType, markNoDecorrelate bool) base.LogicalPlan { - b.optFlag = b.optFlag | flagPredicatePushDown | flagBuildKeyInfo | flagDecorrelate + b.optFlag = b.optFlag | flagPredicatePushDown | flagBuildKeyInfo | flagDecorrelate | flagConvertOuterToInnerJoin ap := LogicalApply{LogicalJoin: LogicalJoin{JoinType: tp}, NoDecorrelate: markNoDecorrelate}.Init(b.ctx, b.getSelectOffset()) ap.SetChildren(outerPlan, innerPlan) ap.names = make([]*types.FieldName, outerPlan.Schema().Len()+innerPlan.Schema().Len()) @@ -5555,7 +5557,7 @@ func (b *PlanBuilder) buildApplyWithJoinType(outerPlan, innerPlan base.LogicalPl // buildSemiApply builds apply plan with outerPlan and innerPlan, which apply semi-join for every row from outerPlan and the whole innerPlan. func (b *PlanBuilder) buildSemiApply(outerPlan, innerPlan base.LogicalPlan, condition []expression.Expression, asScalar, not, considerRewrite, markNoDecorrelate bool) (base.LogicalPlan, error) { - b.optFlag = b.optFlag | flagPredicatePushDown | flagBuildKeyInfo | flagDecorrelate + b.optFlag = b.optFlag | flagPredicatePushDown | flagBuildKeyInfo | flagDecorrelate | flagConvertOuterToInnerJoin join, err := b.buildSemiJoin(outerPlan, innerPlan, condition, asScalar, not, considerRewrite) if err != nil { diff --git a/pkg/planner/core/logical_plans.go b/pkg/planner/core/logical_plans.go index bb123a0e587cf..828ad7809dd39 100644 --- a/pkg/planner/core/logical_plans.go +++ b/pkg/planner/core/logical_plans.go @@ -875,7 +875,7 @@ func (p *LogicalProjection) ExtractFD() *fd.FDSet { // the dependent columns in scalar function should be also considered as output columns as well. outputColsUniqueIDs.Insert(int(one.UniqueID)) } - notnull := isNullFilteredOneExpr(p.SCtx(), p.schema, x) + notnull := util.IsNullRejected(p.SCtx(), p.schema, x) if notnull || determinants.SubsetOf(fds.NotNullCols) { notnullColsUniqueIDs.Insert(scalarUniqueID) } @@ -1015,7 +1015,7 @@ func (la *LogicalAggregation) ExtractFD() *fd.FDSet { determinants.Insert(int(one.UniqueID)) groupByColsOutputCols.Insert(int(one.UniqueID)) } - notnull := isNullFilteredOneExpr(la.SCtx(), la.schema, x) + notnull := util.IsNullRejected(la.SCtx(), la.schema, x) if notnull || determinants.SubsetOf(fds.NotNullCols) { notnullColsUniqueIDs.Insert(scalarUniqueID) } @@ -1182,7 +1182,7 @@ func extractNotNullFromConds(conditions []expression.Expression, p base.LogicalP for _, condition := range conditions { var cols []*expression.Column cols = expression.ExtractColumnsFromExpressions(cols, []expression.Expression{condition}, nil) - if isNullFilteredOneExpr(p.SCtx(), p.Schema(), condition) { + if util.IsNullRejected(p.SCtx(), p.Schema(), condition) { for _, col := range cols { notnullColsUniqueIDs.Insert(int(col.UniqueID)) } diff --git a/pkg/planner/core/optimizer.go b/pkg/planner/core/optimizer.go index fcf7383a8b1d7..473e62402d6e7 100644 --- a/pkg/planner/core/optimizer.go +++ b/pkg/planner/core/optimizer.go @@ -998,7 +998,7 @@ func logicalOptimize(ctx context.Context, flag uint64, logic base.LogicalPlan) ( // The order of flags is same as the order of optRule in the list. // We use a bitmask to record which opt rules should be used. If the i-th bit is 1, it means we should // apply i-th optimizing rule. - if (flag&(1<DataScan(s)}(test.t.a,test.t.a)->Projection", "Join{DataScan(t)->Aggr(count(test.t.c),firstrow(test.t.a))->DataScan(s)}(test.t.a,test.t.a)->Projection->Projection", "Join{DataScan(t)->Aggr(count(test.t.c),firstrow(test.t.a))->DataScan(s)}(test.t.a,test.t.a)->Aggr(firstrow(Column#25),count(test.t.b))->Projection->Projection", - "Join{DataScan(t)->DataScan(s)->Sel([eq(test.t.a, test.t.a)])->Aggr(count(test.t.b))}(test.t.c,Column#37)->Projection", + "Apply{DataScan(t)->DataScan(s)->Sel([eq(test.t.a, test.t.a)])->Aggr(count(test.t.b))}->Projection", "Join{DataScan(t)->DataScan(s)->Aggr(count(test.t.b),firstrow(test.t.a))}(test.t.a,test.t.a)->Projection->Projection->Projection", "Join{Join{DataScan(t1)->DataScan(t2)}->DataScan(s)->Aggr(count(test.t.b),firstrow(test.t.a))}(test.t.a,test.t.a)->Projection->Projection->Projection", "Join{DataScan(t)->DataScan(s)->Aggr(count(1),firstrow(test.t.a))}(test.t.a,test.t.a)->Projection->Projection->Projection", @@ -120,7 +120,7 @@ "Join{DataScan(t1)->DataScan(t)->Projection->Limit}(test.t.b,test.t.b)->Projection->Projection", "Join{DataScan(t)->Join{DataScan(s)->DataScan(k)}(test.t.d,test.t.d)(test.t.c,test.t.c)->Aggr(sum(test.t.a))->Projection}->Projection", "Join{DataScan(t1)->DataScan(t2)->Aggr(max(test.t.a),firstrow(test.t.b))}(test.t.b,test.t.b)->Projection->Sel([eq(test.t.b, Column#25)])->Projection", - "Join{DataScan(t1)->DataScan(t2)->Sel([eq(test.t.g, test.t.g) or(eq(test.t.b, 4), eq(test.t.b, 2))])->Aggr(avg(test.t.a))}->Projection->Sel([eq(cast(test.t.b, decimal(10,0) BINARY), Column#25)])->Projection", + "Apply{DataScan(t1)->DataScan(t2)->Sel([eq(test.t.g, test.t.g) or(eq(test.t.b, 4), eq(test.t.b, 2))])->Aggr(avg(test.t.a))}->Projection->Sel([eq(cast(test.t.b, decimal(10,0) BINARY), Column#25)])->Projection", "Join{DataScan(t1)->DataScan(t2)->Aggr(max(test.t.a),firstrow(test.t.b))}(test.t.b,test.t.b)->Projection->Sel([eq(test.t.b, Column#25)])->Projection", "Join{DataScan(t1)->DataScan(t2)}(test.t.a,test.t.a)(test.t.b,test.t.b)->Projection", "Join{DataScan(t1)->DataScan(t2)}(test.t.a,test.t.a)->Projection", @@ -135,8 +135,8 @@ "Join{Join{DataScan(t)->DataScan(x)->Aggr(firstrow(test.t.a))}(test.t.a,test.t.a)->Projection->DataScan(x)->Aggr(firstrow(test.t.a))}(test.t.a,test.t.a)->Projection->Projection", "Join{Join{DataScan(t)->DataScan(x)}(test.t.a,test.t.a)->DataScan(x)->Aggr(firstrow(test.t.a))}(test.t.a,test.t.a)->Projection->Projection", "Join{Join{DataScan(t)->DataScan(x)->Aggr(firstrow(test.t.a))}(test.t.a,test.t.a)->Projection->DataScan(x)->Aggr(firstrow(test.t.a))}(test.t.a,test.t.a)->Projection->Projection", - "Join{DataScan(t1)->DataScan(t2)->Sel([eq(test.t.a, test.t.a)])->Projection->Sort->Limit}->Projection->Sel([eq(test.t.b, test.t.b)])->Projection", - "Join{DataScan(t2)->DataScan(t1)->Sel([eq(test.t.a, test.t.a)])->Projection}->Projection", + "Apply{DataScan(t1)->DataScan(t2)->Sel([eq(test.t.a, test.t.a)])->Projection->Sort->Limit}->Projection->Sel([eq(test.t.b, test.t.b)])->Projection", + "Apply{DataScan(t2)->DataScan(t1)->Sel([eq(test.t.a, test.t.a)])->Projection}->Projection", "DataScan(t2)->Aggr(count(1))->Projection" ] }, diff --git a/pkg/planner/util/BUILD.bazel b/pkg/planner/util/BUILD.bazel index af04ffaa5c027..7fdf3291fb3c1 100644 --- a/pkg/planner/util/BUILD.bazel +++ b/pkg/planner/util/BUILD.bazel @@ -6,6 +6,7 @@ go_library( "byitem.go", "expression.go", "misc.go", + "null_misc.go", "path.go", ], importpath = "github.com/pingcap/tidb/pkg/planner/util", diff --git a/pkg/planner/util/null_misc.go b/pkg/planner/util/null_misc.go new file mode 100644 index 0000000000000..5549eb1697c9a --- /dev/null +++ b/pkg/planner/util/null_misc.go @@ -0,0 +1,94 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package util + +import ( + "github.com/pingcap/tidb/pkg/expression" + "github.com/pingcap/tidb/pkg/parser/ast" + "github.com/pingcap/tidb/pkg/planner/context" +) + +// IsNullRejected check whether a condition is null-rejected +// A condition would be null-rejected in one of following cases: +// If it is a predicate containing a reference to an inner table that evaluates to UNKNOWN or FALSE +// when one of its arguments is NULL. +// If it is a conjunction containing a null-rejected condition as a conjunct. +// If it is a disjunction of null-rejected conditions. +func IsNullRejected(ctx context.PlanContext, schema *expression.Schema, expr expression.Expression) bool { + exprCtx := ctx.GetNullRejectCheckExprCtx() + expr = expression.PushDownNot(exprCtx, expr) + if expression.ContainOuterNot(expr) { + return false + } + sc := ctx.GetSessionVars().StmtCtx + for _, cond := range expression.SplitCNFItems(expr) { + if isNullRejectedSpecially(ctx, schema, expr) { + return true + } + + result := expression.EvaluateExprWithNull(exprCtx, schema, cond) + x, ok := result.(*expression.Constant) + if !ok { + continue + } + if x.Value.IsNull() { + return true + } else if isTrue, err := x.Value.ToBool(sc.TypeCtxOrDefault()); err == nil && isTrue == 0 { + return true + } + } + return false +} + +// isNullRejectedSpecially handles some null-rejected cases specially, since the current in +// EvaluateExprWithNull is too strict for some cases, e.g. #49616. +func isNullRejectedSpecially(ctx context.PlanContext, schema *expression.Schema, expr expression.Expression) bool { + return specialNullRejectedCase1(ctx, schema, expr) // only 1 case now +} + +// specialNullRejectedCase1 is mainly for #49616. +// Case1 specially handles `null-rejected OR (null-rejected AND {others})`, then no matter what the result +// of `{others}` is (True, False or Null), the result of this predicate is null, so this predicate is null-rejected. +func specialNullRejectedCase1(ctx context.PlanContext, schema *expression.Schema, expr expression.Expression) bool { + isFunc := func(e expression.Expression, lowerFuncName string) *expression.ScalarFunction { + f, ok := e.(*expression.ScalarFunction) + if !ok { + return nil + } + if f.FuncName.L == lowerFuncName { + return f + } + return nil + } + orFunc := isFunc(expr, ast.LogicOr) + if orFunc == nil { + return false + } + for i := 0; i < 2; i++ { + andFunc := isFunc(orFunc.GetArgs()[i], ast.LogicAnd) + if andFunc == nil { + continue + } + if !IsNullRejected(ctx, schema, orFunc.GetArgs()[1-i]) { + continue // the other side should be null-rejected: null-rejected OR (... AND ...) + } + for _, andItem := range expression.SplitCNFItems(andFunc) { + if IsNullRejected(ctx, schema, andItem) { + return true // hit the case in the comment: null-rejected OR (null-rejected AND ...) + } + } + } + return false +} diff --git a/pkg/server/conn_test.go b/pkg/server/conn_test.go index 20c501294bd04..e0f7b19d0d62d 100644 --- a/pkg/server/conn_test.go +++ b/pkg/server/conn_test.go @@ -838,7 +838,7 @@ func TestPrefetchPointKeys4Update(t *testing.T) { require.True(t, txn.Valid()) snap := txn.GetSnapshot() //nolint:forcetypeassert - require.Equal(t, 4, snap.(snapshotCache).SnapCacheHitCount()) + require.Equal(t, 6, snap.(snapshotCache).SnapCacheHitCount()) tk.MustExec("commit") tk.MustQuery("select * from prefetch").Check(testkit.Rows("1 1 2", "2 2 4", "3 3 4")) @@ -888,7 +888,7 @@ func TestPrefetchPointKeys4Delete(t *testing.T) { require.True(t, txn.Valid()) snap := txn.GetSnapshot() //nolint:forcetypeassert - require.Equal(t, 4, snap.(snapshotCache).SnapCacheHitCount()) + require.Equal(t, 6, snap.(snapshotCache).SnapCacheHitCount()) tk.MustExec("commit") tk.MustQuery("select * from prefetch").Check(testkit.Rows("4 4 4", "5 5 5", "6 6 6")) diff --git a/pkg/session/session.go b/pkg/session/session.go index b5285509a6904..3a4469f02caf2 100644 --- a/pkg/session/session.go +++ b/pkg/session/session.go @@ -3125,7 +3125,11 @@ var ( {ddl.BackgroundSubtaskTableSQL, ddl.BackgroundSubtaskTableID}, {ddl.BackgroundSubtaskHistoryTableSQL, ddl.BackgroundSubtaskHistoryTableID}, } - mdlTable = "create table mysql.tidb_mdl_info(job_id BIGINT NOT NULL PRIMARY KEY, version BIGINT NOT NULL, table_ids text(65535));" + mdlTable = `create table mysql.tidb_mdl_info( + job_id BIGINT NOT NULL PRIMARY KEY, + version BIGINT NOT NULL, + table_ids text(65535) + );` ) func splitAndScatterTable(store kv.Storage, tableIDs []int64) { diff --git a/pkg/sessionctx/context.go b/pkg/sessionctx/context.go index b2c3efbad6f71..ad835a4eebff5 100644 --- a/pkg/sessionctx/context.go +++ b/pkg/sessionctx/context.go @@ -72,6 +72,7 @@ type Context interface { // RollbackTxn rolls back the current transaction. RollbackTxn(ctx context.Context) // CommitTxn commits the current transaction. + // buffered KV changes will be discarded, call StmtCommit if you want to commit them. CommitTxn(ctx context.Context) error // Txn returns the current transaction which is created before executing a statement. // The returned kv.Transaction is not nil, but it maybe pending or invalid. @@ -141,9 +142,17 @@ type Context interface { HasDirtyContent(tid int64) bool // StmtCommit flush all changes by the statement to the underlying transaction. + // it must be called before CommitTxn, else all changes since last StmtCommit + // will be lost. For SQL statement, StmtCommit or StmtRollback is called automatically. + // the "Stmt" not only means SQL statement, but also any KV changes, such as + // meta KV. StmtCommit(ctx context.Context) // StmtRollback provides statement level rollback. The parameter `forPessimisticRetry` should be true iff it's used // for auto-retrying execution of DMLs in pessimistic transactions. + // if error happens when you are handling batch of KV changes since last StmtCommit + // or StmtRollback, and you don't want them to be committed, you must call StmtRollback + // before you start another batch, otherwise, the previous changes might be committed + // unexpectedly. StmtRollback(ctx context.Context, isForPessimisticRetry bool) // StmtGetMutation gets the binlog mutation for current statement. StmtGetMutation(int64) *binlog.TableMutation diff --git a/pkg/sessionctx/stmtctx/BUILD.bazel b/pkg/sessionctx/stmtctx/BUILD.bazel index f9d7583b03130..d766981321934 100644 --- a/pkg/sessionctx/stmtctx/BUILD.bazel +++ b/pkg/sessionctx/stmtctx/BUILD.bazel @@ -29,6 +29,7 @@ go_library( "//pkg/util/tracing", "@com_github_tikv_client_go_v2//tikvrpc", "@org_golang_x_exp//maps", + "@org_golang_x_sync//singleflight", "@org_uber_go_atomic//:atomic", ], ) diff --git a/pkg/sessionctx/stmtctx/stmtctx.go b/pkg/sessionctx/stmtctx/stmtctx.go index c05bf2e7b4425..5eb9beeb6aa52 100644 --- a/pkg/sessionctx/stmtctx/stmtctx.go +++ b/pkg/sessionctx/stmtctx/stmtctx.go @@ -50,6 +50,7 @@ import ( "github.com/tikv/client-go/v2/tikvrpc" atomic2 "go.uber.org/atomic" "golang.org/x/exp/maps" + "golang.org/x/sync/singleflight" ) var taskIDAlloc uint64 @@ -337,7 +338,7 @@ type StatementContext struct { // NeededItems stores the columns/indices whose stats are needed for planner. NeededItems []model.StatsLoadItem // ResultCh to receive stats loading results - ResultCh chan StatsLoadResult + ResultCh []<-chan singleflight.Result // LoadStartTime is to record the load start time to calculate latency LoadStartTime time.Time } diff --git a/pkg/statistics/fmsketch.go b/pkg/statistics/fmsketch.go index 147702545f488..01aa1ab99bd05 100644 --- a/pkg/statistics/fmsketch.go +++ b/pkg/statistics/fmsketch.go @@ -42,10 +42,22 @@ var fmSketchPool = sync.Pool{ }, } -// FMSketch is used to count the number of distinct elements in a set. +// MaxSketchSize is the maximum size of the hashset in the FM sketch. +// TODO: add this attribute to PB and persist it instead of using a fixed number(executor.maxSketchSize) +const MaxSketchSize = 10000 + +// FMSketch (Flajolet–Martin Sketch) is a probabilistic data structure used for estimating the number of distinct elements in a stream. +// It uses a hash function to map each element to a binary number and counts the number of trailing zeroes in each hashed value. +// The maximum number of trailing zeroes observed gives an estimate of the logarithm of the number of distinct elements. +// This approach allows the FM sketch to handle large streams of data in a memory-efficient way. +// +// See https://en.wikipedia.org/wiki/Flajolet%E2%80%93Martin_algorithm type FMSketch struct { + // A set to store unique hashed values. hashset *swiss.Map[uint64, bool] - mask uint64 + // A binary mask used to track the maximum number of trailing zeroes in the hashed values. + mask uint64 + // The maximum size of the hashset. If the size exceeds this value, the mask size will be doubled and some hashed values will be removed from the hashset. maxSize int } @@ -71,19 +83,30 @@ func (s *FMSketch) Copy() *FMSketch { return result } -// NDV returns the ndv of the sketch. +// NDV returns the estimated number of distinct values (NDV) in the sketch. func (s *FMSketch) NDV() int64 { if s == nil { return 0 } + // The size of the mask (incremented by one) is 2^r, where r is the maximum number of trailing zeroes observed in the hashed values. + // The count of unique hashed values is the number of unique elements in the hashset. + // This estimation method is based on the Flajolet-Martin algorithm for estimating the number of distinct elements in a stream. return int64(s.mask+1) * int64(s.hashset.Count()) } +// insertHashValue inserts a hashed value into the sketch. func (s *FMSketch) insertHashValue(hashVal uint64) { + // If the hashed value is already in the sketch (determined by bitwise AND with the mask), return without inserting. + // This is because the number of trailing zeroes in the hashed value is less than or equal to the mask value. if (hashVal & s.mask) != 0 { return } + // Put the hashed value into the hashset. s.hashset.Put(hashVal, true) + // If the count of unique hashed values exceeds the maximum size, + // double the mask size and remove any hashed values from the hashset that are now within the mask. + // This is to ensure that the mask value is always a power of two minus one (i.e., a binary number of the form 111...), + // which allows us to quickly check the number of trailing zeroes in a hashed value by performing a bitwise AND operation with the mask. if s.hashset.Count() > s.maxSize { s.mask = s.mask*2 + 1 s.hashset.Iter(func(k uint64, _ bool) (stop bool) { @@ -204,7 +227,7 @@ func DecodeFMSketch(data []byte) (*FMSketch, error) { return nil, errors.Trace(err) } fm := FMSketchFromProto(p) - fm.maxSize = 10000 // TODO: add this attribute to PB and persist it instead of using a fixed number(executor.maxSketchSize) + fm.maxSize = MaxSketchSize return fm, nil } diff --git a/pkg/statistics/handle/syncload/BUILD.bazel b/pkg/statistics/handle/syncload/BUILD.bazel index ed6e310786a2a..3be7fe67caa52 100644 --- a/pkg/statistics/handle/syncload/BUILD.bazel +++ b/pkg/statistics/handle/syncload/BUILD.bazel @@ -17,9 +17,11 @@ go_library( "//pkg/statistics/handle/types", "//pkg/types", "//pkg/util", + "//pkg/util/intest", "//pkg/util/logutil", "@com_github_pingcap_errors//:errors", "@com_github_pingcap_failpoint//:failpoint", + "@org_golang_x_sync//singleflight", "@org_uber_go_zap//:zap", ], ) diff --git a/pkg/statistics/handle/syncload/stats_syncload.go b/pkg/statistics/handle/syncload/stats_syncload.go index 0ae6161a2cf8c..b0bd43166f3ce 100644 --- a/pkg/statistics/handle/syncload/stats_syncload.go +++ b/pkg/statistics/handle/syncload/stats_syncload.go @@ -32,8 +32,10 @@ import ( statstypes "github.com/pingcap/tidb/pkg/statistics/handle/types" "github.com/pingcap/tidb/pkg/types" "github.com/pingcap/tidb/pkg/util" + "github.com/pingcap/tidb/pkg/util/intest" "github.com/pingcap/tidb/pkg/util/logutil" "go.uber.org/zap" + "golang.org/x/sync/singleflight" ) // RetryCount is the max retry count for a sync load task. @@ -44,6 +46,8 @@ type statsSyncLoad struct { StatsLoad statstypes.StatsLoad } +var globalStatsSyncLoadSingleFlight singleflight.Group + // NewStatsSyncLoad creates a new StatsSyncLoad. func NewStatsSyncLoad(statsHandle statstypes.StatsHandle) statstypes.StatsSyncLoad { s := &statsSyncLoad{statsHandle: statsHandle} @@ -78,25 +82,27 @@ func (s *statsSyncLoad) SendLoadRequests(sc *stmtctx.StatementContext, neededHis } sc.StatsLoad.Timeout = timeout sc.StatsLoad.NeededItems = remainedItems - sc.StatsLoad.ResultCh = make(chan stmtctx.StatsLoadResult, len(remainedItems)) - tasks := make([]*statstypes.NeededItemTask, 0) + sc.StatsLoad.ResultCh = make([]<-chan singleflight.Result, 0, len(remainedItems)) for _, item := range remainedItems { - task := &statstypes.NeededItemTask{ - Item: item, - ToTimeout: time.Now().Local().Add(timeout), - ResultCh: sc.StatsLoad.ResultCh, - } - tasks = append(tasks, task) - } - timer := time.NewTimer(timeout) - defer timer.Stop() - for _, task := range tasks { - select { - case s.StatsLoad.NeededItemsCh <- task: - continue - case <-timer.C: - return errors.New("sync load stats channel is full and timeout sending task to channel") - } + localItem := item + resultCh := globalStatsSyncLoadSingleFlight.DoChan(localItem.Key(), func() (any, error) { + timer := time.NewTimer(timeout) + defer timer.Stop() + task := &statstypes.NeededItemTask{ + Item: localItem, + ToTimeout: time.Now().Local().Add(timeout), + ResultCh: make(chan stmtctx.StatsLoadResult, 1), + } + select { + case s.StatsLoad.NeededItemsCh <- task: + result, ok := <-task.ResultCh + intest.Assert(ok, "task.ResultCh cannot be closed") + return result, nil + case <-timer.C: + return nil, errors.New("sync load stats channel is full and timeout sending task to channel") + } + }) + sc.StatsLoad.ResultCh = append(sc.StatsLoad.ResultCh, resultCh) } sc.StatsLoad.LoadStartTime = time.Now() return nil @@ -122,25 +128,34 @@ func (*statsSyncLoad) SyncWaitStatsLoad(sc *stmtctx.StatementContext) error { metrics.SyncLoadCounter.Inc() timer := time.NewTimer(sc.StatsLoad.Timeout) defer timer.Stop() - for { + for _, resultCh := range sc.StatsLoad.ResultCh { select { - case result, ok := <-sc.StatsLoad.ResultCh: + case result, ok := <-resultCh: if !ok { return errors.New("sync load stats channel closed unexpectedly") } - if result.HasError() { - errorMsgs = append(errorMsgs, result.ErrorMsg()) - } - delete(resultCheckMap, result.Item) - if len(resultCheckMap) == 0 { - metrics.SyncLoadHistogram.Observe(float64(time.Since(sc.StatsLoad.LoadStartTime).Milliseconds())) - return nil + // this error is from statsSyncLoad.SendLoadRequests which start to task and send task into worker, + // not the stats loading error + if result.Err != nil { + errorMsgs = append(errorMsgs, result.Err.Error()) + } else { + val := result.Val.(stmtctx.StatsLoadResult) + // this error is from the stats loading error + if val.HasError() { + errorMsgs = append(errorMsgs, val.ErrorMsg()) + } + delete(resultCheckMap, val.Item) } case <-timer.C: metrics.SyncLoadTimeoutCounter.Inc() return errors.New("sync load stats timeout") } } + if len(resultCheckMap) == 0 { + metrics.SyncLoadHistogram.Observe(float64(time.Since(sc.StatsLoad.LoadStartTime).Milliseconds())) + return nil + } + return nil } // removeHistLoadedColumns removed having-hist columns based on neededColumns and statsCache. @@ -230,33 +245,17 @@ func (s *statsSyncLoad) HandleOneTask(sctx sessionctx.Context, lastTask *statsty task = lastTask } result := stmtctx.StatsLoadResult{Item: task.Item.TableItemID} - resultChan := s.StatsLoad.Singleflight.DoChan(task.Item.Key(), func() (any, error) { - err := s.handleOneItemTask(task) - return nil, err - }) - timeout := time.Until(task.ToTimeout) - select { - case sr := <-resultChan: - // sr.Val is always nil. - if sr.Err == nil { - task.ResultCh <- result - return nil, nil - } - if !isVaildForRetry(task) { - result.Error = sr.Err - task.ResultCh <- result - return nil, nil - } - return task, sr.Err - case <-time.After(timeout): - if !isVaildForRetry(task) { - result.Error = errors.New("stats loading timeout") - task.ResultCh <- result - return nil, nil - } - task.ToTimeout.Add(time.Duration(sctx.GetSessionVars().StatsLoadSyncWait.Load()) * time.Microsecond) - return task, nil + err = s.handleOneItemTask(task) + if err == nil { + task.ResultCh <- result + return nil, nil + } + if !isVaildForRetry(task) { + result.Error = err + task.ResultCh <- result + return nil, nil } + return task, err } func isVaildForRetry(task *statstypes.NeededItemTask) bool { diff --git a/pkg/statistics/handle/syncload/stats_syncload_test.go b/pkg/statistics/handle/syncload/stats_syncload_test.go index 4b38387430c49..8a8929d9d93e5 100644 --- a/pkg/statistics/handle/syncload/stats_syncload_test.go +++ b/pkg/statistics/handle/syncload/stats_syncload_test.go @@ -208,13 +208,23 @@ func TestConcurrentLoadHistWithPanicAndFail(t *testing.T) { task1, err1 := h.HandleOneTask(testKit.Session().(sessionctx.Context), nil, exitCh) require.Error(t, err1) require.NotNil(t, task1) + for _, resultCh := range stmtCtx1.StatsLoad.ResultCh { + select { + case <-resultCh: + t.Logf("stmtCtx1.ResultCh should not get anything") + t.FailNow() + default: + } + } + for _, resultCh := range stmtCtx2.StatsLoad.ResultCh { + select { + case <-resultCh: + t.Logf("stmtCtx1.ResultCh should not get anything") + t.FailNow() + default: + } + } select { - case <-stmtCtx1.StatsLoad.ResultCh: - t.Logf("stmtCtx1.ResultCh should not get anything") - t.FailNow() - case <-stmtCtx2.StatsLoad.ResultCh: - t.Logf("stmtCtx2.ResultCh should not get anything") - t.FailNow() case <-task1.ResultCh: t.Logf("task1.ResultCh should not get anything") t.FailNow() @@ -225,17 +235,18 @@ func TestConcurrentLoadHistWithPanicAndFail(t *testing.T) { task3, err3 := h.HandleOneTask(testKit.Session().(sessionctx.Context), task1, exitCh) require.NoError(t, err3) require.Nil(t, task3) - - task, err3 := h.HandleOneTask(testKit.Session().(sessionctx.Context), nil, exitCh) - require.NoError(t, err3) - require.Nil(t, task) - - rs1, ok1 := <-stmtCtx1.StatsLoad.ResultCh - require.True(t, ok1) - require.Equal(t, neededColumns[0].TableItemID, rs1.Item) - rs2, ok2 := <-stmtCtx2.StatsLoad.ResultCh - require.True(t, ok2) - require.Equal(t, neededColumns[0].TableItemID, rs2.Item) + for _, resultCh := range stmtCtx1.StatsLoad.ResultCh { + rs1, ok1 := <-resultCh + require.True(t, rs1.Shared) + require.True(t, ok1) + require.Equal(t, neededColumns[0].TableItemID, rs1.Val.(stmtctx.StatsLoadResult).Item) + } + for _, resultCh := range stmtCtx2.StatsLoad.ResultCh { + rs1, ok1 := <-resultCh + require.True(t, rs1.Shared) + require.True(t, ok1) + require.Equal(t, neededColumns[0].TableItemID, rs1.Val.(stmtctx.StatsLoadResult).Item) + } stat = h.GetTableStats(tableInfo) hg := stat.Columns[tableInfo.Columns[2].ID].Histogram @@ -312,11 +323,11 @@ func TestRetry(t *testing.T) { result, err1 := h.HandleOneTask(testKit.Session().(sessionctx.Context), task1, exitCh) require.NoError(t, err1) require.Nil(t, result) - select { - case <-task1.ResultCh: - default: - t.Logf("task1.ResultCh should get nothing") - t.FailNow() + for _, resultCh := range stmtCtx1.StatsLoad.ResultCh { + rs1, ok1 := <-resultCh + require.True(t, rs1.Shared) + require.True(t, ok1) + require.Error(t, rs1.Val.(stmtctx.StatsLoadResult).Error) } task1.Retry = 0 for i := 0; i < syncload.RetryCount*5; i++ { diff --git a/pkg/statistics/handle/types/BUILD.bazel b/pkg/statistics/handle/types/BUILD.bazel index 328d1a75b1159..df7a6ea2acfa1 100644 --- a/pkg/statistics/handle/types/BUILD.bazel +++ b/pkg/statistics/handle/types/BUILD.bazel @@ -17,6 +17,5 @@ go_library( "//pkg/types", "//pkg/util", "//pkg/util/sqlexec", - "@org_golang_x_sync//singleflight", ], ) diff --git a/pkg/statistics/handle/types/interfaces.go b/pkg/statistics/handle/types/interfaces.go index 5c1b41d7fbd65..8726cd7d64a7c 100644 --- a/pkg/statistics/handle/types/interfaces.go +++ b/pkg/statistics/handle/types/interfaces.go @@ -30,7 +30,6 @@ import ( "github.com/pingcap/tidb/pkg/types" "github.com/pingcap/tidb/pkg/util" "github.com/pingcap/tidb/pkg/util/sqlexec" - "golang.org/x/sync/singleflight" ) // StatsGC is used to GC unnecessary stats. @@ -398,7 +397,6 @@ type NeededItemTask struct { type StatsLoad struct { NeededItemsCh chan *NeededItemTask TimeoutItemsCh chan *NeededItemTask - Singleflight singleflight.Group sync.Mutex } diff --git a/pkg/structure/list.go b/pkg/structure/list.go index 5674229665a80..3211668f01d02 100644 --- a/pkg/structure/list.go +++ b/pkg/structure/list.go @@ -22,6 +22,7 @@ import ( "github.com/pingcap/tidb/pkg/kv" ) +// valid index: [LIndex, RIndex) type listMeta struct { LIndex int64 RIndex int64 diff --git a/pkg/util/mock/context.go b/pkg/util/mock/context.go index f9c0aed4c8051..cef4cd19e0b0c 100644 --- a/pkg/util/mock/context.go +++ b/pkg/util/mock/context.go @@ -229,6 +229,11 @@ func (c *Context) GetPlanCtx() planctx.PlanContext { return c } +// GetNullRejectCheckExprCtx gets the expression context with null rejected check. +func (c *Context) GetNullRejectCheckExprCtx() exprctx.ExprContext { + return exprctx.WithNullRejectCheck(c) +} + // GetExprCtx returns the expression context of the session. func (c *Context) GetExprCtx() exprctx.ExprContext { return c diff --git a/pkg/util/ranger/detacher.go b/pkg/util/ranger/detacher.go index 4d6f26bfb5cfa..a482767803f04 100644 --- a/pkg/util/ranger/detacher.go +++ b/pkg/util/ranger/detacher.go @@ -487,7 +487,9 @@ func (d *rangeDetacher) detachCNFCondAndBuildRangeForIndex(conditions []expressi // excludeToIncludeForIntPoint converts `(i` to `[i+1` and `i)` to `i-1]` if `i` is integer. // For example, if p is `(3`, i.e., point { value: int(3), excl: true, start: true }, it is equal to `[4`, i.e., point { value: int(4), excl: false, start: true }. // Similarly, if p is `8)`, i.e., point { value: int(8), excl: true, start: false}, it is equal to `7]`, i.e., point { value: int(7), excl: false, start: false }. -// If return value is nil, it means p is unsatisfiable. For example, `(MaxInt64` is unsatisfiable. +// If return value is nil, it means p is unsatisfiable. For example, `(MaxUint64` is unsatisfiable. +// The boundary value will be treated as the bigger type: For example, `(MaxInt64` of type KindInt64 will become `[MaxInt64+1` of type KindUint64, +// and vice versa for `0)` of type KindUint64 will become `-1]` of type KindInt64. func excludeToIncludeForIntPoint(p *point) *point { if !p.excl { return p @@ -496,9 +498,10 @@ func excludeToIncludeForIntPoint(p *point) *point { val := p.value.GetInt64() if p.start { if val == math.MaxInt64 { - return nil + p.value.SetUint64(uint64(val + 1)) + } else { + p.value.SetInt64(val + 1) } - p.value.SetInt64(val + 1) p.excl = false } else { if val == math.MinInt64 { @@ -517,9 +520,10 @@ func excludeToIncludeForIntPoint(p *point) *point { p.excl = false } else { if val == 0 { - return nil + p.value.SetInt64(int64(val - 1)) + } else { + p.value.SetUint64(val - 1) } - p.value.SetUint64(val - 1) p.excl = false } } diff --git a/pkg/util/ranger/ranger_test.go b/pkg/util/ranger/ranger_test.go index fb6dbc2386c00..0a260cc3bc05a 100644 --- a/pkg/util/ranger/ranger_test.go +++ b/pkg/util/ranger/ranger_test.go @@ -2340,3 +2340,22 @@ func TestIssue40997(t *testing.T) { "└─TableRowIDScan_6(Probe) 0.67 cop[tikv] table:t71706696 keep order:false, stats:pseudo", )) } + +func TestIssue50051(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + + tk.MustExec("drop table if exists tt") + tk.MustExec("CREATE TABLE tt (c bigint UNSIGNED not null, d int not null, PRIMARY KEY (c,d));") + tk.MustExec("insert into tt values (9223372036854775810, 3);") + tk.MustQuery("SELECT c FROM tt WHERE c>9223372036854775807 AND c>1;").Check(testkit.Rows("9223372036854775810")) + + tk.MustExec("drop table if exists t5") + tk.MustExec("drop table if exists t6") + tk.MustExec("CREATE TABLE `t5` (`d` int not null, `c` int not null, PRIMARY KEY (`d`, `c`));") + tk.MustExec("CREATE TABLE `t6` (`d` bigint UNSIGNED not null);") + tk.MustExec("insert into t5 values (-3, 6);") + tk.MustExec("insert into t6 values (0), (1), (2), (3);") + tk.MustQuery("select d from t5 where d < (select min(d) from t6) and d < 3;").Check(testkit.Rows("-3")) +} diff --git a/tests/integrationtest/r/planner/core/indexjoin.result b/tests/integrationtest/r/planner/core/indexjoin.result index 63353d01f9148..2f1a3c00943aa 100644 --- a/tests/integrationtest/r/planner/core/indexjoin.result +++ b/tests/integrationtest/r/planner/core/indexjoin.result @@ -163,22 +163,20 @@ create table t3 (a int, b float, index idx (b)); create table t4 (a int, b double); insert into t3 values (1, 1.0), (1, 2.0), (2, 3.0); insert into t4 values (1, 1.11111111); -explain format='brief' select /*+ INL_JOIN(tmp) */ * from (select b from t group by b) tmp, t1 where tmp.b=t1.b; +explain format='brief' select /*+ INL_JOIN(tmp) */ * from (select b from t3 group by b) tmp, t4 where tmp.b=t4.b; id estRows task access object operator info -Projection 9990.00 root planner__core__indexjoin.t.b, planner__core__indexjoin.t1.a, planner__core__indexjoin.t1.b -└─HashJoin 9990.00 root inner join, equal:[eq(planner__core__indexjoin.t1.b, planner__core__indexjoin.t.b)] - ├─HashAgg(Build) 7992.00 root group by:planner__core__indexjoin.t.b, funcs:firstrow(planner__core__indexjoin.t.b)->planner__core__indexjoin.t.b - │ └─IndexReader 7992.00 root index:HashAgg - │ └─HashAgg 7992.00 cop[tikv] group by:planner__core__indexjoin.t.b, - │ └─Selection 9990.00 cop[tikv] not(isnull(planner__core__indexjoin.t.b)) - │ └─IndexFullScan 10000.00 cop[tikv] table:t, index:idx(a, b) keep order:false, stats:pseudo - └─IndexReader(Probe) 9990.00 root index:Selection - └─Selection 9990.00 cop[tikv] not(isnull(planner__core__indexjoin.t1.b)) - └─IndexFullScan 10000.00 cop[tikv] table:t1, index:idx(a, b) keep order:false, stats:pseudo -select /*+ INL_JOIN(tmp) */ tmp.b, t1.b from (select b from t group by b) tmp, t1 where tmp.b=t1.b order by tmp.b, t1.b; +Projection 9990.00 root planner__core__indexjoin.t3.b, planner__core__indexjoin.t4.a, planner__core__indexjoin.t4.b +└─IndexJoin 9990.00 root inner join, inner:HashAgg, outer key:planner__core__indexjoin.t4.b, inner key:planner__core__indexjoin.t3.b, equal cond:eq(planner__core__indexjoin.t4.b, planner__core__indexjoin.t3.b) + ├─TableReader(Build) 9990.00 root data:Selection + │ └─Selection 9990.00 cop[tikv] not(isnull(planner__core__indexjoin.t4.b)) + │ └─TableFullScan 10000.00 cop[tikv] table:t4 keep order:false, stats:pseudo + └─HashAgg(Probe) 79840080.00 root group by:planner__core__indexjoin.t3.b, funcs:firstrow(planner__core__indexjoin.t3.b)->planner__core__indexjoin.t3.b + └─IndexReader 79840080.00 root index:HashAgg + └─HashAgg 79840080.00 cop[tikv] group by:planner__core__indexjoin.t3.b, + └─Selection 9990.00 cop[tikv] not(isnull(planner__core__indexjoin.t3.b)) + └─IndexRangeScan 10000.00 cop[tikv] table:t3, index:idx(b) range: decided by [eq(planner__core__indexjoin.t3.b, planner__core__indexjoin.t4.b)], keep order:false, stats:pseudo +select /*+ INL_JOIN(tmp) */ tmp.b, t4.b from (select b from t3 group by b) tmp, t4 where tmp.b=t4.b order by tmp.b, t4.b; b b -1 1 -2 2 explain format='brief' select /*+ INL_JOIN(tmp) */ * from (select a, b from t where a>=1 group by a, b) tmp, t1 where tmp.a=t1.a; id estRows task access object operator info IndexJoin 1.25 root inner join, inner:HashAgg, outer key:planner__core__indexjoin.t1.a, inner key:planner__core__indexjoin.t.a, equal cond:eq(planner__core__indexjoin.t1.a, planner__core__indexjoin.t.a) diff --git a/tests/integrationtest/t/planner/core/indexjoin.test b/tests/integrationtest/t/planner/core/indexjoin.test index cdfd0b6484d97..2f1ae3b62c463 100644 --- a/tests/integrationtest/t/planner/core/indexjoin.test +++ b/tests/integrationtest/t/planner/core/indexjoin.test @@ -42,8 +42,8 @@ create table t3 (a int, b float, index idx (b)); create table t4 (a int, b double); insert into t3 values (1, 1.0), (1, 2.0), (2, 3.0); insert into t4 values (1, 1.11111111); -explain format='brief' select /*+ INL_JOIN(tmp) */ * from (select b from t group by b) tmp, t1 where tmp.b=t1.b; -select /*+ INL_JOIN(tmp) */ tmp.b, t1.b from (select b from t group by b) tmp, t1 where tmp.b=t1.b order by tmp.b, t1.b; +explain format='brief' select /*+ INL_JOIN(tmp) */ * from (select b from t3 group by b) tmp, t4 where tmp.b=t4.b; +select /*+ INL_JOIN(tmp) */ tmp.b, t4.b from (select b from t3 group by b) tmp, t4 where tmp.b=t4.b order by tmp.b, t4.b; # Test the selection, projection inside of "zippedChildren" explain format='brief' select /*+ INL_JOIN(tmp) */ * from (select a, b from t where a>=1 group by a, b) tmp, t1 where tmp.a=t1.a;