diff --git a/br/pkg/backup/client.go b/br/pkg/backup/client.go index 5414b35518de8..755a8217b9c00 100644 --- a/br/pkg/backup/client.go +++ b/br/pkg/backup/client.go @@ -674,7 +674,7 @@ func BuildBackupSchemas( // Treat cached table as normal table. tableInfo.TableCacheStatusType = model.TableCacheStatusDisable - if tableInfo.PKIsHandle && tableInfo.ContainsAutoRandomBits() { + if tableInfo.ContainsAutoRandomBits() { // this table has auto_random id, we need backup and rebase in restoration var globalAutoRandID int64 globalAutoRandID, err = autoIDAccess.RandomID().Get() diff --git a/br/pkg/lightning/backend/backend.go b/br/pkg/lightning/backend/backend.go index b31f88324eba8..9c99212930845 100644 --- a/br/pkg/lightning/backend/backend.go +++ b/br/pkg/lightning/backend/backend.go @@ -106,6 +106,10 @@ type CheckCtx struct { // TargetInfoGetter defines the interfaces to get target information. type TargetInfoGetter interface { + // FetchRemoteDBModels obtains the models of all databases. Currently, only + // the database name is filled. + FetchRemoteDBModels(ctx context.Context) ([]*model.DBInfo, error) + // FetchRemoteTableModels obtains the models of all tables given the schema // name. The returned table info does not need to be precise if the encoder, // is not requiring them, but must at least fill in the following fields for diff --git a/br/pkg/lightning/backend/kv/BUILD.bazel b/br/pkg/lightning/backend/kv/BUILD.bazel index 087bb5849e52d..f880519eb71f4 100644 --- a/br/pkg/lightning/backend/kv/BUILD.bazel +++ b/br/pkg/lightning/backend/kv/BUILD.bazel @@ -34,6 +34,7 @@ go_library( "//tablecodec", "//types", "//util/chunk", + "//util/codec", "//util/mathutil", "//util/topsql/stmtstats", "@com_github_docker_go_units//:go-units", diff --git a/br/pkg/lightning/backend/kv/base.go b/br/pkg/lightning/backend/kv/base.go index 6e3f165d92fc6..3f90aaa66b2d5 100644 --- a/br/pkg/lightning/backend/kv/base.go +++ b/br/pkg/lightning/backend/kv/base.go @@ -33,6 +33,7 @@ import ( "github.com/pingcap/tidb/table/tables" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/chunk" + "github.com/pingcap/tidb/util/codec" "go.uber.org/zap" "go.uber.org/zap/zapcore" ) @@ -204,7 +205,7 @@ func (e *BaseKVEncoder) Record2KV(record, originalRow []types.Datum, rowID int64 kvPairs := e.SessionCtx.TakeKvPairs() for i := 0; i < len(kvPairs.Pairs); i++ { var encoded [9]byte // The max length of encoded int64 is 9. - kvPairs.Pairs[i].RowID = common.EncodeIntRowIDToBuf(encoded[:0], rowID) + kvPairs.Pairs[i].RowID = codec.EncodeComparableVarint(encoded[:0], rowID) } e.recordCache = record[:0] return kvPairs, nil diff --git a/br/pkg/lightning/backend/local/iterator.go b/br/pkg/lightning/backend/local/iterator.go index 9595f8f68ea51..8877fdb958c1e 100644 --- a/br/pkg/lightning/backend/local/iterator.go +++ b/br/pkg/lightning/backend/local/iterator.go @@ -27,9 +27,6 @@ import ( // Iter abstract iterator method for Ingester. type Iter interface { - // Seek seek to specify position. - // if key not found, seeks next key position in iter. - Seek(key []byte) bool // Error return current error on this iter. Error() error // First moves this iter to the first key. @@ -88,15 +85,6 @@ type DupDetectOpt struct { ReportErrOnDup bool } -func (d *dupDetectIter) Seek(key []byte) bool { - rawKey := d.keyAdapter.Encode(nil, key, ZeroRowID) - if d.err != nil || !d.iter.SeekGE(rawKey) { - return false - } - d.fill() - return d.err == nil -} - func (d *dupDetectIter) First() bool { if d.err != nil || !d.iter.First() { return false @@ -155,7 +143,7 @@ func (d *dupDetectIter) Next() bool { } if d.option.ReportErrOnDup { dupKey := make([]byte, len(d.curKey)) - dupVal := make([]byte, len(d.iter.Value())) + dupVal := make([]byte, len(d.curVal)) copy(dupKey, d.curKey) copy(dupVal, d.curVal) d.err = common.ErrFoundDuplicateKeys.FastGenByArgs(dupKey, dupVal) @@ -225,15 +213,6 @@ type dupDBIter struct { err error } -func (d *dupDBIter) Seek(key []byte) bool { - rawKey := d.keyAdapter.Encode(nil, key, ZeroRowID) - if d.err != nil || !d.iter.SeekGE(rawKey) { - return false - } - d.curKey, d.err = d.keyAdapter.Decode(d.curKey[:0], d.iter.Key()) - return d.err == nil -} - func (d *dupDBIter) Error() error { if d.err != nil { return d.err diff --git a/br/pkg/lightning/backend/local/iterator_test.go b/br/pkg/lightning/backend/local/iterator_test.go index 0a00cb5864cc0..e376550b5431f 100644 --- a/br/pkg/lightning/backend/local/iterator_test.go +++ b/br/pkg/lightning/backend/local/iterator_test.go @@ -178,55 +178,6 @@ func TestDupDetectIterator(t *testing.T) { } } -func TestDupDetectIterSeek(t *testing.T) { - pairs := []common.KvPair{ - { - Key: []byte{1, 2, 3, 0}, - Val: randBytes(128), - RowID: common.EncodeIntRowID(1), - }, - { - Key: []byte{1, 2, 3, 1}, - Val: randBytes(128), - RowID: common.EncodeIntRowID(2), - }, - { - Key: []byte{1, 2, 3, 1}, - Val: randBytes(128), - RowID: common.EncodeIntRowID(3), - }, - { - Key: []byte{1, 2, 3, 2}, - Val: randBytes(128), - RowID: common.EncodeIntRowID(4), - }, - } - - storeDir := t.TempDir() - db, err := pebble.Open(filepath.Join(storeDir, "kv"), &pebble.Options{}) - require.NoError(t, err) - - keyAdapter := dupDetectKeyAdapter{} - wb := db.NewBatch() - for _, p := range pairs { - key := keyAdapter.Encode(nil, p.Key, p.RowID) - require.NoError(t, wb.Set(key, p.Val, nil)) - } - require.NoError(t, wb.Commit(pebble.Sync)) - - dupDB, err := pebble.Open(filepath.Join(storeDir, "duplicates"), &pebble.Options{}) - require.NoError(t, err) - iter := newDupDetectIter(db, keyAdapter, &pebble.IterOptions{}, dupDB, log.L(), DupDetectOpt{}) - - require.True(t, iter.Seek([]byte{1, 2, 3, 1})) - require.Equal(t, pairs[1].Val, iter.Value()) - require.True(t, iter.Next()) - require.Equal(t, pairs[3].Val, iter.Value()) - require.NoError(t, iter.Close()) - require.NoError(t, db.Close()) - require.NoError(t, dupDB.Close()) -} - func TestKeyAdapterEncoding(t *testing.T) { keyAdapter := dupDetectKeyAdapter{} srcKey := []byte{1, 2, 3} diff --git a/br/pkg/lightning/backend/local/key_adapter.go b/br/pkg/lightning/backend/local/key_adapter.go index 5d9d119b2c3ec..a59428dfef49d 100644 --- a/br/pkg/lightning/backend/local/key_adapter.go +++ b/br/pkg/lightning/backend/local/key_adapter.go @@ -15,10 +15,7 @@ package local import ( - "math" - "github.com/pingcap/errors" - "github.com/pingcap/tidb/br/pkg/lightning/common" "github.com/pingcap/tidb/util/codec" ) @@ -102,8 +99,7 @@ func (dupDetectKeyAdapter) EncodedLen(key []byte, rowID []byte) int { var _ KeyAdapter = dupDetectKeyAdapter{} -// static vars for rowID var ( - MinRowID = common.EncodeIntRowID(math.MinInt64) - ZeroRowID = common.EncodeIntRowID(0) + // MinRowID is the minimum rowID value after DupDetectKeyAdapter.Encode(). + MinRowID = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0} ) diff --git a/br/pkg/lightning/backend/local/key_adapter_test.go b/br/pkg/lightning/backend/local/key_adapter_test.go index d80efa6de2af4..cc42f4f283f25 100644 --- a/br/pkg/lightning/backend/local/key_adapter_test.go +++ b/br/pkg/lightning/backend/local/key_adapter_test.go @@ -20,9 +20,14 @@ import ( "math" "sort" "testing" + "time" "unsafe" "github.com/pingcap/tidb/br/pkg/lightning/common" + "github.com/pingcap/tidb/kv" + "github.com/pingcap/tidb/sessionctx/stmtctx" + "github.com/pingcap/tidb/types" + "github.com/pingcap/tidb/util/codec" "github.com/stretchr/testify/require" ) @@ -35,8 +40,9 @@ func randBytes(n int) []byte { func TestNoopKeyAdapter(t *testing.T) { keyAdapter := noopKeyAdapter{} key := randBytes(32) - require.Len(t, key, keyAdapter.EncodedLen(key, ZeroRowID)) - encodedKey := keyAdapter.Encode(nil, key, ZeroRowID) + rowID := randBytes(8) + require.Len(t, key, keyAdapter.EncodedLen(key, rowID)) + encodedKey := keyAdapter.Encode(nil, key, rowID) require.Equal(t, key, encodedKey) decodedKey, err := keyAdapter.Decode(nil, encodedKey) @@ -160,3 +166,50 @@ func TestDecodeKeyDstIsInsufficient(t *testing.T) { require.Equal(t, key, buf2[4:]) } } + +func TestMinRowID(t *testing.T) { + keyApapter := dupDetectKeyAdapter{} + key := []byte("key") + val := []byte("val") + shouldBeMin := keyApapter.Encode(key, val, MinRowID) + + rowIDs := make([][]byte, 0, 20) + + // DDL + + rowIDs = append(rowIDs, kv.IntHandle(math.MinInt64).Encoded()) + rowIDs = append(rowIDs, kv.IntHandle(-1).Encoded()) + rowIDs = append(rowIDs, kv.IntHandle(0).Encoded()) + rowIDs = append(rowIDs, kv.IntHandle(math.MaxInt64).Encoded()) + handleData := []types.Datum{ + types.NewIntDatum(math.MinInt64), + types.NewIntDatum(-1), + types.NewIntDatum(0), + types.NewIntDatum(math.MaxInt64), + types.NewBytesDatum(make([]byte, 1)), + types.NewBytesDatum(make([]byte, 7)), + types.NewBytesDatum(make([]byte, 8)), + types.NewBytesDatum(make([]byte, 9)), + types.NewBytesDatum(make([]byte, 100)), + } + for _, d := range handleData { + sc := &stmtctx.StatementContext{TimeZone: time.Local} + encodedKey, err := codec.EncodeKey(sc, nil, d) + require.NoError(t, err) + ch, err := kv.NewCommonHandle(encodedKey) + require.NoError(t, err) + rowIDs = append(rowIDs, ch.Encoded()) + } + + // lightning, IMPORT INTO, ... + + numRowIDs := []int64{math.MinInt64, -1, 0, math.MaxInt64} + for _, id := range numRowIDs { + rowIDs = append(rowIDs, codec.EncodeComparableVarint(nil, id)) + } + + for _, id := range rowIDs { + bs := keyApapter.Encode(key, val, id) + require.True(t, bytes.Compare(bs, shouldBeMin) >= 0) + } +} diff --git a/br/pkg/lightning/backend/local/local.go b/br/pkg/lightning/backend/local/local.go index 94090385b5cd3..5e2c6ddd72831 100644 --- a/br/pkg/lightning/backend/local/local.go +++ b/br/pkg/lightning/backend/local/local.go @@ -279,6 +279,11 @@ func NewTargetInfoGetter(tls *common.TLS, db *sql.DB, pdCli pd.Client) backend.T } } +// FetchRemoteDBModels implements the `backend.TargetInfoGetter` interface. +func (g *targetInfoGetter) FetchRemoteDBModels(ctx context.Context) ([]*model.DBInfo, error) { + return tikv.FetchRemoteDBModelsFromTLS(ctx, g.tls) +} + // FetchRemoteTableModels obtains the models of all tables given the schema name. // It implements the `TargetInfoGetter` interface. func (g *targetInfoGetter) FetchRemoteTableModels(ctx context.Context, schemaName string) ([]*model.TableInfo, error) { diff --git a/br/pkg/lightning/backend/tidb/tidb.go b/br/pkg/lightning/backend/tidb/tidb.go index 774191d48431d..a67605219e724 100644 --- a/br/pkg/lightning/backend/tidb/tidb.go +++ b/br/pkg/lightning/backend/tidb/tidb.go @@ -136,6 +136,38 @@ func NewTargetInfoGetter(db *sql.DB) backend.TargetInfoGetter { } } +// FetchRemoteDBModels implements the `backend.TargetInfoGetter` interface. +func (b *targetInfoGetter) FetchRemoteDBModels(ctx context.Context) ([]*model.DBInfo, error) { + results := []*model.DBInfo{} + logger := log.FromContext(ctx) + s := common.SQLWithRetry{ + DB: b.db, + Logger: logger, + } + err := s.Transact(ctx, "fetch db models", func(_ context.Context, tx *sql.Tx) error { + results = results[:0] + + rows, e := tx.Query("SHOW DATABASES") + if e != nil { + return e + } + defer rows.Close() + + for rows.Next() { + var dbName string + if e := rows.Scan(&dbName); e != nil { + return e + } + dbInfo := &model.DBInfo{ + Name: model.NewCIStr(dbName), + } + results = append(results, dbInfo) + } + return rows.Err() + }) + return results, err +} + // FetchRemoteTableModels obtains the models of all tables given the schema name. // It implements the `backend.TargetInfoGetter` interface. // TODO: refactor diff --git a/br/pkg/lightning/checkpoints/checkpoints.go b/br/pkg/lightning/checkpoints/checkpoints.go index 07e1c319c6813..39bceba8677fd 100644 --- a/br/pkg/lightning/checkpoints/checkpoints.go +++ b/br/pkg/lightning/checkpoints/checkpoints.go @@ -69,7 +69,7 @@ const WholeTableEngineID = math.MaxInt32 // remember to increase the version number in case of incompatible change. const ( CheckpointTableNameTask = "task_v2" - CheckpointTableNameTable = "table_v8" + CheckpointTableNameTable = "table_v9" CheckpointTableNameEngine = "engine_v5" CheckpointTableNameChunk = "chunk_v5" ) @@ -106,7 +106,7 @@ const ( status tinyint unsigned DEFAULT 30, alloc_base bigint NOT NULL DEFAULT 0, table_id bigint NOT NULL DEFAULT 0, - table_info text NOT NULL, + table_info longtext NOT NULL, create_time timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP, update_time timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, kv_bytes bigint unsigned NOT NULL DEFAULT 0, diff --git a/br/pkg/lightning/common/BUILD.bazel b/br/pkg/lightning/common/BUILD.bazel index 519e81ed03175..92736a56e4117 100644 --- a/br/pkg/lightning/common/BUILD.bazel +++ b/br/pkg/lightning/common/BUILD.bazel @@ -126,6 +126,7 @@ go_test( "@com_github_stretchr_testify//require", "@org_golang_google_grpc//codes", "@org_golang_google_grpc//status", + "@org_golang_x_time//rate", "@org_uber_go_goleak//:goleak", "@org_uber_go_multierr//:multierr", ], diff --git a/br/pkg/lightning/common/retry.go b/br/pkg/lightning/common/retry.go index f5270cdbb2fd4..b701afefe4a17 100644 --- a/br/pkg/lightning/common/retry.go +++ b/br/pkg/lightning/common/retry.go @@ -40,6 +40,8 @@ var retryableErrorMsgList = []string{ // this error happens on when distsql.Checksum calls TiKV // see https://github.com/pingcap/tidb/blob/2c3d4f1ae418881a95686e8b93d4237f2e76eec6/store/copr/coprocessor.go#L941 "coprocessor task terminated due to exceeding the deadline", + // fix https://github.com/pingcap/tidb/issues/51383 + "rate: wait", } func isRetryableFromErrorMessage(err error) bool { diff --git a/br/pkg/lightning/common/retry_test.go b/br/pkg/lightning/common/retry_test.go index 6e12c241381f7..5aa06237eeded 100644 --- a/br/pkg/lightning/common/retry_test.go +++ b/br/pkg/lightning/common/retry_test.go @@ -21,6 +21,7 @@ import ( "net" "net/url" "testing" + "time" "github.com/go-sql-driver/mysql" "github.com/pingcap/errors" @@ -28,6 +29,7 @@ import ( drivererr "github.com/pingcap/tidb/store/driver/error" "github.com/stretchr/testify/require" "go.uber.org/multierr" + "golang.org/x/time/rate" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" ) @@ -117,4 +119,13 @@ func TestIsRetryableError(t *testing.T) { require.False(t, IsRetryableError(multierr.Combine(context.Canceled, &net.DNSError{IsTimeout: true}))) require.True(t, IsRetryableError(errors.New("other error: Coprocessor task terminated due to exceeding the deadline"))) + + // error from limiter + l := rate.NewLimiter(rate.Limit(1), 1) + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + // context has 1 second timeout, can't wait for 10 seconds + err = l.WaitN(ctx, 10) + require.Error(t, err) + require.True(t, IsRetryableError(err)) } diff --git a/br/pkg/lightning/common/util.go b/br/pkg/lightning/common/util.go index db53bc9f86ad2..0ad133e16a741 100644 --- a/br/pkg/lightning/common/util.go +++ b/br/pkg/lightning/common/util.go @@ -436,13 +436,16 @@ type KvPair struct { Key []byte // Val is the value of the KV pair Val []byte - // RowID is the row id of the KV pair. + // RowID identifies a KvPair in case two KvPairs are equal in Key and Val. It has + // two sources: + // + // When the KvPair is generated from ADD INDEX, the RowID is the encoded handle. + // + // Otherwise, the RowID is related to the row number in the source files, and + // encode the number with `codec.EncodeComparableVarint`. RowID []byte } -// EncodeIntRowIDToBuf encodes an int64 row id to a buffer. -var EncodeIntRowIDToBuf = codec.EncodeComparableVarint - // EncodeIntRowID encodes an int64 row id. func EncodeIntRowID(rowID int64) []byte { return codec.EncodeComparableVarint(nil, rowID) diff --git a/br/pkg/lightning/importer/get_pre_info.go b/br/pkg/lightning/importer/get_pre_info.go index 191fed628bdf6..c11041dfaa07e 100644 --- a/br/pkg/lightning/importer/get_pre_info.go +++ b/br/pkg/lightning/importer/get_pre_info.go @@ -91,6 +91,8 @@ type PreImportInfoGetter interface { // TargetInfoGetter defines the operations to get information from target. type TargetInfoGetter interface { + // FetchRemoteDBModels fetches the database structures from the remote target. + FetchRemoteDBModels(ctx context.Context) ([]*model.DBInfo, error) // FetchRemoteTableModels fetches the table structures from the remote target. FetchRemoteTableModels(ctx context.Context, schemaName string) ([]*model.TableInfo, error) // CheckVersionRequirements performs the check whether the target satisfies the version requirements. @@ -158,6 +160,11 @@ func NewTargetInfoGetterImpl( }, nil } +// FetchRemoteDBModels implements TargetInfoGetter. +func (g *TargetInfoGetterImpl) FetchRemoteDBModels(ctx context.Context) ([]*model.DBInfo, error) { + return g.backend.FetchRemoteDBModels(ctx) +} + // FetchRemoteTableModels fetches the table structures from the remote target. // It implements the TargetInfoGetter interface. func (g *TargetInfoGetterImpl) FetchRemoteTableModels(ctx context.Context, schemaName string) ([]*model.TableInfo, error) { @@ -800,6 +807,12 @@ func (p *PreImportInfoGetterImpl) IsTableEmpty(ctx context.Context, schemaName s return p.targetInfoGetter.IsTableEmpty(ctx, schemaName, tableName) } +// FetchRemoteDBModels fetches the database structures from the remote target. +// It implements the PreImportInfoGetter interface. +func (p *PreImportInfoGetterImpl) FetchRemoteDBModels(ctx context.Context) ([]*model.DBInfo, error) { + return p.targetInfoGetter.FetchRemoteDBModels(ctx) +} + // FetchRemoteTableModels fetches the table structures from the remote target. // It implements the PreImportInfoGetter interface. func (p *PreImportInfoGetterImpl) FetchRemoteTableModels(ctx context.Context, schemaName string) ([]*model.TableInfo, error) { diff --git a/br/pkg/lightning/importer/import.go b/br/pkg/lightning/importer/import.go index b5cd1262fa98b..17cc060a1b32a 100644 --- a/br/pkg/lightning/importer/import.go +++ b/br/pkg/lightning/importer/import.go @@ -580,29 +580,47 @@ type restoreSchemaWorker struct { func (worker *restoreSchemaWorker) addJob(sqlStr string, job *schemaJob) error { stmts, err := createIfNotExistsStmt(worker.parser, sqlStr, job.dbName, job.tblName) if err != nil { - worker.logger.Warn("failed to rewrite statement, will use raw input instead", - zap.String("db", job.dbName), - zap.String("table", job.tblName), - zap.String("statement", sqlStr), - zap.Error(err)) - job.stmts = []string{sqlStr} - } else { - job.stmts = stmts + return errors.Trace(err) } + job.stmts = stmts return worker.appendJob(job) } func (worker *restoreSchemaWorker) makeJobs( dbMetas []*mydump.MDDatabaseMeta, + getDBs func(context.Context) ([]*model.DBInfo, error), getTables func(context.Context, string) ([]*model.TableInfo, error), ) error { defer func() { close(worker.jobCh) worker.quit() }() - var err error + + if len(dbMetas) == 0 { + return nil + } + // 1. restore databases, execute statements concurrency + + dbs, err := getDBs(worker.ctx) + if err != nil { + worker.logger.Warn("get databases from downstream failed", zap.Error(err)) + } + dbSet := make(set.StringSet, len(dbs)) + for _, db := range dbs { + dbSet.Insert(db.Name.L) + } + for _, dbMeta := range dbMetas { + // if downstream already has this database, we can skip ddl job + if dbSet.Exist(strings.ToLower(dbMeta.Name)) { + worker.logger.Info( + "database already exists in downstream, skip processing the source file", + zap.String("db", dbMeta.Name), + ) + continue + } + sql := dbMeta.GetSchema(worker.ctx, worker.store) err = worker.addJob(sql, &schemaJob{ dbName: dbMeta.Name, @@ -617,18 +635,28 @@ func (worker *restoreSchemaWorker) makeJobs( if err != nil { return err } + // 2. restore tables, execute statements concurrency + for _, dbMeta := range dbMetas { // we can ignore error here, and let check failed later if schema not match - tables, _ := getTables(worker.ctx, dbMeta.Name) - tableMap := make(map[string]struct{}) + tables, err := getTables(worker.ctx, dbMeta.Name) + if err != nil { + worker.logger.Warn("get tables from downstream failed", zap.Error(err)) + } + tableSet := make(set.StringSet, len(tables)) for _, t := range tables { - tableMap[t.Name.L] = struct{}{} + tableSet.Insert(t.Name.L) } for _, tblMeta := range dbMeta.Tables { - if _, ok := tableMap[strings.ToLower(tblMeta.Name)]; ok { + if tableSet.Exist(strings.ToLower(tblMeta.Name)) { // we already has this table in TiDB. // we should skip ddl job and let SchemaValid check. + worker.logger.Info( + "table already exists in downstream, skip processing the source file", + zap.String("db", dbMeta.Name), + zap.String("table", tblMeta.Name), + ) continue } else if tblMeta.SchemaFile.FileMeta.Path == "" { return common.ErrSchemaNotExists.GenWithStackByArgs(dbMeta.Name, tblMeta.Name) @@ -703,7 +731,6 @@ loop: var err error if session == nil { session, err = func() (*sql.Conn, error) { - // TODO: support lightning in SQL return worker.db.Conn(worker.ctx) }() if err != nil { @@ -826,7 +853,7 @@ func (rc *Controller) restoreSchema(ctx context.Context) error { for i := 0; i < concurrency; i++ { go worker.doJob() } - err := worker.makeJobs(rc.dbMetas, rc.preInfoGetter.FetchRemoteTableModels) + err := worker.makeJobs(rc.dbMetas, rc.preInfoGetter.FetchRemoteDBModels, rc.preInfoGetter.FetchRemoteTableModels) logTask.End(zap.ErrorLevel, err) if err != nil { return err diff --git a/br/pkg/lightning/importer/mock/mock.go b/br/pkg/lightning/importer/mock/mock.go index 6b0809729e1ef..d09fc83bb3925 100644 --- a/br/pkg/lightning/importer/mock/mock.go +++ b/br/pkg/lightning/importer/mock/mock.go @@ -212,6 +212,15 @@ func (t *TargetInfo) SetTableInfo(schemaName string, tableName string, tblInfo * t.dbTblInfoMap[schemaName][tableName] = tblInfo } +// FetchRemoteDBModels implements the TargetInfoGetter interface. +func (t *TargetInfo) FetchRemoteDBModels(_ context.Context) ([]*model.DBInfo, error) { + resultInfos := []*model.DBInfo{} + for dbName := range t.dbTblInfoMap { + resultInfos = append(resultInfos, &model.DBInfo{Name: model.NewCIStr(dbName)}) + } + return resultInfos, nil +} + // FetchRemoteTableModels fetches the table structures from the remote target. // It implements the TargetInfoGetter interface. func (t *TargetInfo) FetchRemoteTableModels(_ context.Context, schemaName string) ([]*model.TableInfo, error) { diff --git a/br/pkg/lightning/importer/restore_schema_test.go b/br/pkg/lightning/importer/restore_schema_test.go index b969e01bea358..e3c374295d3a4 100644 --- a/br/pkg/lightning/importer/restore_schema_test.go +++ b/br/pkg/lightning/importer/restore_schema_test.go @@ -136,6 +136,10 @@ func (s *restoreSchemaSuite) SetupTest() { s.controller, s.ctx = gomock.WithContext(context.Background(), s.T()) mockTargetInfoGetter := mock.NewMockTargetInfoGetter(s.controller) mockBackend := mock.NewMockBackend(s.controller) + mockTargetInfoGetter.EXPECT(). + FetchRemoteDBModels(gomock.Any()). + AnyTimes(). + Return([]*model.DBInfo{{Name: model.NewCIStr("fakedb")}}, nil) mockTargetInfoGetter.EXPECT(). FetchRemoteTableModels(gomock.Any(), gomock.Any()). AnyTimes(). diff --git a/br/pkg/lightning/lightning.go b/br/pkg/lightning/lightning.go index d9ec7536dd8c4..5e1ee36a0c7ba 100644 --- a/br/pkg/lightning/lightning.go +++ b/br/pkg/lightning/lightning.go @@ -753,8 +753,7 @@ func (l *Lightning) handlePostTask(w http.ResponseWriter, req *http.Request) { writeJSONError(w, http.StatusBadRequest, "cannot read request", err) return } - filteredData := utils.HideSensitive(string(data)) - log.L().Info("received task config", zap.String("content", filteredData)) + log.L().Info("received task config") cfg := config.NewConfig() if err = cfg.LoadFromGlobal(l.globalCfg); err != nil { diff --git a/br/pkg/lightning/mydump/loader.go b/br/pkg/lightning/mydump/loader.go index 309e55d6a86e6..95bb2f48902b3 100644 --- a/br/pkg/lightning/mydump/loader.go +++ b/br/pkg/lightning/mydump/loader.go @@ -785,7 +785,7 @@ func SampleFileCompressRatio(ctx context.Context, fileMeta SourceFileMeta, store // SampleParquetDataSize samples the data size of the parquet file. func SampleParquetDataSize(ctx context.Context, fileMeta SourceFileMeta, store storage.ExternalStorage) (int64, error) { totalRowCount, err := ReadParquetFileRowCountByFile(ctx, store, fileMeta) - if err != nil { + if totalRowCount == 0 || err != nil { return 0, err } diff --git a/br/pkg/lightning/mydump/loader_test.go b/br/pkg/lightning/mydump/loader_test.go index 69c3474d4cd1d..9320d4cfbe6de 100644 --- a/br/pkg/lightning/mydump/loader_test.go +++ b/br/pkg/lightning/mydump/loader_test.go @@ -1108,7 +1108,7 @@ func TestSampleFileCompressRatio(t *testing.T) { require.InDelta(t, ratio, 5000.0/float64(bf.Len()), 1e-5) } -func TestSampleParquetDataSize(t *testing.T) { +func testSampleParquetDataSize(t *testing.T, count int) { s := newTestMydumpLoaderSuite(t) store, err := storage.NewLocalStorage(s.sourceDir) require.NoError(t, err) @@ -1133,7 +1133,7 @@ func TestSampleParquetDataSize(t *testing.T) { t.Logf("seed: %d", seed) rand.Seed(seed) totalRowSize := 0 - for i := 0; i < 1000; i++ { + for i := 0; i < count; i++ { kl := rand.Intn(20) + 1 key := make([]byte, kl) kl, err = rand.Read(key) @@ -1166,3 +1166,8 @@ func TestSampleParquetDataSize(t *testing.T) { // expected error within 10%, so delta = totalRowSize / 10 require.InDelta(t, totalRowSize, size, float64(totalRowSize)/10) } + +func TestSampleParquetDataSize(t *testing.T) { + t.Run("count=1000", func(t *testing.T) { testSampleParquetDataSize(t, 1000) }) + t.Run("count=0", func(t *testing.T) { testSampleParquetDataSize(t, 0) }) +} diff --git a/br/pkg/mock/backend.go b/br/pkg/mock/backend.go index e32c73f3ceb7b..4e5c56c75d4d8 100644 --- a/br/pkg/mock/backend.go +++ b/br/pkg/mock/backend.go @@ -295,6 +295,21 @@ func (mr *MockTargetInfoGetterMockRecorder) CheckRequirements(arg0, arg1 interfa return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CheckRequirements", reflect.TypeOf((*MockTargetInfoGetter)(nil).CheckRequirements), arg0, arg1) } +// FetchRemoteDBModels mocks base method. +func (m *MockTargetInfoGetter) FetchRemoteDBModels(arg0 context.Context) ([]*model.DBInfo, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "FetchRemoteDBModels", arg0) + ret0, _ := ret[0].([]*model.DBInfo) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// FetchRemoteDBModels indicates an expected call of FetchRemoteDBModels. +func (mr *MockTargetInfoGetterMockRecorder) FetchRemoteDBModels(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FetchRemoteDBModels", reflect.TypeOf((*MockTargetInfoGetter)(nil).FetchRemoteDBModels), arg0) +} + // FetchRemoteTableModels mocks base method. func (m *MockTargetInfoGetter) FetchRemoteTableModels(arg0 context.Context, arg1 string) ([]*model.TableInfo, error) { m.ctrl.T.Helper() diff --git a/br/pkg/restore/db.go b/br/pkg/restore/db.go index 1f3f5d949e26e..3fe2f990800a9 100644 --- a/br/pkg/restore/db.go +++ b/br/pkg/restore/db.go @@ -254,7 +254,7 @@ func (db *DB) CreateTablePostRestore(ctx context.Context, table *metautil.Table, utils.EncloseName(table.DB.Name.O), utils.EncloseName(table.Info.Name.O), table.Info.AutoIncID) - } else if table.Info.PKIsHandle && table.Info.ContainsAutoRandomBits() { + } else if table.Info.ContainsAutoRandomBits() { restoreMetaSQL = fmt.Sprintf( "alter table %s.%s auto_random_base = %d", utils.EncloseName(table.DB.Name.O), diff --git a/br/pkg/utils/BUILD.bazel b/br/pkg/utils/BUILD.bazel index 08b697946e243..4c1e14d2d7c91 100644 --- a/br/pkg/utils/BUILD.bazel +++ b/br/pkg/utils/BUILD.bazel @@ -20,7 +20,6 @@ go_library( "retry.go", "safe_point.go", "schema.go", - "sensitive.go", "store_manager.go", "suspend_importing.go", "worker.go", @@ -89,12 +88,11 @@ go_test( "retry_test.go", "safe_point_test.go", "schema_test.go", - "sensitive_test.go", "suspend_importing_test.go", ], embed = [":utils"], flaky = True, - shard_count = 38, + shard_count = 37, deps = [ "//br/pkg/errors", "//br/pkg/metautil", diff --git a/br/pkg/utils/sensitive.go b/br/pkg/utils/sensitive.go deleted file mode 100644 index fcc31ee30b78d..0000000000000 --- a/br/pkg/utils/sensitive.go +++ /dev/null @@ -1,23 +0,0 @@ -// Copyright 2022 PingCAP, Inc. Licensed under Apache-2.0. - -package utils - -import ( - "regexp" -) - -var ( - passwordPatterns = `(password[\s]*=[\s]*(\\")?)(.*?)((\\")?\\n)` - - passwordRegexp *regexp.Regexp -) - -func init() { - passwordRegexp = regexp.MustCompile(passwordPatterns) -} - -// HideSensitive replace password with ******. -func HideSensitive(input string) string { - output := passwordRegexp.ReplaceAllString(input, "$1******$4") - return output -} diff --git a/br/pkg/utils/sensitive_test.go b/br/pkg/utils/sensitive_test.go deleted file mode 100644 index a14ce0619eb85..0000000000000 --- a/br/pkg/utils/sensitive_test.go +++ /dev/null @@ -1,41 +0,0 @@ -// Copyright 2022 PingCAP, Inc. Licensed under Apache-2.0. - -package utils - -import ( - "testing" - - "github.com/stretchr/testify/require" -) - -func TestHideSensitive(t *testing.T) { - strs := []struct { - old string - new string - }{ - { - `host = "127.0.0.1"\n user = "root"\n password = "/Q7B9DizNLLTTfiZHv9WoEAKamfpIUs="\n port = 3306\n`, - `host = "127.0.0.1"\n user = "root"\n password = ******\n port = 3306\n`, - }, - { - `host = "127.0.0.1"\n user = "root"\n password = ""\n port = 3306\n`, - `host = "127.0.0.1"\n user = "root"\n password = ******\n port = 3306\n`, - }, - { - `host = "127.0.0.1"\n user = "root"\n password= "/Q7B9DizNLLTTfiZHv9WoEAKamfpIUs="\n port = 3306\n`, - `host = "127.0.0.1"\n user = "root"\n password= ******\n port = 3306\n`, - }, - { - `host = "127.0.0.1"\n user = "root"\n password =""\n port = 3306\n`, - `host = "127.0.0.1"\n user = "root"\n password =******\n port = 3306\n`, - }, - { - `host = "127.0.0.1"\n user = "root"\n password=""\n port = 3306\n`, - `host = "127.0.0.1"\n user = "root"\n password=******\n port = 3306\n`, - }, - } - for i, str := range strs { - t.Logf("case #%d\n", i) - require.Equal(t, str.new, HideSensitive(str.old)) - } -} diff --git a/br/tests/br_autorandom/run.sh b/br/tests/br_autorandom/run.sh new file mode 100644 index 0000000000000..d84403173db7a --- /dev/null +++ b/br/tests/br_autorandom/run.sh @@ -0,0 +1,71 @@ +#!/bin/bash +# +# Copyright 2024 PingCAP, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +set -eu +. run_services +CUR=$(cd `dirname $0`; pwd) + +# const value +PREFIX="autorandom" # NOTICE: don't start with 'br' because `restart services` would remove file/directory br*. +res_file="$TEST_DIR/sql_res.$TEST_NAME.txt" + +# start a new cluster +echo "restart a services" +restart_services + +# prepare the data +echo "prepare the data" +run_sql "CREATE TABLE test.common (a BIGINT UNSIGNED AUTO_RANDOM(1), b VARCHAR(255), uid INT, c VARCHAR(255) DEFAULT 'c', PRIMARY KEY (a, b), UNIQUE INDEX (uid));" +run_sql "INSERT INTO test.common (b, uid, c) values ('a', 1, 'a');" +run_sql "INSERT INTO test.common (b, uid, c) values ('a', 2, 'a');" +run_sql "INSERT INTO test.common (b, uid, c) values ('a', 3, 'a');" +run_sql "INSERT INTO test.common (b, uid, c) values ('a', 4, 'a');" +run_sql "INSERT INTO test.common (b, uid, c) values ('a', 5, 'a');" +run_sql "INSERT INTO test.common (b, uid, c) values ('a', 6, 'a');" +run_sql "INSERT INTO test.common (b, uid, c) values ('a', 7, 'a');" +run_sql "INSERT INTO test.common (b, uid, c) values ('a', 8, 'a');" +run_sql "INSERT INTO test.common (b, uid, c) values ('a', 9, 'a');" +run_sql "INSERT INTO test.common (b, uid, c) values ('a', 10, 'a');" + +run_sql "CREATE TABLE test.pk (a BIGINT UNSIGNED AUTO_RANDOM(1), uid INT, c VARCHAR(255) DEFAULT 'c', PRIMARY KEY (a), UNIQUE INDEX (uid));" +run_sql "INSERT INTO test.pk (uid, c) values (1, 'a');" +run_sql "INSERT INTO test.pk (uid, c) values (2, 'a');" +run_sql "INSERT INTO test.pk (uid, c) values (3, 'a');" +run_sql "INSERT INTO test.pk (uid, c) values (4, 'a');" +run_sql "INSERT INTO test.pk (uid, c) values (5, 'a');" +run_sql "INSERT INTO test.pk (uid, c) values (6, 'a');" +run_sql "INSERT INTO test.pk (uid, c) values (7, 'a');" +run_sql "INSERT INTO test.pk (uid, c) values (8, 'a');" +run_sql "INSERT INTO test.pk (uid, c) values (9, 'a');" +run_sql "INSERT INTO test.pk (uid, c) values (10, 'a');" + +# backup & restore +run_br --pd $PD_ADDR backup full -s "local://$TEST_DIR/$PREFIX/full" +echo "restart a services" +restart_services +run_br --pd $PD_ADDR restore full -s "local://$TEST_DIR/$PREFIX/full" + +# new workload +for i in `seq 1 9`; do + run_sql "INSERT INTO test.common (b, uid) values ('a', 10) on duplicate key update c = 'b';" + run_sql "INSERT INTO test.pk (uid) values (10) on duplicate key update c = 'b';" +done + +# check consistency +run_sql "SELECT COUNT(*) AS RESCNT FROM test.common WHERE uid < 10 AND c = 'b';" +check_contains "RESCNT: 0" +run_sql "SELECT COUNT(*) AS RESCNT FROM test.pk WHERE uid < 10 AND c = 'b';" +check_contains "RESCNT: 0" diff --git a/br/tests/br_foreign_key/run.sh b/br/tests/br_foreign_key/run.sh index eafd38a180d40..db6e1a4a1672c 100644 --- a/br/tests/br_foreign_key/run.sh +++ b/br/tests/br_foreign_key/run.sh @@ -17,37 +17,40 @@ set -eu DB="$TEST_NAME" -run_sql "set @@global.foreign_key_checks=1;" -run_sql "set @@foreign_key_checks=1;" -run_sql "create schema $DB;" -run_sql "create table $DB.t1 (id int key);" -run_sql "create table $DB.t2 (id int key, a int, b int, foreign key fk_1 (a) references t1(id) ON UPDATE SET NULL ON DELETE SET NULL, foreign key fk_2 (b) references t1(id) ON DELETE CASCADE ON UPDATE CASCADE);" -run_sql "insert into $DB.t1 values (1), (2), (3);" -run_sql "insert into $DB.t2 values (1, 1, 1), (2, 2, 2), (3, 3, 3);" -run_sql "update $DB.t1 set id=id+10 where id in (1, 3);" -run_sql "delete from $DB.t1 where id = 2;" - -echo "backup start..." -run_br backup db --db "$DB" -s "local://$TEST_DIR/$DB" --pd $PD_ADDR - -run_sql "drop schema $DB;" - -echo "restore start..." -run_br restore db --db $DB -s "local://$TEST_DIR/$DB" --pd $PD_ADDR - -set -x - -run_sql "select count(*) from $DB.t1;" -check_contains 'count(*): 2' - -run_sql "select count(*) from $DB.t2;" -check_contains 'count(*): 2' - -run_sql "select id, a, b from $DB.t2;" -check_contains 'id: 1' -check_contains 'id: 3' -check_contains 'a: NULL' -check_contains 'b: 11' -check_contains 'b: 13' - -run_sql "drop schema $DB" +for DDL_BATCH_SIZE in 1 2; +do + run_sql "set @@global.foreign_key_checks=1;" + run_sql "set @@foreign_key_checks=1;" + run_sql "create schema $DB;" + run_sql "create table $DB.t1 (id int key);" + run_sql "create table $DB.t2 (id int key, a int, b int, foreign key fk_1 (a) references t1(id) ON UPDATE SET NULL ON DELETE SET NULL, foreign key fk_2 (b) references t1(id) ON DELETE CASCADE ON UPDATE CASCADE);" + run_sql "insert into $DB.t1 values (1), (2), (3);" + run_sql "insert into $DB.t2 values (1, 1, 1), (2, 2, 2), (3, 3, 3);" + run_sql "update $DB.t1 set id=id+10 where id in (1, 3);" + run_sql "delete from $DB.t1 where id = 2;" + + echo "backup start..." + run_br backup db --db "$DB" -s "local://$TEST_DIR/$DB-$DDL_BATCH_SIZE" --pd $PD_ADDR + + run_sql "drop schema $DB;" + + echo "restore start..." + run_br restore db --db $DB -s "local://$TEST_DIR/$DB-$DDL_BATCH_SIZE" --pd $PD_ADDR --ddl-batch-size=$DDL_BATCH_SIZE + + set -x + + run_sql "select count(*) from $DB.t1;" + check_contains 'count(*): 2' + + run_sql "select count(*) from $DB.t2;" + check_contains 'count(*): 2' + + run_sql "select id, a, b from $DB.t2;" + check_contains 'id: 1' + check_contains 'id: 3' + check_contains 'a: NULL' + check_contains 'b: 11' + check_contains 'b: 13' + + run_sql "drop schema $DB" +done diff --git a/br/tests/lightning_character_sets/run.sh b/br/tests/lightning_character_sets/run.sh index d1a7ea5728d16..4c09185853f95 100755 --- a/br/tests/lightning_character_sets/run.sh +++ b/br/tests/lightning_character_sets/run.sh @@ -78,6 +78,8 @@ check_contains 's: 5291' # test about unsupported charset in UTF-8 encoding dump files # test local backend run_lightning --config "tests/$TEST_NAME/greek.toml" -d "tests/$TEST_NAME/greek" 2>&1 | grep -q "Unknown character set: 'greek'" +# check TiDB does not receive the DDL +check_not_contains "greek" $TEST_DIR/tidb.log run_sql 'DROP DATABASE IF EXISTS charsets;' run_sql 'CREATE DATABASE charsets;' run_sql 'CREATE TABLE charsets.greek (c VARCHAR(20) PRIMARY KEY);' diff --git a/br/tests/lightning_checkpoint/run.sh b/br/tests/lightning_checkpoint/run.sh index 4e3d57ec6f158..7eabbad75046a 100755 --- a/br/tests/lightning_checkpoint/run.sh +++ b/br/tests/lightning_checkpoint/run.sh @@ -89,7 +89,7 @@ echo "******** Verify checkpoint no-op ********" run_lightning -d "$DBPATH" --enable-checkpoint=1 run_sql "$PARTIAL_IMPORT_QUERY" check_contains "s: $(( (1000 * $CHUNK_COUNT + 1001) * $CHUNK_COUNT * $TABLE_COUNT ))" -run_sql 'SELECT count(*) FROM `tidb_lightning_checkpoint_test_cppk`.table_v8 WHERE status >= 200' +run_sql 'SELECT count(*) FROM `tidb_lightning_checkpoint_test_cppk`.table_v9 WHERE status >= 200' check_contains "count(*): $TABLE_COUNT" # Start importing the tables. @@ -113,5 +113,5 @@ echo "******** Verify checkpoint no-op ********" run_lightning -d "$DBPATH" --enable-checkpoint=1 run_sql "$PARTIAL_IMPORT_QUERY" check_contains "s: $(( (1000 * $CHUNK_COUNT + 1001) * $CHUNK_COUNT * $TABLE_COUNT ))" -run_sql 'SELECT count(*) FROM `tidb_lightning_checkpoint_test_cppk`.table_v8 WHERE status >= 200' +run_sql 'SELECT count(*) FROM `tidb_lightning_checkpoint_test_cppk`.table_v9 WHERE status >= 200' check_contains "count(*): $TABLE_COUNT" diff --git a/br/tests/lightning_checkpoint_chunks/run.sh b/br/tests/lightning_checkpoint_chunks/run.sh index cf596ed574354..382fc2f3f2005 100755 --- a/br/tests/lightning_checkpoint_chunks/run.sh +++ b/br/tests/lightning_checkpoint_chunks/run.sh @@ -33,7 +33,7 @@ verify_checkpoint_noop() { run_sql 'SELECT count(i), sum(i) FROM cpch_tsr.tbl;' check_contains "count(i): $(($ROW_COUNT*$CHUNK_COUNT))" check_contains "sum(i): $(( $ROW_COUNT*$CHUNK_COUNT*(($CHUNK_COUNT+2)*$ROW_COUNT + 1)/2 ))" - run_sql 'SELECT count(*) FROM `tidb_lightning_checkpoint_test_cpch.1234567890.bak`.table_v8 WHERE status >= 200' + run_sql 'SELECT count(*) FROM `tidb_lightning_checkpoint_test_cpch.1234567890.bak`.table_v9 WHERE status >= 200' check_contains "count(*): 1" } diff --git a/br/tests/run_group.sh b/br/tests/run_group.sh index e454fec47cfad..8ba4e76e689e8 100755 --- a/br/tests/run_group.sh +++ b/br/tests/run_group.sh @@ -29,7 +29,7 @@ groups=( ["G05"]='br_range br_rawkv br_replica_read br_restore_TDE_enable br_restore_log_task_enable br_s3 br_shuffle_leader br_shuffle_region br_single_table' ["G06"]='br_skip_checksum br_small_batch_size br_split_region_fail br_systables br_table_filter br_txn' ["G07"]='br_clustered_index br_crypter br_table_partition br_tidb_placement_policy br_tiflash br_tikv_outage' - ["G08"]='br_tikv_outage2 br_ttl br_views_and_sequences br_z_gc_safepoint lightning_add_index lightning_alter_random lightning_auto_columns' + ["G08"]='br_tikv_outage2 br_ttl br_views_and_sequences br_z_gc_safepoint br_autorandom lightning_add_index lightning_alter_random lightning_auto_columns' ["G09"]='lightning_auto_random_default lightning_bom_file lightning_character_sets lightning_check_partial_imported lightning_checkpoint lightning_checkpoint_chunks lightning_checkpoint_columns lightning_checkpoint_dirty_tableid' ["G10"]='lightning_checkpoint_engines lightning_checkpoint_engines_order lightning_checkpoint_error_destroy lightning_checkpoint_parquet lightning_checkpoint_timestamp lightning_checksum_mismatch lightning_cmdline_override lightning_column_permutation lightning_common_handle' ["G11"]='lightning_compress lightning_concurrent-restore lightning_config_max_error lightning_config_skip_csv_header lightning_csv lightning_default-columns lightning_disable_scheduler_by_key_range lightning_disk_quota lightning_distributed_import' diff --git a/build/nogo_config.json b/build/nogo_config.json index a63305e1cc197..985cc01b2b695 100644 --- a/build/nogo_config.json +++ b/build/nogo_config.json @@ -144,6 +144,7 @@ }, "fieldalignment": { "exclude_files": { + "pkg/statistics/table.go": "disable this limitation that prevents us from splitting struct fields for clarity", "external/": "no need to vet third party code", ".*_generated\\.go$": "ignore generated code", ".*_/testmain\\.go$": "ignore code", diff --git a/config/BUILD.bazel b/config/BUILD.bazel index db2e165331bca..299e001e2cf90 100644 --- a/config/BUILD.bazel +++ b/config/BUILD.bazel @@ -37,7 +37,7 @@ go_test( data = glob(["**"]), embed = [":config"], flaky = True, - shard_count = 23, + shard_count = 24, deps = [ "//testkit/testsetup", "//util/logutil", diff --git a/config/config.go b/config/config.go index 63c181d51675e..508dbe27fca60 100644 --- a/config/config.go +++ b/config/config.go @@ -1237,13 +1237,16 @@ func (c *Config) RemovedVariableCheck(confFile string) error { // Load loads config options from a toml file. func (c *Config) Load(confFile string) error { metaData, err := toml.DecodeFile(confFile, c) + if err != nil { + return err + } if c.TokenLimit == 0 { c.TokenLimit = 1000 } // If any items in confFile file are not mapped into the Config struct, issue // an error and stop the server from starting. undecoded := metaData.Undecoded() - if len(undecoded) > 0 && err == nil { + if len(undecoded) > 0 { var undecodedItems []string for _, item := range undecoded { undecodedItems = append(undecodedItems, item.String()) diff --git a/config/config_test.go b/config/config_test.go index 5523165b935d9..42307f3e37280 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -1329,3 +1329,25 @@ func TestAutoScalerConfig(t *testing.T) { conf.UseAutoScaler = false }) } + +func TestInvalidConfigWithDeprecatedConfig(t *testing.T) { + tmpDir := t.TempDir() + configFile := filepath.Join(tmpDir, "config.toml") + + f, err := os.Create(configFile) + require.NoError(t, err) + + _, err = f.WriteString(` +[log] +slow-threshold = 1000 +[performance] +enforce-mpp = 1 + `) + require.NoError(t, err) + require.NoError(t, f.Sync()) + + var conf Config + err = conf.Load(configFile) + require.Error(t, err) + require.Equal(t, err.Error(), "toml: line 5 (last key \"performance.enforce-mpp\"): incompatible types: TOML value has type int64; destination has type boolean") +} diff --git a/ddl/column.go b/ddl/column.go index 572474294736a..9d95dbe64c3bf 100644 --- a/ddl/column.go +++ b/ddl/column.go @@ -2041,3 +2041,12 @@ func getChangingColumnOriginName(changingColumn *model.ColumnInfo) string { } return columnName[:pos] } + +func getExpressionIndexOriginName(expressionIdx *model.ColumnInfo) string { + columnName := strings.TrimPrefix(expressionIdx.Name.O, expressionIndexPrefix+"_") + var pos int + if pos = strings.LastIndex(columnName, "_"); pos == -1 { + return columnName + } + return columnName[:pos] +} diff --git a/ddl/db_integration_test.go b/ddl/db_integration_test.go index 5277bb1d86929..9dca70c504340 100644 --- a/ddl/db_integration_test.go +++ b/ddl/db_integration_test.go @@ -2135,6 +2135,13 @@ func TestDefaultColumnWithUUID(t *testing.T) { " `c` int(10) DEFAULT NULL,\n" + " `c1` varchar(256) DEFAULT uuid()\n" + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin")) + + // test modify column + tk.MustExec("drop table if exists t") + tk.MustExec("create table t (c int(10), c1 varchar(256) default rand());") + tk.MustExec("alter table t alter column c1 set default 'xx';") + tk.MustExec("insert into t values (1, default);") + tk.MustQuery("select c1 from t;").Check(testkit.Rows("xx")) } func TestChangingDBCharset(t *testing.T) { diff --git a/ddl/db_test.go b/ddl/db_test.go index aefcad56e751a..2ec42634d50ed 100644 --- a/ddl/db_test.go +++ b/ddl/db_test.go @@ -28,6 +28,7 @@ import ( "github.com/pingcap/failpoint" "github.com/pingcap/tidb/config" "github.com/pingcap/tidb/ddl" + "github.com/pingcap/tidb/ddl/testutil" ddlutil "github.com/pingcap/tidb/ddl/util" "github.com/pingcap/tidb/ddl/util/callback" "github.com/pingcap/tidb/domain" @@ -52,6 +53,7 @@ import ( "github.com/pingcap/tidb/util/dbterror" "github.com/pingcap/tidb/util/mock" "github.com/pingcap/tidb/util/sqlexec" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/tikv/client-go/v2/oracle" "github.com/tikv/client-go/v2/tikv" @@ -1621,3 +1623,57 @@ func TestMDLTruncateTable(t *testing.T) { require.True(t, timetk2.After(timeMain)) require.True(t, timetk3.After(timeMain)) } + +func TestInsertIgnore(t *testing.T) { + store, dom := testkit.CreateMockStoreAndDomain(t) + + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("create table t(a smallint(6) DEFAULT '-13202', b varchar(221) NOT NULL DEFAULT 'duplicatevalue', " + + "c tinyint(1) NOT NULL DEFAULT '0', PRIMARY KEY (c, b));") + + tk1 := testkit.NewTestKit(t, store) + tk1.MustExec("use test") + + d := dom.DDL() + originalCallback := d.GetHook() + defer d.SetHook(originalCallback) + callback := &callback.TestDDLCallback{} + + onJobUpdatedExportedFunc := func(job *model.Job) { + switch job.SchemaState { + case model.StateDeleteOnly: + _, err := tk1.Exec("INSERT INTO t VALUES (-18585,'aaa',1), (-18585,'0',1), (-18585,'1',1), (-18585,'duplicatevalue',1);") + assert.NoError(t, err) + case model.StateWriteReorganization: + idx := testutil.FindIdxInfo(dom, "test", "t", "idx") + if idx.BackfillState == model.BackfillStateReadyToMerge { + _, err := tk1.Exec("insert ignore into `t` values ( 234,'duplicatevalue',-2028 );") + assert.NoError(t, err) + return + } + } + } + callback.OnJobUpdatedExported.Store(&onJobUpdatedExportedFunc) + d.SetHook(callback) + + tk.MustExec("alter table t add unique index idx(b);") + tk.MustExec("admin check table t;") +} + +func TestDDLJobErrEntrySizeTooLarge(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + + tk.MustExec("use test") + tk.MustExec("create table t (a int);") + + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/ddl/mockErrEntrySizeTooLarge", `1*return(true)`)) + t.Cleanup(func() { + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/ddl/mockErrEntrySizeTooLarge")) + }) + + tk.MustGetErrCode("rename table t to t1;", errno.ErrEntryTooLarge) + tk.MustExec("create table t1 (a int);") + tk.MustExec("alter table t add column b int;") // Should not block. +} diff --git a/ddl/ddl_api.go b/ddl/ddl_api.go index a19d1f178749d..0b32474506c7c 100644 --- a/ddl/ddl_api.go +++ b/ddl/ddl_api.go @@ -4546,7 +4546,7 @@ func checkExchangePartition(pt *model.TableInfo, nt *model.TableInfo) error { return errors.Trace(dbterror.ErrPartitionExchangePartTable.GenWithStackByArgs(nt.Name)) } - if nt.ForeignKeys != nil { + if len(nt.ForeignKeys) > 0 { return errors.Trace(dbterror.ErrPartitionExchangeForeignKey.GenWithStackByArgs(nt.Name)) } @@ -5561,8 +5561,8 @@ func (d *ddl) AlterColumn(ctx sessionctx.Context, ident ast.Ident, spec *ast.Alt // Clean the NoDefaultValueFlag value. col.DelFlag(mysql.NoDefaultValueFlag) + col.DefaultIsExpr = false if len(specNewColumn.Options) == 0 { - col.DefaultIsExpr = false err = col.SetDefaultValue(nil) if err != nil { return errors.Trace(err) diff --git a/ddl/ddl_worker.go b/ddl/ddl_worker.go index dc2d280f0fb8d..926c2f3ef255b 100644 --- a/ddl/ddl_worker.go +++ b/ddl/ddl_worker.go @@ -423,7 +423,6 @@ func (w *worker) handleUpdateJobError(t *meta.Meta, job *model.Job, err error) e } // Reduce this txn entry size. job.BinlogInfo.Clean() - job.InvolvingSchemaInfo = nil job.Error = toTError(err) job.ErrorCount++ job.SchemaState = model.StateNone diff --git a/ddl/index.go b/ddl/index.go index fc51e9f774a01..fbcdefc7d5a1f 100644 --- a/ddl/index.go +++ b/ddl/index.go @@ -392,6 +392,8 @@ func onRenameIndex(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, _ error) } renameIndexes(tblInfo, from, to) + renameHiddenColumns(tblInfo, from, to) + if ver, err = updateVersionAndTableInfo(d, t, job, tblInfo, true); err != nil { job.State = model.JobStateCancelled return ver, errors.Trace(err) @@ -725,6 +727,10 @@ func pickBackfillType(ctx context.Context, job *model.Job, unique bool, d *ddlCt job.ReorgMeta.ReorgTp = model.ReorgTypeTxn return model.ReorgTypeTxn, nil } + if hasSysDB(job) { + job.ReorgMeta.ReorgTp = model.ReorgTypeTxn + return model.ReorgTypeTxn, nil + } if ingest.LitInitialized { available, err := ingest.LitBackCtxMgr.CheckAvailable() if err != nil { @@ -2284,3 +2290,12 @@ func renameIndexes(tblInfo *model.TableInfo, from, to model.CIStr) { } } } + +func renameHiddenColumns(tblInfo *model.TableInfo, from, to model.CIStr) { + for _, col := range tblInfo.Columns { + if col.Hidden && getExpressionIndexOriginName(col) == from.O { + col.Name.L = strings.Replace(col.Name.L, from.L, to.L, 1) + col.Name.O = strings.Replace(col.Name.O, from.O, to.O, 1) + } + } +} diff --git a/ddl/integration_test.go b/ddl/integration_test.go index 46f26fa5e8237..4a00bab84a16b 100644 --- a/ddl/integration_test.go +++ b/ddl/integration_test.go @@ -142,3 +142,18 @@ func TestDDLOnCachedTable(t *testing.T) { tk.MustExec("alter table t nocache;") tk.MustExec("drop table if exists t;") } + +func TestExchangePartitionAfterDropForeignKey(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + + tk.MustExec("use test;") + + tk.MustExec("create table parent (id int unique);") + tk.MustExec("create table child (id int, parent_id int, foreign key (parent_id) references parent(id));") + tk.MustExec("create table child_with_partition(id int, parent_id int) partition by range(id) (partition p1 values less than (100));") + tk.MustGetErrMsg("alter table child_with_partition exchange partition p1 with table child;", "[ddl:1740]Table to exchange with partition has foreign key references: 'child'") + tk.MustExec("alter table child drop foreign key fk_1;") + tk.MustExec("alter table child drop key fk_1;") + tk.MustExec("alter table child_with_partition exchange partition p1 with table child;") +} diff --git a/ddl/metadatalocktest/BUILD.bazel b/ddl/metadatalocktest/BUILD.bazel index d458d7d592368..c1287f1eb14db 100644 --- a/ddl/metadatalocktest/BUILD.bazel +++ b/ddl/metadatalocktest/BUILD.bazel @@ -8,7 +8,7 @@ go_test( "mdl_test.go", ], flaky = True, - shard_count = 32, + shard_count = 34, deps = [ "//config", "//ddl", diff --git a/ddl/metadatalocktest/mdl_test.go b/ddl/metadatalocktest/mdl_test.go index be97f9e1f3f50..67c084df68036 100644 --- a/ddl/metadatalocktest/mdl_test.go +++ b/ddl/metadatalocktest/mdl_test.go @@ -432,6 +432,90 @@ func TestMDLAutoCommitReadOnly(t *testing.T) { require.Greater(t, ts1, ts2) } +func TestMDLAnalyze(t *testing.T) { + store, dom := testkit.CreateMockStoreAndDomain(t) + sv := server.CreateMockServer(t, store) + + sv.SetDomain(dom) + dom.InfoSyncer().SetSessionManager(sv) + defer sv.Close() + + conn1 := server.CreateMockConn(t, sv) + tk := testkit.NewTestKitWithSession(t, store, conn1.Context().Session) + conn2 := server.CreateMockConn(t, sv) + tkDDL := testkit.NewTestKitWithSession(t, store, conn2.Context().Session) + tk.MustExec("use test") + tk.MustExec("set global tidb_enable_metadata_lock=1") + tk.MustExec("create table t(a int);") + tk.MustExec("insert into t values(1);") + + var wg sync.WaitGroup + wg.Add(2) + var ts2 time.Time + var ts1 time.Time + + go func() { + tk.MustExec("begin") + tk.MustExec("analyze table t;") + tk.MustQuery("select sleep(2);") + tk.MustExec("commit") + ts1 = time.Now() + wg.Done() + }() + + go func() { + tkDDL.MustExec("alter table test.t add column b int;") + ts2 = time.Now() + wg.Done() + }() + + wg.Wait() + require.Greater(t, ts1, ts2) +} + +func TestMDLAnalyzePartition(t *testing.T) { + store, dom := testkit.CreateMockStoreAndDomain(t) + sv := server.CreateMockServer(t, store) + + sv.SetDomain(dom) + dom.InfoSyncer().SetSessionManager(sv) + defer sv.Close() + + conn1 := server.CreateMockConn(t, sv) + tk := testkit.NewTestKitWithSession(t, store, conn1.Context().Session) + conn2 := server.CreateMockConn(t, sv) + tkDDL := testkit.NewTestKitWithSession(t, store, conn2.Context().Session) + tk.MustExec("use test") + tk.MustExec("set @@tidb_partition_prune_mode='dynamic'") + tk.MustExec("set global tidb_enable_metadata_lock=1") + tk.MustExec("create table t(a int) partition by range(a) ( PARTITION p0 VALUES LESS THAN (0), PARTITION p1 VALUES LESS THAN (100), PARTITION p2 VALUES LESS THAN MAXVALUE );") + tk.MustExec("insert into t values(1), (2), (3), (4);") + + var wg sync.WaitGroup + wg.Add(2) + var ts2 time.Time + var ts1 time.Time + + go func() { + tk.MustExec("begin") + tk.MustExec("analyze table t;") + tk.MustExec("analyze table t partition p1;") + tk.MustQuery("select sleep(2);") + tk.MustExec("commit") + ts1 = time.Now() + wg.Done() + }() + + go func() { + tkDDL.MustExec("alter table test.t drop partition p2;") + ts2 = time.Now() + wg.Done() + }() + + wg.Wait() + require.Greater(t, ts1, ts2) +} + func TestMDLAutoCommitNonReadOnly(t *testing.T) { store, dom := testkit.CreateMockStoreAndDomain(t) sv := server.CreateMockServer(t, store) diff --git a/ddl/table.go b/ddl/table.go index bbc2599ceae81..191d102f60dd8 100644 --- a/ddl/table.go +++ b/ddl/table.go @@ -177,8 +177,11 @@ func onCreateTable(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, _ error) func createTableWithForeignKeys(d *ddlCtx, t *meta.Meta, job *model.Job, tbInfo *model.TableInfo, fkCheck bool) (ver int64, err error) { switch tbInfo.State { - case model.StateNone: - // create table in non-public state + case model.StateNone, model.StatePublic: + // create table in non-public or public state. The function `createTable` will always reset + // the `tbInfo.State` with `model.StateNone`, so it's fine to just call the `createTable` with + // public state. + // when `br` restores table, the state of `tbInfo` will be public. tbInfo, err = createTable(d, t, job, fkCheck) if err != nil { return ver, errors.Trace(err) diff --git a/executor/compact_table.go b/executor/compact_table.go index c5960619c54b0..8d2cc3386b1bd 100644 --- a/executor/compact_table.go +++ b/executor/compact_table.go @@ -15,7 +15,6 @@ package executor import ( - "bytes" "context" "encoding/hex" "time" @@ -310,7 +309,7 @@ func (task *storeCompactTask) compactOnePhysicalTable(physicalTableID int64) (bo // Let's send more compact requests, as there are remaining data to compact. lastEndKey := resp.GetCompactedEndKey() - if len(lastEndKey) == 0 || bytes.Compare(lastEndKey, startKey) <= 0 { + if len(lastEndKey) == 0 { // The TiFlash server returned an invalid compacted end key. // This is unexpected... warn := errors.Errorf("compact on store %s failed: internal error (check logs for details)", task.targetStore.Address) diff --git a/executor/executor_failpoint_test.go b/executor/executor_failpoint_test.go index 70630ad184a4a..e7f845d6171a9 100644 --- a/executor/executor_failpoint_test.go +++ b/executor/executor_failpoint_test.go @@ -585,3 +585,26 @@ func TestGetMvccByEncodedKeyRegionError(t *testing.T) { require.Equal(t, 1, len(resp.Info.Writes)) require.Equal(t, commitTs, resp.Info.Writes[0].CommitTs) } + +func TestShuffleExit(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t1;") + tk.MustExec("create table t1(i int, j int, k int);") + tk.MustExec("insert into t1 VALUES (1,1,1),(2,2,2),(3,3,3),(4,4,4);") + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/executor/shuffleError", "return(true)")) + defer func() { + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/executor/shuffleError")) + }() + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/executor/shuffleExecFetchDataAndSplit", "return(true)")) + defer func() { + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/executor/shuffleExecFetchDataAndSplit")) + }() + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/executor/shuffleWorkerRun", "panic(\"ShufflePanic\")")) + defer func() { + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/executor/shuffleWorkerRun")) + }() + err := tk.QueryToErr("SELECT SUM(i) OVER W FROM t1 WINDOW w AS (PARTITION BY j ORDER BY i) ORDER BY 1+SUM(i) OVER w;") + require.ErrorContains(t, err, "ShuffleExec.Next error") +} diff --git a/executor/infoschema_reader.go b/executor/infoschema_reader.go index f4bdfe2a7a460..29938e2462fab 100644 --- a/executor/infoschema_reader.go +++ b/executor/infoschema_reader.go @@ -1864,11 +1864,11 @@ func (e *memtableRetriever) setDataForTiDBHotRegions(ctx sessionctx.Context) err if !ok { return errors.New("Information about hot region can be gotten only when the storage is TiKV") } - allSchemas := ctx.GetInfoSchema().(infoschema.InfoSchema).AllSchemas() tikvHelper := &helper.Helper{ Store: tikvStore, RegionCache: tikvStore.GetRegionCache(), } + allSchemas := tikvHelper.FilterMemDBs(ctx.GetInfoSchema().(infoschema.InfoSchema).AllSchemas()) metrics, err := tikvHelper.ScrapeHotInfo(pdapi.HotRead, allSchemas) if err != nil { return err diff --git a/executor/insert_common.go b/executor/insert_common.go index 8baf10c09fb5d..cb5237785dbdb 100644 --- a/executor/insert_common.go +++ b/executor/insert_common.go @@ -1171,6 +1171,28 @@ func (e *InsertValues) collectRuntimeStatsEnabled() bool { return false } +func (e *InsertValues) handleDuplicateKey(ctx context.Context, txn kv.Transaction, uk *keyValueWithDupInfo, replace bool, r toBeCheckedRow) (bool, error) { + if !replace { + e.ctx.GetSessionVars().StmtCtx.AppendWarning(uk.dupErr) + if txnCtx := e.ctx.GetSessionVars().TxnCtx; txnCtx.IsPessimistic && e.ctx.GetSessionVars().LockUnchangedKeys { + txnCtx.AddUnchangedKeyForLock(uk.newKey) + } + return true, nil + } + _, handle, err := tables.FetchDuplicatedHandle(ctx, uk.newKey, true, txn, e.Table.Meta().ID, uk.commonHandle) + if err != nil { + return false, err + } + if handle == nil { + return false, nil + } + _, err = e.removeRow(ctx, txn, handle, r, true) + if err != nil { + return false, err + } + return false, nil +} + // batchCheckAndInsert checks rows with duplicate errors. // All duplicate rows will be ignored and appended as duplicate warnings. func (e *InsertValues) batchCheckAndInsert( @@ -1221,7 +1243,6 @@ func (e *InsertValues) batchCheckAndInsert( } // append warnings and get no duplicated error rows -CheckAndInsert: for i, r := range toBeCheckedRows { if r.ignored { continue @@ -1258,43 +1279,44 @@ CheckAndInsert: } } + rowInserted := false for _, uk := range r.uniqueKeys { _, err := txn.Get(ctx, uk.newKey) + if err != nil && !kv.IsErrNotFound(err) { + return err + } if err == nil { - if replace { - _, handle, err := tables.FetchDuplicatedHandle( - ctx, - uk.newKey, - true, - txn, - e.Table.Meta().ID, - uk.commonHandle, - ) - if err != nil { - return err - } - if handle == nil { - continue - } - _, err = e.removeRow(ctx, txn, handle, r, true) + rowInserted, err = e.handleDuplicateKey(ctx, txn, uk, replace, r) + if err != nil { + return err + } + if rowInserted { + break + } + continue + } + if tablecodec.IsTempIndexKey(uk.newKey) { + tablecodec.TempIndexKey2IndexKey(uk.newKey) + _, err = txn.Get(ctx, uk.newKey) + if err != nil && !kv.IsErrNotFound(err) { + return err + } + if err == nil { + rowInserted, err = e.handleDuplicateKey(ctx, txn, uk, replace, r) if err != nil { return err } - } else { - // If duplicate keys were found in BatchGet, mark row = nil. - e.ctx.GetSessionVars().StmtCtx.AppendWarning(uk.dupErr) - if txnCtx := e.ctx.GetSessionVars().TxnCtx; txnCtx.IsPessimistic && - e.ctx.GetSessionVars().LockUnchangedKeys { - // lock duplicated unique key on insert-ignore - txnCtx.AddUnchangedKeyForLock(uk.newKey) + if rowInserted { + break } - continue CheckAndInsert } - } else if !kv.IsErrNotFound(err) { - return err } } + if rowInserted { + continue + } + // If row was checked with no duplicate keys, // it should be added to values map for the further row check. // There may be duplicate keys inside the insert statement. diff --git a/executor/shuffle.go b/executor/shuffle.go index 7596d2cef1970..371e847cbbf9a 100644 --- a/executor/shuffle.go +++ b/executor/shuffle.go @@ -17,6 +17,7 @@ package executor import ( "context" "sync" + "time" "github.com/pingcap/errors" "github.com/pingcap/failpoint" @@ -113,7 +114,7 @@ func (e *ShuffleExec) Open(ctx context.Context) error { e.prepared = false e.finishCh = make(chan struct{}, 1) - e.outputCh = make(chan *shuffleOutput, e.concurrency) + e.outputCh = make(chan *shuffleOutput, e.concurrency+len(e.dataSources)) for _, w := range e.workers { w.finishCh = e.finishCh @@ -199,13 +200,13 @@ func (e *ShuffleExec) Close() error { } func (e *ShuffleExec) prepare4ParallelExec(ctx context.Context) { + waitGroup := &sync.WaitGroup{} + waitGroup.Add(len(e.workers) + len(e.dataSources)) // create a goroutine for each dataSource to fetch and split data for i := range e.dataSources { - go e.fetchDataAndSplit(ctx, i) + go e.fetchDataAndSplit(ctx, i, waitGroup) } - waitGroup := &sync.WaitGroup{} - waitGroup.Add(len(e.workers)) for _, w := range e.workers { go w.run(ctx, waitGroup) } @@ -256,7 +257,7 @@ func recoveryShuffleExec(output chan *shuffleOutput, r interface{}) { logutil.BgLogger().Error("shuffle panicked", zap.Error(err), zap.Stack("stack")) } -func (e *ShuffleExec) fetchDataAndSplit(ctx context.Context, dataSourceIndex int) { +func (e *ShuffleExec) fetchDataAndSplit(ctx context.Context, dataSourceIndex int, waitGroup *sync.WaitGroup) { var ( err error workerIndices []int @@ -271,8 +272,16 @@ func (e *ShuffleExec) fetchDataAndSplit(ctx context.Context, dataSourceIndex int for _, w := range e.workers { close(w.receivers[dataSourceIndex].inputCh) } + waitGroup.Done() }() + failpoint.Inject("shuffleExecFetchDataAndSplit", func(val failpoint.Value) { + if val.(bool) { + time.Sleep(100 * time.Millisecond) + panic("shuffleExecFetchDataAndSplitPanic") + } + }) + for { err = Next(ctx, e.dataSources[dataSourceIndex], chk) if err != nil { @@ -386,6 +395,7 @@ func (e *shuffleWorker) run(ctx context.Context, waitGroup *sync.WaitGroup) { waitGroup.Done() }() + failpoint.Inject("shuffleWorkerRun", nil) for { select { case <-e.finishCh: diff --git a/expression/builtin_cast.go b/expression/builtin_cast.go index c9016f9634405..31637c6ea8e58 100644 --- a/expression/builtin_cast.go +++ b/expression/builtin_cast.go @@ -874,7 +874,7 @@ func (b *builtinCastStringAsJSONSig) evalJSON(row chunk.Row) (res types.BinaryJS typ := b.args[0].GetType() if types.IsBinaryStr(typ) { buf := []byte(val) - if typ.GetType() == mysql.TypeString { + if typ.GetType() == mysql.TypeString && typ.GetFlen() > 0 { // the tailing zero should also be in the opaque json buf = make([]byte, typ.GetFlen()) copy(buf, val) diff --git a/expression/builtin_cast_vec.go b/expression/builtin_cast_vec.go index 12adfe39e6b4f..c97baef320b8f 100644 --- a/expression/builtin_cast_vec.go +++ b/expression/builtin_cast_vec.go @@ -844,7 +844,7 @@ func (b *builtinCastStringAsJSONSig) vecEvalJSON(input *chunk.Chunk, result *chu val := buf.GetBytes(i) resultBuf := val - if typ.GetType() == mysql.TypeString { + if typ.GetType() == mysql.TypeString && typ.GetFlen() > 0 { // only for BINARY: the tailing zero should also be in the opaque json resultBuf = make([]byte, typ.GetFlen()) copy(resultBuf, val) diff --git a/expression/constant_fold.go b/expression/constant_fold.go index 0d73757b3161b..f51b86b8cd337 100644 --- a/expression/constant_fold.go +++ b/expression/constant_fold.go @@ -100,7 +100,15 @@ func ifNullFoldHandler(expr *ScalarFunction) (Expression, bool) { // evaluated to constArg.Value after foldConstant(args[0]), it's not // needed to be checked. if constArg.Value.IsNull() { - return foldConstant(args[1]) + foldedExpr, isConstant := foldConstant(args[1]) + + // See https://github.com/pingcap/tidb/issues/51765. If the first argument can + // be folded into NULL, the collation of IFNULL should be the same as the second + // arguments. + expr.GetType().SetCharset(args[1].GetType().GetCharset()) + expr.GetType().SetCollate(args[1].GetType().GetCollate()) + + return foldedExpr, isConstant } return constArg, isDeferred } @@ -157,6 +165,10 @@ func foldConstant(expr Expression) (Expression, bool) { if _, ok := unFoldableFunctions[x.FuncName.L]; ok { return expr, false } + if _, ok := x.Function.(*extensionFuncSig); ok { + // we should not fold the extension function, because it may have a side effect. + return expr, false + } if function := specialFoldHandler[x.FuncName.L]; function != nil && !MaybeOverOptimized4PlanCache(x.GetCtx(), []Expression{expr}) { return function(x) } diff --git a/expression/extension.go b/expression/extension.go index 9ab506213d5f0..e67b4b549347d 100644 --- a/expression/extension.go +++ b/expression/extension.go @@ -97,7 +97,7 @@ func newExtensionFuncClass(def *extension.FunctionDef) (*extensionFuncClass, err } func (c *extensionFuncClass) getFunction(ctx sessionctx.Context, args []Expression) (builtinFunc, error) { - if err := c.checkPrivileges(ctx); err != nil { + if err := checkPrivileges(ctx, &c.funcDef); err != nil { return nil, err } @@ -108,13 +108,18 @@ func (c *extensionFuncClass) getFunction(ctx sessionctx.Context, args []Expressi if err != nil { return nil, err } + + // Though currently, `getFunction` does not require too much information that makes it safe to be cached, + // we still skip the plan cache for extension functions because there are no strong requirements to do it. + // Skipping the plan cache can make the behavior simple. + ctx.GetSessionVars().StmtCtx.SetSkipPlanCache(errors.NewNoStackError("extension function should not be cached")) bf.tp.SetFlen(c.flen) sig := &extensionFuncSig{context.TODO(), bf, c.funcDef} return sig, nil } -func (c *extensionFuncClass) checkPrivileges(ctx sessionctx.Context) error { - fn := c.funcDef.RequireDynamicPrivileges +func checkPrivileges(ctx sessionctx.Context, fnDef *extension.FunctionDef) error { + fn := fnDef.RequireDynamicPrivileges if fn == nil { return nil } @@ -157,6 +162,10 @@ func (b *extensionFuncSig) Clone() builtinFunc { } func (b *extensionFuncSig) evalString(row chunk.Row) (string, bool, error) { + if err := checkPrivileges(b.ctx, &b.FunctionDef); err != nil { + return "", true, err + } + if b.EvalTp == types.ETString { return b.EvalStringFunc(b, row) } @@ -164,6 +173,10 @@ func (b *extensionFuncSig) evalString(row chunk.Row) (string, bool, error) { } func (b *extensionFuncSig) evalInt(row chunk.Row) (int64, bool, error) { + if err := checkPrivileges(b.ctx, &b.FunctionDef); err != nil { + return 0, true, err + } + if b.EvalTp == types.ETInt { return b.EvalIntFunc(b, row) } diff --git a/expression/integration_test/integration_test.go b/expression/integration_test/integration_test.go index 26435b5a522ce..07643567f93a0 100644 --- a/expression/integration_test/integration_test.go +++ b/expression/integration_test/integration_test.go @@ -8025,3 +8025,42 @@ func TestIssue50850(t *testing.T) { testkit.Rows("01", "31", "34", "5D", "65", "A5", "A6", "B1", "D5", "FF")) tk.MustExec("drop table if exists t3;") } + +func TestIssue51765(t *testing.T) { + store := testkit.CreateMockStore(t) + + tk := testkit.NewTestKit(t, store) + + tk.MustExec("use test") + tk.MustExec("create table t (id varbinary(16))") + tk.MustExec("create table t1(id char(16) charset utf8mb4 collate utf8mb4_general_ci)") + tk.MustExec("insert into t values ()") + tk.MustExec(`insert into t1 values ("Hello World")`) + + tk.MustQuery("select collation(ifnull(concat(NULL), '~'))").Check(testkit.Rows("utf8mb4_bin")) + tk.MustQuery("select collation(ifnull(concat(NULL),ifnull(concat(NULL),'~')))").Check(testkit.Rows("utf8mb4_bin")) + tk.MustQuery("select collation(ifnull(concat(id),'~')) from t;").Check(testkit.Rows("binary")) + tk.MustQuery("select collation(ifnull(concat(NULL),ifnull(concat(id),'~'))) from t;").Check(testkit.Rows("binary")) + tk.MustQuery("select collation(ifnull(concat(id),ifnull(concat(id),'~'))) from t;").Check(testkit.Rows("binary")) + tk.MustQuery("select collation(ifnull(concat(NULL),id)) from t1;").Check(testkit.Rows("utf8mb4_general_ci")) + tk.MustQuery("select collation(ifnull(concat(NULL),ifnull(concat(NULL),id))) from t1;").Check(testkit.Rows("utf8mb4_general_ci")) +} + +func TestCastBinaryStringToJSON(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + + tk.MustQuery("select cast(binary 'aa' as json);").Check(testkit.Rows(`"base64:type254:YWE="`)) + + tk.MustExec("use test") + tk.MustExec("create table t (vb VARBINARY(10), b BINARY(10), vc VARCHAR(10), c CHAR(10));") + tk.MustExec("insert into t values ('1', '1', '1', '1');") + tk.MustQuery("select cast(vb as json), cast(b as json), cast(vc as json), cast(c as json) from t;").Check( + testkit.Rows(`"base64:type15:MQ==" "base64:type254:MQAAAAAAAAAAAA==" 1 1`)) + tk.MustQuery("select 1 from t where cast(vb as json) = '1';").Check(testkit.Rows()) + tk.MustQuery("select 1 from t where cast(b as json) = '1';").Check(testkit.Rows()) + tk.MustQuery("select 1 from t where cast(vc as json) = '1';").Check(testkit.Rows()) + tk.MustQuery("select 1 from t where cast(c as json) = '1';").Check(testkit.Rows()) + tk.MustQuery("select 1 from t where cast(BINARY vc as json) = '1';").Check(testkit.Rows()) + tk.MustQuery("select 1 from t where cast(BINARY c as json) = '1';").Check(testkit.Rows()) +} diff --git a/expression/scalar_function.go b/expression/scalar_function.go index 131e018fba550..d4a4ca108baf2 100644 --- a/expression/scalar_function.go +++ b/expression/scalar_function.go @@ -334,6 +334,12 @@ func (sf *ScalarFunction) ConstItem(sc *stmtctx.StatementContext) bool { if _, ok := unFoldableFunctions[sf.FuncName.L]; ok { return false } + + if _, ok := sf.Function.(*extensionFuncSig); ok { + // we should return false for extension functions for safety, because it may have a side effect. + return false + } + for _, arg := range sf.GetArgs() { if !arg.ConstItem(sc) { return false diff --git a/extension/BUILD.bazel b/extension/BUILD.bazel index 578b3eca306a1..27ed31523e6f3 100644 --- a/extension/BUILD.bazel +++ b/extension/BUILD.bazel @@ -39,7 +39,7 @@ go_test( ], embed = [":extension"], flaky = True, - shard_count = 14, + shard_count = 15, deps = [ "//expression", "//parser/ast", @@ -55,6 +55,7 @@ go_test( "//testkit/testsetup", "//types", "//util/chunk", + "//util/mock", "//util/sem", "@com_github_pingcap_errors//:errors", "@com_github_stretchr_testify//require", diff --git a/extension/function_test.go b/extension/function_test.go index 06fa301543525..acd697192b9bf 100644 --- a/extension/function_test.go +++ b/extension/function_test.go @@ -18,6 +18,7 @@ import ( "fmt" "sort" "strings" + "sync/atomic" "testing" "github.com/pingcap/errors" @@ -28,6 +29,7 @@ import ( "github.com/pingcap/tidb/testkit" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/chunk" + "github.com/pingcap/tidb/util/mock" "github.com/pingcap/tidb/util/sem" "github.com/stretchr/testify/require" ) @@ -318,6 +320,19 @@ func TestExtensionFuncPrivilege(t *testing.T) { return "ghi", false, nil }, }, + { + Name: "custom_eval_int_func", + EvalTp: types.ETInt, + RequireDynamicPrivileges: func(sem bool) []string { + if sem { + return []string{"RESTRICTED_CUSTOM_DYN_PRIV_2"} + } + return []string{"CUSTOM_DYN_PRIV_1"} + }, + EvalIntFunc: func(ctx extension.FunctionContext, row chunk.Row) (int64, bool, error) { + return 1, false, nil + }, + }, }), extension.WithCustomDynPrivs([]string{ "CUSTOM_DYN_PRIV_1", @@ -349,6 +364,7 @@ func TestExtensionFuncPrivilege(t *testing.T) { tk1.MustQuery("select custom_only_dyn_priv_func()").Check(testkit.Rows("abc")) tk1.MustQuery("select custom_only_sem_dyn_priv_func()").Check(testkit.Rows("def")) tk1.MustQuery("select custom_both_dyn_priv_func()").Check(testkit.Rows("ghi")) + tk1.MustQuery("select custom_eval_int_func()").Check(testkit.Rows("1")) // u1 in non-sem require.NoError(t, tk1.Session().Auth(&auth.UserIdentity{Username: "u1", Hostname: "localhost"}, nil, nil, nil)) @@ -356,6 +372,11 @@ func TestExtensionFuncPrivilege(t *testing.T) { require.EqualError(t, tk1.ExecToErr("select custom_only_dyn_priv_func()"), "[expression:1227]Access denied; you need (at least one of) the SUPER or CUSTOM_DYN_PRIV_1 privilege(s) for this operation") tk1.MustQuery("select custom_only_sem_dyn_priv_func()").Check(testkit.Rows("def")) require.EqualError(t, tk1.ExecToErr("select custom_both_dyn_priv_func()"), "[expression:1227]Access denied; you need (at least one of) the SUPER or CUSTOM_DYN_PRIV_1 privilege(s) for this operation") + require.EqualError(t, tk1.ExecToErr("select custom_eval_int_func()"), "[expression:1227]Access denied; you need (at least one of) the SUPER or CUSTOM_DYN_PRIV_1 privilege(s) for this operation") + + // prepare should check privilege + require.EqualError(t, tk1.ExecToErr("prepare stmt1 from 'select custom_only_dyn_priv_func()'"), "[expression:1227]Access denied; you need (at least one of) the SUPER or CUSTOM_DYN_PRIV_1 privilege(s) for this operation") + require.EqualError(t, tk1.ExecToErr("prepare stmt2 from 'select custom_eval_int_func()'"), "[expression:1227]Access denied; you need (at least one of) the SUPER or CUSTOM_DYN_PRIV_1 privilege(s) for this operation") // u2 in non-sem require.NoError(t, tk1.Session().Auth(&auth.UserIdentity{Username: "u2", Hostname: "localhost"}, nil, nil, nil)) @@ -363,6 +384,7 @@ func TestExtensionFuncPrivilege(t *testing.T) { tk1.MustQuery("select custom_only_dyn_priv_func()").Check(testkit.Rows("abc")) tk1.MustQuery("select custom_only_sem_dyn_priv_func()").Check(testkit.Rows("def")) tk1.MustQuery("select custom_both_dyn_priv_func()").Check(testkit.Rows("ghi")) + tk1.MustQuery("select custom_eval_int_func()").Check(testkit.Rows("1")) // u3 in non-sem require.NoError(t, tk1.Session().Auth(&auth.UserIdentity{Username: "u3", Hostname: "localhost"}, nil, nil, nil)) @@ -370,6 +392,7 @@ func TestExtensionFuncPrivilege(t *testing.T) { require.EqualError(t, tk1.ExecToErr("select custom_only_dyn_priv_func()"), "[expression:1227]Access denied; you need (at least one of) the SUPER or CUSTOM_DYN_PRIV_1 privilege(s) for this operation") tk1.MustQuery("select custom_only_sem_dyn_priv_func()").Check(testkit.Rows("def")) require.EqualError(t, tk1.ExecToErr("select custom_both_dyn_priv_func()"), "[expression:1227]Access denied; you need (at least one of) the SUPER or CUSTOM_DYN_PRIV_1 privilege(s) for this operation") + require.EqualError(t, tk1.ExecToErr("select custom_eval_int_func()"), "[expression:1227]Access denied; you need (at least one of) the SUPER or CUSTOM_DYN_PRIV_1 privilege(s) for this operation") // u4 in non-sem require.NoError(t, tk1.Session().Auth(&auth.UserIdentity{Username: "u4", Hostname: "localhost"}, nil, nil, nil)) @@ -377,6 +400,7 @@ func TestExtensionFuncPrivilege(t *testing.T) { tk1.MustQuery("select custom_only_dyn_priv_func()").Check(testkit.Rows("abc")) tk1.MustQuery("select custom_only_sem_dyn_priv_func()").Check(testkit.Rows("def")) tk1.MustQuery("select custom_both_dyn_priv_func()").Check(testkit.Rows("ghi")) + tk1.MustQuery("select custom_eval_int_func()").Check(testkit.Rows("1")) sem.Enable() @@ -386,6 +410,7 @@ func TestExtensionFuncPrivilege(t *testing.T) { tk1.MustQuery("select custom_only_dyn_priv_func()").Check(testkit.Rows("abc")) require.EqualError(t, tk1.ExecToErr("select custom_only_sem_dyn_priv_func()"), "[expression:1227]Access denied; you need (at least one of) the RESTRICTED_CUSTOM_DYN_PRIV_2 privilege(s) for this operation") require.EqualError(t, tk1.ExecToErr("select custom_both_dyn_priv_func()"), "[expression:1227]Access denied; you need (at least one of) the RESTRICTED_CUSTOM_DYN_PRIV_2 privilege(s) for this operation") + require.EqualError(t, tk1.ExecToErr("select custom_eval_int_func()"), "[expression:1227]Access denied; you need (at least one of) the RESTRICTED_CUSTOM_DYN_PRIV_2 privilege(s) for this operation") // u1 in sem require.NoError(t, tk1.Session().Auth(&auth.UserIdentity{Username: "u1", Hostname: "localhost"}, nil, nil, nil)) @@ -393,6 +418,7 @@ func TestExtensionFuncPrivilege(t *testing.T) { require.EqualError(t, tk1.ExecToErr("select custom_only_dyn_priv_func()"), "[expression:1227]Access denied; you need (at least one of) the CUSTOM_DYN_PRIV_1 privilege(s) for this operation") require.EqualError(t, tk1.ExecToErr("select custom_only_sem_dyn_priv_func()"), "[expression:1227]Access denied; you need (at least one of) the RESTRICTED_CUSTOM_DYN_PRIV_2 privilege(s) for this operation") require.EqualError(t, tk1.ExecToErr("select custom_both_dyn_priv_func()"), "[expression:1227]Access denied; you need (at least one of) the RESTRICTED_CUSTOM_DYN_PRIV_2 privilege(s) for this operation") + require.EqualError(t, tk1.ExecToErr("select custom_eval_int_func()"), "[expression:1227]Access denied; you need (at least one of) the RESTRICTED_CUSTOM_DYN_PRIV_2 privilege(s) for this operation") // u2 in sem require.NoError(t, tk1.Session().Auth(&auth.UserIdentity{Username: "u2", Hostname: "localhost"}, nil, nil, nil)) @@ -400,6 +426,7 @@ func TestExtensionFuncPrivilege(t *testing.T) { tk1.MustQuery("select custom_only_dyn_priv_func()").Check(testkit.Rows("abc")) require.EqualError(t, tk1.ExecToErr("select custom_only_sem_dyn_priv_func()"), "[expression:1227]Access denied; you need (at least one of) the RESTRICTED_CUSTOM_DYN_PRIV_2 privilege(s) for this operation") require.EqualError(t, tk1.ExecToErr("select custom_both_dyn_priv_func()"), "[expression:1227]Access denied; you need (at least one of) the RESTRICTED_CUSTOM_DYN_PRIV_2 privilege(s) for this operation") + require.EqualError(t, tk1.ExecToErr("select custom_eval_int_func()"), "[expression:1227]Access denied; you need (at least one of) the RESTRICTED_CUSTOM_DYN_PRIV_2 privilege(s) for this operation") // u3 in sem require.NoError(t, tk1.Session().Auth(&auth.UserIdentity{Username: "u3", Hostname: "localhost"}, nil, nil, nil)) @@ -407,6 +434,7 @@ func TestExtensionFuncPrivilege(t *testing.T) { require.EqualError(t, tk1.ExecToErr("select custom_only_dyn_priv_func()"), "[expression:1227]Access denied; you need (at least one of) the CUSTOM_DYN_PRIV_1 privilege(s) for this operation") tk1.MustQuery("select custom_only_sem_dyn_priv_func()").Check(testkit.Rows("def")) tk1.MustQuery("select custom_both_dyn_priv_func()").Check(testkit.Rows("ghi")) + tk1.MustQuery("select custom_eval_int_func()").Check(testkit.Rows("1")) // u4 in sem require.NoError(t, tk1.Session().Auth(&auth.UserIdentity{Username: "u4", Hostname: "localhost"}, nil, nil, nil)) @@ -414,4 +442,82 @@ func TestExtensionFuncPrivilege(t *testing.T) { tk1.MustQuery("select custom_only_dyn_priv_func()").Check(testkit.Rows("abc")) tk1.MustQuery("select custom_only_sem_dyn_priv_func()").Check(testkit.Rows("def")) tk1.MustQuery("select custom_both_dyn_priv_func()").Check(testkit.Rows("ghi")) + tk1.MustQuery("select custom_eval_int_func()").Check(testkit.Rows("1")) + + // Test the privilege should also be checked when evaluating especially for when privilege is revoked. + // We enable `fixcontrol.Fix49736` to force enable plan cache to make sure `Expression.EvalXXX` will be invoked. + tk1.MustExec("prepare s1 from 'select custom_both_dyn_priv_func()'") + tk1.MustExec("prepare s2 from 'select custom_eval_int_func()'") + tk1.MustQuery("execute s1").Check(testkit.Rows("ghi")) + tk1.MustQuery("execute s2").Check(testkit.Rows("1")) + tk.MustExec("REVOKE RESTRICTED_CUSTOM_DYN_PRIV_2 on *.* FROM u4@localhost") + require.EqualError(t, tk1.ExecToErr("execute s1"), "[expression:1227]Access denied; you need (at least one of) the RESTRICTED_CUSTOM_DYN_PRIV_2 privilege(s) for this operation") + require.EqualError(t, tk1.ExecToErr("execute s2"), "[expression:1227]Access denied; you need (at least one of) the RESTRICTED_CUSTOM_DYN_PRIV_2 privilege(s) for this operation") +} + +func TestShouldNotOptimizeExtensionFunc(t *testing.T) { + defer func() { + extension.Reset() + sem.Disable() + }() + + extension.Reset() + var cnt atomic.Int64 + require.NoError(t, extension.Register("test", + extension.WithCustomFunctions([]*extension.FunctionDef{ + { + Name: "my_func1", + EvalTp: types.ETInt, + EvalIntFunc: func(ctx extension.FunctionContext, row chunk.Row) (int64, bool, error) { + val := cnt.Add(1) + return val, false, nil + }, + }, + { + Name: "my_func2", + EvalTp: types.ETString, + EvalStringFunc: func(ctx extension.FunctionContext, row chunk.Row) (string, bool, error) { + val := cnt.Add(1) + if val%2 == 0 { + return "abc", false, nil + } + return "def", false, nil + }, + }, + }), + )) + + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("create table t1(a int primary key)") + tk.MustExec("insert into t1 values(1000), (2000)") + + // Test extension function should not fold. + // if my_func1 is folded, the result will be "1000 1", "2000 1", + // because after fold the function will be called only once. + tk.MustQuery("select a, my_func1() from t1 order by a").Check(testkit.Rows("1000 1", "2000 2")) + require.Equal(t, int64(2), cnt.Load()) + + // Test extension function should not be seen as a constant, i.e., its `ConstantLevel()` should return `ConstNone`. + // my_func2 should be called twice to return different regexp string for the below query. + // If it is optimized by mistake, a wrong result "1000 0", "2000 0" will be produced. + cnt.Store(0) + tk.MustQuery("select a, 'abc' regexp my_func2() from t1 order by a").Check(testkit.Rows("1000 0", "2000 1")) + + // Test flags after building expression + for _, exprStr := range []string{ + "my_func1()", + "my_func2()", + } { + ctx := mock.NewContext() + ctx.GetSessionVars().StmtCtx.UseCache = true + exprs, err := expression.ParseSimpleExprsWithNames(ctx, exprStr, nil, nil) + require.NoError(t, err) + require.Equal(t, 1, len(exprs)) + scalar, ok := exprs[0].(*expression.ScalarFunction) + require.True(t, ok) + require.False(t, scalar.ConstItem(ctx.GetSessionVars().StmtCtx)) + require.False(t, ctx.GetSessionVars().StmtCtx.UseCache) + } } diff --git a/pkg/executor/test/executor/BUILD.bazel b/pkg/executor/test/executor/BUILD.bazel deleted file mode 100644 index fa992431ab81f..0000000000000 --- a/pkg/executor/test/executor/BUILD.bazel +++ /dev/null @@ -1,9 +0,0 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") - -go_test( - name = "executor_test", - timeout = "short", - srcs = ["executor_test.go"], - flaky = True, - deps = ["//testkit"], -) diff --git a/pkg/executor/test/executor/executor_test.go b/pkg/executor/test/executor/executor_test.go deleted file mode 100644 index 96cdd815918e5..0000000000000 --- a/pkg/executor/test/executor/executor_test.go +++ /dev/null @@ -1,39 +0,0 @@ -// Copyright 2022 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package executor - -import ( - "testing" - - "github.com/pingcap/tidb/testkit" -) - -func TestIssues40463(t *testing.T) { - store := testkit.CreateMockStore(t) - tk := testkit.NewTestKit(t, store) - - tk.MustExec("use test;") - tk.MustExec("CREATE TABLE `4f380f26-9af6-4df8-959d-ad6296eff914` (`f7a9a4be-3728-449b-a5ea-df9b957eec67` enum('bkdv0','9rqy','lw','neud','ym','4nbv','9a7','bpkfo','xtfl','59','6vjj') NOT NULL DEFAULT 'neud', `43ca0135-1650-429b-8887-9eabcae2a234` set('8','5x47','xc','o31','lnz','gs5s','6yam','1','20ea','i','e') NOT NULL DEFAULT 'e', PRIMARY KEY (`f7a9a4be-3728-449b-a5ea-df9b957eec67`,`43ca0135-1650-429b-8887-9eabcae2a234`) /*T![clustered_index] CLUSTERED */) ENGINE=InnoDB DEFAULT CHARSET=ascii COLLATE=ascii_bin;") - tk.MustExec("INSERT INTO `4f380f26-9af6-4df8-959d-ad6296eff914` VALUES ('bkdv0','gs5s'),('lw','20ea'),('neud','8'),('ym','o31'),('4nbv','o31'),('xtfl','e');") - - tk.MustExec("CREATE TABLE `ba35a09f-76f4-40aa-9b48-13154a24bdd2` (`9b2a7138-14a3-4e8f-b29a-720392aad22c` set('zgn','if8yo','e','k7','bav','xj6','lkag','m5','as','ia','l3') DEFAULT 'zgn,if8yo,e,k7,ia,l3',`a60d6b5c-08bd-4a6d-b951-716162d004a5` set('6li6','05jlu','w','l','m','e9r','5q','d0ol','i6ajr','csf','d32') DEFAULT '6li6,05jlu,w,l,m,d0ol,i6ajr,csf,d32',`fb753d37-6252-4bd3-9bd1-0059640e7861` year(4) DEFAULT '2065', UNIQUE KEY `51816c39-27df-4bbe-a0e7-d6b6f54be2a2` (`fb753d37-6252-4bd3-9bd1-0059640e7861`), KEY `b0dfda0a-ffed-4c5b-9a72-4113bc1cbc8e` (`9b2a7138-14a3-4e8f-b29a-720392aad22c`,`fb753d37-6252-4bd3-9bd1-0059640e7861`)) ENGINE=InnoDB DEFAULT CHARSET=utf8 COLLATE=utf8_bin /*T! SHARD_ROW_ID_BITS=5 */;") - tk.MustExec("insert into `ba35a09f-76f4-40aa-9b48-13154a24bdd2` values ('if8yo', '6li6,05jlu,w,l,m,d0ol,i6ajr,csf,d32', 2065);") - - tk.MustExec("CREATE TABLE `07ccc74e-14c3-4685-bb41-c78a069b1a6d` (`8a93bdc5-2214-4f96-b5a7-1ba4c0d396ae` bigint(20) NOT NULL DEFAULT '-4604789462044748682',`30b19ecf-679f-4ca3-813f-d3c3b8f7da7e` date NOT NULL DEFAULT '5030-11-23',`1c52eaf2-1ebb-4486-9410-dfd00c7c835c` decimal(7,5) DEFAULT '-81.91307',`4b09dfdc-e688-41cb-9ffa-d03071a43077` float DEFAULT '1.7989023',PRIMARY KEY (`30b19ecf-679f-4ca3-813f-d3c3b8f7da7e`,`8a93bdc5-2214-4f96-b5a7-1ba4c0d396ae`) /*T![clustered_index] CLUSTERED */,KEY `ae7a7637-ca52-443b-8a3f-69694f730cc4` (`8a93bdc5-2214-4f96-b5a7-1ba4c0d396ae`),KEY `42640042-8a17-4145-9510-5bb419f83ed9` (`8a93bdc5-2214-4f96-b5a7-1ba4c0d396ae`),KEY `839f4f5a-83f3-449b-a7dd-c7d2974d351a` (`30b19ecf-679f-4ca3-813f-d3c3b8f7da7e`),KEY `c474cde1-6fe4-45df-9067-b4e479f84149` (`8a93bdc5-2214-4f96-b5a7-1ba4c0d396ae`),KEY `f834d0a9-709e-4ca8-925d-73f48322b70d` (`8a93bdc5-2214-4f96-b5a7-1ba4c0d396ae`)) ENGINE=InnoDB DEFAULT CHARSET=gbk COLLATE=gbk_chinese_ci;") - tk.MustExec("set sql_mode=``;") - tk.MustExec("INSERT INTO `07ccc74e-14c3-4685-bb41-c78a069b1a6d` VALUES (616295989348159438,'0000-00-00',1.00000,1.7989023),(2215857492573998768,'1970-02-02',0.00000,1.7989023),(2215857492573998768,'1983-05-13',0.00000,1.7989023),(-2840083604831267906,'1984-01-30',1.00000,1.7989023),(599388718360890339,'1986-09-09',1.00000,1.7989023),(3506764933630033073,'1987-11-22',1.00000,1.7989023),(3506764933630033073,'2002-02-26',1.00000,1.7989023),(3506764933630033073,'2003-05-14',1.00000,1.7989023),(3506764933630033073,'2007-05-16',1.00000,1.7989023),(3506764933630033073,'2017-02-20',1.00000,1.7989023),(3506764933630033073,'2017-08-06',1.00000,1.7989023),(2215857492573998768,'2019-02-18',1.00000,1.7989023),(3506764933630033073,'2020-08-11',1.00000,1.7989023),(3506764933630033073,'2028-06-07',1.00000,1.7989023),(3506764933630033073,'2036-08-16',1.00000,1.7989023);") - - tk.MustQuery("select /*+ use_index_merge( `4f380f26-9af6-4df8-959d-ad6296eff914` ) */ /*+ stream_agg() */ approx_percentile( `4f380f26-9af6-4df8-959d-ad6296eff914`.`f7a9a4be-3728-449b-a5ea-df9b957eec67` , 77 ) as r0 , `4f380f26-9af6-4df8-959d-ad6296eff914`.`f7a9a4be-3728-449b-a5ea-df9b957eec67` as r1 from `4f380f26-9af6-4df8-959d-ad6296eff914` where not( IsNull( `4f380f26-9af6-4df8-959d-ad6296eff914`.`f7a9a4be-3728-449b-a5ea-df9b957eec67` ) ) and not( `4f380f26-9af6-4df8-959d-ad6296eff914`.`f7a9a4be-3728-449b-a5ea-df9b957eec67` in ( select `8a93bdc5-2214-4f96-b5a7-1ba4c0d396ae` from `07ccc74e-14c3-4685-bb41-c78a069b1a6d` where `4f380f26-9af6-4df8-959d-ad6296eff914`.`f7a9a4be-3728-449b-a5ea-df9b957eec67` in ( select `a60d6b5c-08bd-4a6d-b951-716162d004a5` from `ba35a09f-76f4-40aa-9b48-13154a24bdd2` where not( `4f380f26-9af6-4df8-959d-ad6296eff914`.`f7a9a4be-3728-449b-a5ea-df9b957eec67` between 'bpkfo' and '59' ) and not( `4f380f26-9af6-4df8-959d-ad6296eff914`.`f7a9a4be-3728-449b-a5ea-df9b957eec67` in ( select `fb753d37-6252-4bd3-9bd1-0059640e7861` from `ba35a09f-76f4-40aa-9b48-13154a24bdd2` where IsNull( `4f380f26-9af6-4df8-959d-ad6296eff914`.`f7a9a4be-3728-449b-a5ea-df9b957eec67` ) or not( `4f380f26-9af6-4df8-959d-ad6296eff914`.`43ca0135-1650-429b-8887-9eabcae2a234` in ( select `8a93bdc5-2214-4f96-b5a7-1ba4c0d396ae` from `07ccc74e-14c3-4685-bb41-c78a069b1a6d` where IsNull( `4f380f26-9af6-4df8-959d-ad6296eff914`.`43ca0135-1650-429b-8887-9eabcae2a234` ) and not( `4f380f26-9af6-4df8-959d-ad6296eff914`.`f7a9a4be-3728-449b-a5ea-df9b957eec67` between 'neud' and 'bpkfo' ) ) ) ) ) ) ) ) group by `4f380f26-9af6-4df8-959d-ad6296eff914`.`f7a9a4be-3728-449b-a5ea-df9b957eec67`;") -} diff --git a/planner/core/casetest/BUILD.bazel b/planner/core/casetest/BUILD.bazel index 2b63b59f5b2a6..88eea70a79d8b 100644 --- a/planner/core/casetest/BUILD.bazel +++ b/planner/core/casetest/BUILD.bazel @@ -23,6 +23,7 @@ go_test( "rule_result_reorder_test.go", "stats_test.go", "tiflash_selection_late_materialization_test.go", + "widow_with_exist_subquery_test.go", "window_push_down_test.go", ], data = glob(["testdata/**"]), diff --git a/planner/core/casetest/widow_with_exist_subquery_test.go b/planner/core/casetest/widow_with_exist_subquery_test.go new file mode 100644 index 0000000000000..6b5be44dce19c --- /dev/null +++ b/planner/core/casetest/widow_with_exist_subquery_test.go @@ -0,0 +1,77 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package casetest + +import ( + "testing" + + "github.com/pingcap/tidb/testkit" +) + +func TestWindowWithCorrelatedSubQuery(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + + tk.MustExec("use test") + tk.MustExec("CREATE TABLE temperature_data (temperature double);") + tk.MustExec("CREATE TABLE humidity_data (humidity double);") + tk.MustExec("CREATE TABLE weather_report (report_id double, report_date varchar(100));") + + tk.MustExec("INSERT INTO temperature_data VALUES (1.0);") + tk.MustExec("INSERT INTO humidity_data VALUES (0.5);") + tk.MustExec("INSERT INTO weather_report VALUES (2.0, 'test');") + + result := tk.MustQuery(` + SELECT + EXISTS ( + SELECT + FIRST_VALUE(temp_data.temperature) OVER weather_window AS first_temperature, + MIN(report_data.report_id) OVER weather_window AS min_report_id + FROM + temperature_data AS temp_data + WINDOW weather_window AS ( + PARTITION BY EXISTS ( + SELECT + report_data.report_date AS report_date + FROM + humidity_data AS humidity_data + WHERE temp_data.temperature >= humidity_data.humidity + ) + ) + ) AS is_exist + FROM + weather_report AS report_data; + `) + + result.Check(testkit.Rows("1")) + + result = tk.MustQuery(` + SELECT + EXISTS ( + SELECT + FIRST_VALUE(temp_data.temperature) OVER weather_window AS first_temperature, + MIN(report_data.report_id) OVER weather_window AS min_report_id + FROM + temperature_data AS temp_data + WINDOW weather_window AS ( + PARTITION BY temp_data.temperature + ) + ) AS is_exist + FROM + weather_report AS report_data; + `) + + result.Check(testkit.Rows("1")) +} diff --git a/planner/core/exhaust_physical_plans.go b/planner/core/exhaust_physical_plans.go index 21b388778ca08..9421511157b74 100644 --- a/planner/core/exhaust_physical_plans.go +++ b/planner/core/exhaust_physical_plans.go @@ -1194,7 +1194,7 @@ func getColsNDVLowerBoundFromHistColl(colUIDs []int64, histColl *statistics.Hist // 2. Try to get NDV from index stats. // Note that we don't need to specially handle prefix index here, because the NDV of a prefix index is // equal or less than the corresponding normal index, and that's safe here since we want a lower bound. - for idxID, idxCols := range histColl.Idx2ColumnIDs { + for idxID, idxCols := range histColl.Idx2ColUniqueIDs { if len(idxCols) != len(colUIDs) { continue } diff --git a/planner/core/find_best_task.go b/planner/core/find_best_task.go index ac2fefbe4115e..673d1463e5374 100644 --- a/planner/core/find_best_task.go +++ b/planner/core/find_best_task.go @@ -1863,7 +1863,7 @@ func (ds *DataSource) crossEstimateRowCount(path *util.AccessPath, conds []expre return 0, err == nil, corr } idxID := int64(-1) - idxIDs, idxExists := ds.stats.HistColl.ColID2IdxIDs[colID] + idxIDs, idxExists := ds.stats.HistColl.ColUniqueID2IdxIDs[colID] if idxExists && len(idxIDs) > 0 { idxID = idxIDs[0] } diff --git a/planner/core/issuetest/planner_issue_test.go b/planner/core/issuetest/planner_issue_test.go index 60bf3a076a055..4851511315bb5 100644 --- a/planner/core/issuetest/planner_issue_test.go +++ b/planner/core/issuetest/planner_issue_test.go @@ -154,53 +154,53 @@ func TestIssue47881(t *testing.T) { tk.MustExec("create table t2(id int,name varchar(10),name1 varchar(10),name2 varchar(10),`date1` date);") tk.MustExec("insert into t2 values(1,'tt','ttt','tttt','2099-12-31'),(2,'dd','ddd','dddd','2099-12-31');") rs := tk.MustQuery(`WITH bzzs AS ( - SELECT - count(1) AS bzn - FROM + SELECT + count(1) AS bzn + FROM t c - ), + ), tmp1 AS ( - SELECT - t1.* - FROM - t1 - LEFT JOIN bzzs ON 1 = 1 - WHERE - name IN ('tt') + SELECT + t1.* + FROM + t1 + LEFT JOIN bzzs ON 1 = 1 + WHERE + name IN ('tt') AND bzn <> 1 - ), + ), tmp2 AS ( - SELECT - tmp1.*, - date('2099-12-31') AS endate - FROM + SELECT + tmp1.*, + date('2099-12-31') AS endate + FROM tmp1 - ), + ), tmp3 AS ( - SELECT - * - FROM - tmp2 - WHERE - endate > CURRENT_DATE - UNION ALL - SELECT - '1' AS id, - 'ss' AS name, - 'sss' AS name1, - 'ssss' AS name2, - date('2099-12-31') AS endate - FROM - bzzs t1 - WHERE + SELECT + * + FROM + tmp2 + WHERE + endate > CURRENT_DATE + UNION ALL + SELECT + '1' AS id, + 'ss' AS name, + 'sss' AS name1, + 'ssss' AS name2, + date('2099-12-31') AS endate + FROM + bzzs t1 + WHERE bzn = 1 - ) - SELECT - c2.id, - c3.id - FROM - t2 db - LEFT JOIN tmp3 c2 ON c2.id = '1' + ) + SELECT + c2.id, + c3.id + FROM + t2 db + LEFT JOIN tmp3 c2 ON c2.id = '1' LEFT JOIN tmp3 c3 ON c3.id = '1';`) rs.Check(testkit.Rows("1 1", "1 1")) } @@ -230,3 +230,35 @@ func TestIssue51670(t *testing.T) { tk.MustQuery("select b.b from A a left join (B b left join C c on b.b = c.b) on b.b = a.b where a.a in (2, 3);").Sort().Check(testkit.Rows("1", "2")) tk.MustQuery("select b.b from A a left join (B b left join C c on b.b = c.b) on b.b = a.b where a.a in (2, 3, null);").Sort().Check(testkit.Rows("1", "2")) } + +func TestIssue50614(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("drop table if exists tt") + tk.MustExec("create table tt(a bigint, b bigint, c bigint, d bigint, e bigint, primary key(c,d));") + tk.MustQuery("explain format = brief " + + "update tt, (select 1 as c1 ,2 as c2 ,3 as c3, 4 as c4 union all select 2,3,4,5 union all select 3,4,5,6) tmp " + + "set tt.a=tmp.c1, tt.b=tmp.c2 " + + "where tt.c=tmp.c3 and tt.d=tmp.c4 and (tt.c,tt.d) in ((11,111),(22,222),(33,333),(44,444));").Check( + testkit.Rows( + "Update N/A root N/A", + "└─Projection 0.00 root test.tt.a, test.tt.b, test.tt.c, test.tt.d, test.tt.e, Column#18, Column#19, Column#20, Column#21", + " └─Projection 0.00 root test.tt.a, test.tt.b, test.tt.c, test.tt.d, test.tt.e, Column#18, Column#19, Column#20, Column#21", + " └─IndexJoin 0.00 root inner join, inner:TableReader, outer key:Column#20, Column#21, inner key:test.tt.c, test.tt.d, equal cond:eq(Column#20, test.tt.c), eq(Column#21, test.tt.d), other cond:or(or(and(eq(Column#20, 11), eq(test.tt.d, 111)), and(eq(Column#20, 22), eq(test.tt.d, 222))), or(and(eq(Column#20, 33), eq(test.tt.d, 333)), and(eq(Column#20, 44), eq(test.tt.d, 444)))), or(or(and(eq(test.tt.c, 11), eq(Column#21, 111)), and(eq(test.tt.c, 22), eq(Column#21, 222))), or(and(eq(test.tt.c, 33), eq(Column#21, 333)), and(eq(test.tt.c, 44), eq(Column#21, 444))))", + " ├─Union(Build) 0.00 root ", + " │ ├─Projection 0.00 root Column#6, Column#7, Column#8, Column#9", + " │ │ └─Projection 0.00 root 1->Column#6, 2->Column#7, 3->Column#8, 4->Column#9", + " │ │ └─TableDual 0.00 root rows:0", + " │ ├─Projection 0.00 root Column#10, Column#11, Column#12, Column#13", + " │ │ └─Projection 0.00 root 2->Column#10, 3->Column#11, 4->Column#12, 5->Column#13", + " │ │ └─TableDual 0.00 root rows:0", + " │ └─Projection 0.00 root Column#14, Column#15, Column#16, Column#17", + " │ └─Projection 0.00 root 3->Column#14, 4->Column#15, 5->Column#16, 6->Column#17", + " │ └─TableDual 0.00 root rows:0", + " └─TableReader(Probe) 0.00 root data:Selection", + " └─Selection 0.00 cop[tikv] or(or(and(eq(test.tt.c, 11), eq(test.tt.d, 111)), and(eq(test.tt.c, 22), eq(test.tt.d, 222))), or(and(eq(test.tt.c, 33), eq(test.tt.d, 333)), and(eq(test.tt.c, 44), eq(test.tt.d, 444)))), or(or(eq(test.tt.c, 11), eq(test.tt.c, 22)), or(eq(test.tt.c, 33), eq(test.tt.c, 44))), or(or(eq(test.tt.d, 111), eq(test.tt.d, 222)), or(eq(test.tt.d, 333), eq(test.tt.d, 444)))", + " └─TableRangeScan 0.00 cop[tikv] table:tt range: decided by [eq(test.tt.c, Column#20) eq(test.tt.d, Column#21)], keep order:false, stats:pseudo", + ), + ) +} diff --git a/planner/core/logical_plan_builder.go b/planner/core/logical_plan_builder.go index 526327586b363..dd6fec22e032d 100644 --- a/planner/core/logical_plan_builder.go +++ b/planner/core/logical_plan_builder.go @@ -6452,6 +6452,14 @@ func (b *PlanBuilder) buildByItemsForWindow( } if col, ok := it.(*expression.Column); ok { retItems = append(retItems, property.SortItem{Col: col, Desc: item.Desc}) + // We need to attempt to add this column because a subquery may be created during the expression rewrite process. + // Therefore, we need to ensure that the column from the newly created query plan is added. + // If the column is already in the schema, we don't need to add it again. + if !proj.schema.Contains(col) { + proj.Exprs = append(proj.Exprs, col) + proj.names = append(proj.names, types.EmptyName) + proj.schema.Append(col) + } continue } proj.Exprs = append(proj.Exprs, it) diff --git a/planner/core/logical_plans.go b/planner/core/logical_plans.go index 9c70f79ed104f..6a6ad6d5980cb 100644 --- a/planner/core/logical_plans.go +++ b/planner/core/logical_plans.go @@ -1142,9 +1142,13 @@ type LogicalMaxOneRow struct { } // LogicalTableDual represents a dual table plan. +// Note that sometimes we don't set schema for LogicalTableDual (most notably in buildTableDual()), which means +// outputting 0/1 row with zero column. This semantic may be different from your expectation sometimes but should not +// cause any actual problems now. type LogicalTableDual struct { logicalSchemaProducer + // RowCount could only be 0 or 1. RowCount int } @@ -1550,8 +1554,8 @@ func (ds *DataSource) fillIndexPath(path *util.AccessPath, conds []expression.Ex path.IdxCols = append(path.IdxCols, handleCol) path.IdxColLens = append(path.IdxColLens, types.UnspecifiedLength) // Also updates the map that maps the index id to its prefix column ids. - if len(ds.tableStats.HistColl.Idx2ColumnIDs[path.Index.ID]) == len(path.Index.Columns) { - ds.tableStats.HistColl.Idx2ColumnIDs[path.Index.ID] = append(ds.tableStats.HistColl.Idx2ColumnIDs[path.Index.ID], handleCol.UniqueID) + if len(ds.tableStats.HistColl.Idx2ColUniqueIDs[path.Index.ID]) == len(path.Index.Columns) { + ds.tableStats.HistColl.Idx2ColUniqueIDs[path.Index.ID] = append(ds.tableStats.HistColl.Idx2ColUniqueIDs[path.Index.ID], handleCol.UniqueID) } } } diff --git a/planner/core/main_test.go b/planner/core/main_test.go index 28390b1dd4227..30810fe168295 100644 --- a/planner/core/main_test.go +++ b/planner/core/main_test.go @@ -35,6 +35,7 @@ func TestMain(m *testing.M) { testDataMap.LoadTestSuiteData("testdata", "plan_suite_unexported") testDataMap.LoadTestSuiteData("testdata", "index_merge_suite") testDataMap.LoadTestSuiteData("testdata", "join_reorder_suite") + testDataMap.LoadTestSuiteData("testdata", "plan_stats_suite") indexMergeSuiteData = testDataMap["index_merge_suite"] planSuiteUnexportedData = testDataMap["plan_suite_unexported"] @@ -62,3 +63,7 @@ func GetIndexMergeSuiteData() testdata.TestData { func GetJoinReorderData() testdata.TestData { return testDataMap["join_reorder_suite"] } + +func GetPlanStatsData() testdata.TestData { + return testDataMap["plan_stats_suite"] +} diff --git a/planner/core/plan_stats.go b/planner/core/plan_stats.go index a4f579459f249..22b671dfb0d24 100644 --- a/planner/core/plan_stats.go +++ b/planner/core/plan_stats.go @@ -36,7 +36,7 @@ func (collectPredicateColumnsPoint) optimize(_ context.Context, plan LogicalPlan return plan, nil } predicateNeeded := variable.EnableColumnTracking.Load() - syncWait := plan.SCtx().GetSessionVars().StatsLoadSyncWait * time.Millisecond.Nanoseconds() + syncWait := plan.SCtx().GetSessionVars().StatsLoadSyncWait.Load() * time.Millisecond.Nanoseconds() histNeeded := syncWait > 0 predicateColumns, histNeededColumns := CollectColumnStatsUsage(plan, predicateNeeded, histNeeded) if len(predicateColumns) > 0 { diff --git a/planner/core/plan_stats_test.go b/planner/core/plan_stats_test.go index fa0072973bfc4..82a82d2072265 100644 --- a/planner/core/plan_stats_test.go +++ b/planner/core/plan_stats_test.go @@ -33,6 +33,7 @@ import ( "github.com/pingcap/tidb/statistics" "github.com/pingcap/tidb/statistics/handle" "github.com/pingcap/tidb/testkit" + "github.com/pingcap/tidb/testkit/testdata" "github.com/stretchr/testify/require" ) @@ -326,3 +327,48 @@ func TestPlanStatsStatusRecord(t *testing.T) { } } } + +func TestPartialStatsInExplain(t *testing.T) { + store, dom := testkit.CreateMockStoreAndDomain(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("create table t(a int, b int, c int, primary key(a), key idx(b))") + tk.MustExec("insert into t values (1,1,1),(2,2,2),(3,3,3)") + tk.MustExec("create table t2(a int, primary key(a))") + tk.MustExec("insert into t2 values (1),(2),(3)") + tk.MustExec( + "create table tp(a int, b int, c int, index ic(c)) partition by range(a)" + + "(partition p0 values less than (10)," + + "partition p1 values less than (20)," + + "partition p2 values less than maxvalue)", + ) + tk.MustExec("insert into tp values (1,1,1),(2,2,2),(13,13,13),(14,14,14),(25,25,25),(36,36,36)") + + oriLease := dom.StatsHandle().Lease() + dom.StatsHandle().SetLease(1) + defer func() { + dom.StatsHandle().SetLease(oriLease) + }() + tk.MustExec("analyze table t") + tk.MustExec("analyze table t2") + tk.MustExec("analyze table tp") + require.NoError(t, dom.StatsHandle().Update(dom.InfoSchema())) + tk.MustQuery("explain select * from tp where a = 1") + tk.MustExec("set @@tidb_stats_load_sync_wait = 0") + var ( + input []string + output []struct { + Query string + Result []string + } + ) + testData := plannercore.GetPlanStatsData() + testData.LoadTestCases(t, &input, &output) + for i, sql := range input { + testdata.OnRecord(func() { + output[i].Query = input[i] + output[i].Result = testdata.ConvertRowsToStrings(tk.MustQuery(sql).Rows()) + }) + tk.MustQuery(sql).Check(testkit.Rows(output[i].Result...)) + } +} diff --git a/planner/core/preprocess.go b/planner/core/preprocess.go index 3a551f19acb64..ae1c4006b4c66 100644 --- a/planner/core/preprocess.go +++ b/planner/core/preprocess.go @@ -139,7 +139,7 @@ func Preprocess(ctx context.Context, sctx sessionctx.Context, node ast.Node, pre return errors.Trace(v.err) } -type preprocessorFlag uint8 +type preprocessorFlag uint64 const ( // inPrepare is set when visiting in prepare statement. @@ -157,6 +157,8 @@ const ( inSequenceFunction // initTxnContextProvider is set when we should init txn context in preprocess initTxnContextProvider + // inAnalyze is set when visiting an analyze statement. + inAnalyze ) // Make linter happy. @@ -393,6 +395,8 @@ func (p *preprocessor) Enter(in ast.Node) (out ast.Node, skipChildren bool) { p.sctx.GetSessionVars().StmtCtx.IsStaleness = true p.IsStaleness = true } + case *ast.AnalyzeTableStmt: + p.flag |= inAnalyze default: p.flag &= ^parentIsJoin } @@ -1568,10 +1572,12 @@ func (p *preprocessor) handleTableName(tn *ast.TableName) { if tn.Schema.String() != "" { currentDB = tn.Schema.L } - table, err = tryLockMDLAndUpdateSchemaIfNecessary(p.sctx, model.NewCIStr(currentDB), table, p.ensureInfoSchema()) - if err != nil { - p.err = err - return + if !p.skipLockMDL() { + table, err = tryLockMDLAndUpdateSchemaIfNecessary(p.sctx, model.NewCIStr(currentDB), table, p.ensureInfoSchema()) + if err != nil { + p.err = err + return + } } tableInfo := table.Meta() @@ -1907,3 +1913,11 @@ func tryLockMDLAndUpdateSchemaIfNecessary(sctx sessionctx.Context, dbName model. } return tbl, nil } + +// skipLockMDL returns true if the preprocessor should skip the lock of MDL. +func (p *preprocessor) skipLockMDL() bool { + // skip lock mdl for IMPORT INTO statement, + // because it's a batch process and will do both DML and DDL. + // skip lock mdl for ANALYZE statement. + return p.flag&inAnalyze > 0 +} diff --git a/planner/core/rule_predicate_push_down.go b/planner/core/rule_predicate_push_down.go index 2db96bd69a973..114d511d83ade 100644 --- a/planner/core/rule_predicate_push_down.go +++ b/planner/core/rule_predicate_push_down.go @@ -500,11 +500,6 @@ func (p *LogicalProjection) PredicatePushDown(predicates []expression.Expression return predicates, child } } - if len(p.children) == 1 { - if _, isDual := p.children[0].(*LogicalTableDual); isDual { - return predicates, p - } - } for _, cond := range predicates { substituted, hasFailed, newFilter := expression.ColumnSubstituteImpl(cond, p.Schema(), p.Exprs, true) if substituted && !hasFailed && !expression.HasGetSetVarFunc(newFilter) { diff --git a/planner/core/stats.go b/planner/core/stats.go index 3baad8e473323..3b83d5d9aefa2 100644 --- a/planner/core/stats.go +++ b/planner/core/stats.go @@ -210,8 +210,8 @@ func (ds *DataSource) getGroupNDVs(colGroups [][]*expression.Column) []property. tbl := ds.tableStats.HistColl ndvs := make([]property.GroupNDV, 0, len(colGroups)) for idxID, idx := range tbl.Indices { - colsLen := len(tbl.Idx2ColumnIDs[idxID]) - // tbl.Idx2ColumnIDs may only contain the prefix of index columns. + colsLen := len(tbl.Idx2ColUniqueIDs[idxID]) + // tbl.Idx2ColUniqueIDs may only contain the prefix of index columns. // But it may exceeds the total index since the index would contain the handle column if it's not a unique index. // We append the handle at fillIndexPath. if colsLen < len(idx.Info.Columns) { @@ -220,7 +220,7 @@ func (ds *DataSource) getGroupNDVs(colGroups [][]*expression.Column) []property. colsLen-- } idxCols := make([]int64, colsLen) - copy(idxCols, tbl.Idx2ColumnIDs[idxID]) + copy(idxCols, tbl.Idx2ColUniqueIDs[idxID]) slices.Sort(idxCols) for _, g := range colGroups { // We only want those exact matches. diff --git a/planner/core/testdata/plan_stats_suite_in.json b/planner/core/testdata/plan_stats_suite_in.json new file mode 100644 index 0000000000000..aed7359bfac87 --- /dev/null +++ b/planner/core/testdata/plan_stats_suite_in.json @@ -0,0 +1,10 @@ +[ + { + "name": "TestPartialStatsInExplain", + "cases": [ + "explain format = brief select * from tp where b = 10", + "explain format = brief select * from t join tp where tp.a = 10 and t.b = tp.c", + "explain format = brief select * from t join tp partition (p0) join t2 where t.a < 10 and t.b = tp.c and t2.a > 10 and t2.a = tp.c" + ] + } +] diff --git a/planner/core/testdata/plan_stats_suite_out.json b/planner/core/testdata/plan_stats_suite_out.json new file mode 100644 index 0000000000000..617eab8ebc9a7 --- /dev/null +++ b/planner/core/testdata/plan_stats_suite_out.json @@ -0,0 +1,43 @@ +[ + { + "Name": "TestPartialStatsInExplain", + "Cases": [ + { + "Query": "explain format = brief select * from tp where b = 10", + "Result": [ + "TableReader 0.01 root partition:all data:Selection", + "└─Selection 0.01 cop[tikv] eq(test.tp.b, 10)", + " └─TableFullScan 6.00 cop[tikv] table:tp keep order:false, stats:partial[b:allEvicted]" + ] + }, + { + "Query": "explain format = brief select * from t join tp where tp.a = 10 and t.b = tp.c", + "Result": [ + "Projection 0.00 root test.t.a, test.t.b, test.t.c, test.tp.a, test.tp.b, test.tp.c", + "└─HashJoin 0.00 root inner join, equal:[eq(test.tp.c, test.t.b)]", + " ├─TableReader(Build) 0.00 root partition:p1 data:Selection", + " │ └─Selection 0.00 cop[tikv] eq(test.tp.a, 10), not(isnull(test.tp.c))", + " │ └─TableFullScan 6.00 cop[tikv] table:tp keep order:false, stats:partial[c:allEvicted]", + " └─TableReader(Probe) 3.00 root data:Selection", + " └─Selection 3.00 cop[tikv] not(isnull(test.t.b))", + " └─TableFullScan 3.00 cop[tikv] table:t keep order:false, stats:partial[b:allEvicted]" + ] + }, + { + "Query": "explain format = brief select * from t join tp partition (p0) join t2 where t.a < 10 and t.b = tp.c and t2.a > 10 and t2.a = tp.c", + "Result": [ + "HashJoin 0.00 root inner join, equal:[eq(test.tp.c, test.t2.a)]", + "├─TableReader(Build) 0.00 root data:TableRangeScan", + "│ └─TableRangeScan 0.00 cop[tikv] table:t2 range:(10,+inf], keep order:false", + "└─HashJoin(Probe) 0.00 root inner join, equal:[eq(test.t.b, test.tp.c)]", + " ├─TableReader(Build) 0.00 root data:Selection", + " │ └─Selection 0.00 cop[tikv] gt(test.t.b, 10), not(isnull(test.t.b))", + " │ └─TableRangeScan 3.00 cop[tikv] table:t range:[-inf,10), keep order:false, stats:partial[b:allEvicted]", + " └─TableReader(Probe) 4.00 root partition:p0 data:Selection", + " └─Selection 4.00 cop[tikv] gt(test.tp.c, 10), not(isnull(test.tp.c))", + " └─TableFullScan 6.00 cop[tikv] table:tp keep order:false, stats:partial[c:allEvicted]" + ] + } + ] + } +] diff --git a/privilege/privileges/ldap/BUILD.bazel b/privilege/privileges/ldap/BUILD.bazel index a807291caf073..fb822094db1a3 100644 --- a/privilege/privileges/ldap/BUILD.bazel +++ b/privilege/privileges/ldap/BUILD.bazel @@ -12,9 +12,11 @@ go_library( visibility = ["//visibility:public"], deps = [ "//privilege/conn", + "//util/logutil", "@com_github_go_ldap_ldap_v3//:ldap", "@com_github_ngaut_pools//:pools", "@com_github_pingcap_errors//:errors", + "@org_uber_go_zap//:zap", ], ) @@ -23,6 +25,7 @@ go_test( timeout = "short", srcs = ["ldap_common_test.go"], embed = [":ldap"], + embedsrcs = ["test/ca.crt"], flaky = True, deps = ["@com_github_stretchr_testify//require"], ) diff --git a/privilege/privileges/ldap/ldap_common.go b/privilege/privileges/ldap/ldap_common.go index e28f15d4ef447..9d450680f9522 100644 --- a/privilege/privileges/ldap/ldap_common.go +++ b/privilege/privileges/ldap/ldap_common.go @@ -22,12 +22,24 @@ import ( "os" "strconv" "sync" + "time" "github.com/go-ldap/ldap/v3" "github.com/ngaut/pools" "github.com/pingcap/errors" + "github.com/pingcap/tidb/util/logutil" + "go.uber.org/zap" ) +// ldapTimeout is set to 10s. It works on both the TCP/TLS dialing timeout, and the LDAP request timeout. For connection with TLS, the +// user may find that it fails after 2*ldapTimeout, because TiDB will try to connect through both `StartTLS` (from a normal TCP connection) +// and `TLS`, therefore the total time is 2*ldapTimeout. +var ldapTimeout = 10 * time.Second + +// skipTLSForTest is used to skip trying to connect with TLS directly in tests. If it's set to false, connection will only try to +// use `StartTLS` +var skipTLSForTest = false + // ldapAuthImpl gives the internal utilities of authentication with LDAP. // The getter and setter methods will lock the mutex inside, while all other methods don't, so all other method call // should be protected by `impl.Lock()`. @@ -120,10 +132,13 @@ func (impl *ldapAuthImpl) connectionFactory() (pools.Resource, error) { // It's fine to load these two TLS configurations one-by-one (but not guarded by a single lock), because there isn't // a way to set two variables atomically. if impl.enableTLS { - ldapConnection, err := ldap.Dial("tcp", address) + ldapConnection, err := ldap.DialURL("ldap://"+address, ldap.DialWithDialer(&net.Dialer{ + Timeout: ldapTimeout, + })) if err != nil { return nil, errors.Wrap(err, "create ldap connection") } + ldapConnection.SetTimeout(ldapTimeout) err = ldapConnection.StartTLS(&tls.Config{ RootCAs: impl.caPool, @@ -134,15 +149,19 @@ func (impl *ldapAuthImpl) connectionFactory() (pools.Resource, error) { } return ldapConnection, nil } - ldapConnection, err := ldap.Dial("tcp", address) + ldapConnection, err := ldap.DialURL("ldap://"+address, ldap.DialWithDialer(&net.Dialer{ + Timeout: ldapTimeout, + })) if err != nil { return nil, errors.Wrap(err, "create ldap connection") } + ldapConnection.SetTimeout(ldapTimeout) return ldapConnection, nil } const getConnectionMaxRetry = 10 +const getConnectionRetryInterval = 500 * time.Millisecond func (impl *ldapAuthImpl) getConnection() (*ldap.Conn, error) { retryCount := 0 @@ -163,6 +182,9 @@ func (impl *ldapAuthImpl) getConnection() (*ldap.Conn, error) { Password: impl.bindRootPWD, }) if err != nil { + logutil.BgLogger().Warn("fail to use LDAP connection bind to anonymous user. Retrying", zap.Error(err), + zap.Duration("backoff", getConnectionRetryInterval)) + // fail to bind to anonymous user, just release this connection and try to get a new one impl.ldapConnectionPool.Put(nil) @@ -170,6 +192,9 @@ func (impl *ldapAuthImpl) getConnection() (*ldap.Conn, error) { if retryCount >= getConnectionMaxRetry { return nil, errors.Wrap(err, "fail to bind to anonymous user") } + // Be careful that it's still holding the lock of the system variables, so it's not good to sleep here. + // TODO: refactor the `RWLock` to avoid the problem of holding the lock. + time.Sleep(getConnectionRetryInterval) continue } @@ -182,12 +207,12 @@ func (impl *ldapAuthImpl) putConnection(conn *ldap.Conn) { } func (impl *ldapAuthImpl) initializePool() { - if impl.ldapConnectionPool != nil { - impl.ldapConnectionPool.Close() - } - - // skip initialization when the variables are not correct + // skip re-initialization when the variables are not correct if impl.initCapacity > 0 && impl.maxCapacity >= impl.initCapacity { + if impl.ldapConnectionPool != nil { + impl.ldapConnectionPool.Close() + } + impl.ldapConnectionPool = pools.NewResourcePool(impl.connectionFactory, impl.initCapacity, impl.maxCapacity, 0) } } @@ -232,6 +257,7 @@ func (impl *ldapAuthImpl) SetLDAPServerHost(ldapServerHost string) { if ldapServerHost != impl.ldapServerHost { impl.ldapServerHost = ldapServerHost + impl.initializePool() } } @@ -242,6 +268,7 @@ func (impl *ldapAuthImpl) SetLDAPServerPort(ldapServerPort int) { if ldapServerPort != impl.ldapServerPort { impl.ldapServerPort = ldapServerPort + impl.initializePool() } } @@ -252,6 +279,7 @@ func (impl *ldapAuthImpl) SetEnableTLS(enableTLS bool) { if enableTLS != impl.enableTLS { impl.enableTLS = enableTLS + impl.initializePool() } } diff --git a/privilege/privileges/ldap/ldap_common_test.go b/privilege/privileges/ldap/ldap_common_test.go index 7794b5ab5f3d4..d8e8a870015c3 100644 --- a/privilege/privileges/ldap/ldap_common_test.go +++ b/privilege/privileges/ldap/ldap_common_test.go @@ -15,11 +15,21 @@ package ldap import ( + "crypto/x509" + _ "embed" + "fmt" + "math/rand" + "net" + "sync" "testing" + "time" "github.com/stretchr/testify/require" ) +//go:embed test/ca.crt +var tlsCAStr []byte + func TestCanonicalizeDN(t *testing.T) { impl := &ldapAuthImpl{ searchAttr: "cn", @@ -27,3 +37,64 @@ func TestCanonicalizeDN(t *testing.T) { require.Equal(t, impl.canonicalizeDN("yka", "cn=y,dc=ping,dc=cap"), "cn=y,dc=ping,dc=cap") require.Equal(t, impl.canonicalizeDN("yka", "+dc=ping,dc=cap"), "cn=yka,dc=ping,dc=cap") } + +func TestLDAPStartTLSTimeout(t *testing.T) { + originalTimeout := ldapTimeout + ldapTimeout = time.Second * 2 + skipTLSForTest = true + defer func() { + ldapTimeout = originalTimeout + skipTLSForTest = false + }() + + var ln net.Listener + startListen := make(chan struct{}) + afterTimeout := make(chan struct{}) + defer close(afterTimeout) + + // this test only tests whether the LDAP with LTS enabled will fallback from StartTLS + randomTLSServicePort := rand.Int()%10000 + 10000 + serverWg := &sync.WaitGroup{} + serverWg.Add(1) + go func() { + var err error + defer close(startListen) + defer serverWg.Done() + + ln, err = net.Listen("tcp", fmt.Sprintf("localhost:%d", randomTLSServicePort)) + require.NoError(t, err) + startListen <- struct{}{} + + conn, err := ln.Accept() + require.NoError(t, err) + + <-afterTimeout + require.NoError(t, conn.Close()) + + // close the server + require.NoError(t, ln.Close()) + }() + + <-startListen + defer func() { + serverWg.Wait() + }() + + impl := &ldapAuthImpl{} + impl.SetEnableTLS(true) + impl.SetLDAPServerHost("localhost") + impl.SetLDAPServerPort(randomTLSServicePort) + + impl.caPool = x509.NewCertPool() + require.True(t, impl.caPool.AppendCertsFromPEM(tlsCAStr)) + impl.SetInitCapacity(1) + impl.SetMaxCapacity(1) + + now := time.Now() + _, err := impl.connectionFactory() + afterTimeout <- struct{}{} + dur := time.Since(now) + require.Greater(t, dur, 2*time.Second) + require.Less(t, dur, 3*time.Second) + require.ErrorContains(t, err, "connection timed out") +} diff --git a/privilege/privileges/ldap/test/ca.crt b/privilege/privileges/ldap/test/ca.crt new file mode 100644 index 0000000000000..cbef4f3fb2bc4 --- /dev/null +++ b/privilege/privileges/ldap/test/ca.crt @@ -0,0 +1,31 @@ +-----BEGIN CERTIFICATE----- +MIIFZTCCA02gAwIBAgIUZ2hQOFVvjuAHrC8Tv+5dnwPGvF0wDQYJKoZIhvcNAQEL +BQAwQjELMAkGA1UEBhMCWFgxFTATBgNVBAcMDERlZmF1bHQgQ2l0eTEcMBoGA1UE +CgwTRGVmYXVsdCBDb21wYW55IEx0ZDAeFw0yMzA0MjQwNjM1MTRaFw0yODA0MjMw +NjM1MTRaMEIxCzAJBgNVBAYTAlhYMRUwEwYDVQQHDAxEZWZhdWx0IENpdHkxHDAa +BgNVBAoME0RlZmF1bHQgQ29tcGFueSBMdGQwggIiMA0GCSqGSIb3DQEBAQUAA4IC +DwAwggIKAoICAQDFvQt3xupYFQxZsyQPr2byhR9ZHoAWBxxqNWxbvpqy7RzHeccJ +Jg36dO1BNIBY8NjIy/JHV7eLDVGCh9FTGozn8dODQMOwDXTYqxFOiBHb2/M9wxVX +ILafa1GlsOnUFxEws9T0XH7ZBqMLC/KlXdJ5xQD1C36K1eWHvphjD0AFhgUnqQ4N +O3NT3tJjzcY7oXBoKpgSgQs7qeTdJLTKJE7pY02C/hJI2WvJDdIiEhZTi0UWqE06 +5aXp8Heag/H4VlZzRA+RzQuDXqgXC3Bt7mJwtnoym0HgyTvoKBKO/vzfAW1yQhGo +6ikfSZkvIy3kyPAxD1FSWeSA0Xo8soGNDUsZjV6dQRtcnlWLPFA+7VfwivCPNiFh +91csXhHNEkYPNq4yCe6ZsycydZ+GNdNygzIrMjQ+DjNnHXXmfdeiiLLJbyxYzwuu +GaAT9eD98vXeJFhuWSbKyj4oXcMKnj3bTAQnudMCHIV8cMDe7Zuq1d4/gjXvk95Z +s/OnxqRYYNTXECkdLrevAPfGI2Qg9d7IrhnAh6KqCDDiFkhDkS5TMbWeHA88ZPKZ +6RvLYZmA+j3tjtKPpta9yPiTglExjBUDHIe+37K84G4p0C4bEo5RxEop5hHuX4EW +QvwNb0254i+RCsdyt+tFHiAVzo3/mTg9EMkWlTzoy0381HytFNcLDGb37QIDAQAB +o1MwUTAdBgNVHQ4EFgQUHs+YTH+x0YRNja11v18CF4iu5XowHwYDVR0jBBgwFoAU +Hs+YTH+x0YRNja11v18CF4iu5XowDwYDVR0TAQH/BAUwAwEB/zANBgkqhkiG9w0B +AQsFAAOCAgEAqf2ukpuqXC7u0kTqWdrw+3x2R+mwoLSsy7QbpsjdPbJwVoEysC02 +WHiZcMj+w75wZQ2SS5Md1+1K4HiIWC3n+eRDodz+Di0Hg9lxtoMFuOIBpnYUsDuA +Fooo/B7HadZkw9AbWFxPK5EGLyfCRuZF50981svSX5rnYqgCLLs0zGxr7Uswhzvh +3fQMDd0OGLST0M4FW/pQkRYIWnQ4zn/n+wJaHBeaKXHJ7QfgNCtVXOLTXdzpreIL +RvvydcOdVoPnjMgCs1jyhB6g226+/nOuQqic53pxnUTUTvHFJQ5B6/JlzMHeJS1m +ycvSxmF+3RqhjePiwRAT/ui9FBXkhXSG3wp0n0w83rpq7Ne0uwPH/KE9hqFkiI5x +PzobjKy6ahzoSbZrw/a4gDXfZe3fYGtm1EdyDYTh1HFCi7hkdoxY9iCIL1Gr+mpB +YruELQZ+RpvZQ7V8JN7bPtzWfPybPkOSozP1EoLXhUAnXl4/DinoBZvum1MpvPWY +sKF9qQfTP6cAqOuIL1LcVhJ7yHAjR+BK7tvhA2h4sIqxEjhlDmJjH0XMr9JpYQcb +yBzNgkS0YycMPJru0zb2p7vodql5rxSWArQW13Pyqza6N803Qk3vP0/SCfYXeR/i +hv/InNBQBwfHo79FBEv/T7UB8yS7CIu75f562jp23DFKUQD+doybmDg= +-----END CERTIFICATE----- \ No newline at end of file diff --git a/sessionctx/variable/session.go b/sessionctx/variable/session.go index 2be92ac60ac07..dc09defed1662 100644 --- a/sessionctx/variable/session.go +++ b/sessionctx/variable/session.go @@ -1279,7 +1279,7 @@ type SessionVars struct { ReadConsistency ReadConsistencyLevel // StatsLoadSyncWait indicates how long to wait for stats load before timeout. - StatsLoadSyncWait int64 + StatsLoadSyncWait atomic.Int64 // SysdateIsNow indicates whether Sysdate is an alias of Now function SysdateIsNow bool @@ -1944,7 +1944,6 @@ func NewSessionVars(hctx HookContext) *SessionVars { TMPTableSize: DefTiDBTmpTableMaxSize, MPPStoreFailTTL: DefTiDBMPPStoreFailTTL, Rng: mathutil.NewWithTime(), - StatsLoadSyncWait: StatsLoadSyncWait.Load(), EnableLegacyInstanceScope: DefEnableLegacyInstanceScope, RemoveOrderbyInSubquery: DefTiDBRemoveOrderbyInSubquery, EnableSkewDistinctAgg: DefTiDBSkewDistinctAgg, @@ -2003,6 +2002,7 @@ func NewSessionVars(hctx HookContext) *SessionVars { vars.DiskTracker = disk.NewTracker(memory.LabelForSession, -1) vars.MemTracker = memory.NewTracker(memory.LabelForSession, vars.MemQuotaQuery) vars.MemTracker.IsRootTrackerOfSess = true + vars.StatsLoadSyncWait.Store(StatsLoadSyncWait.Load()) for _, engine := range config.GetGlobalConfig().IsolationRead.Engines { switch engine { diff --git a/sessionctx/variable/sysvar.go b/sessionctx/variable/sysvar.go index e3107b6e50f32..26702e212ed4c 100644 --- a/sessionctx/variable/sysvar.go +++ b/sessionctx/variable/sysvar.go @@ -2073,7 +2073,7 @@ var defaultSysVars = []*SysVar{ }}, {Scope: ScopeGlobal | ScopeSession, Name: TiDBStatsLoadSyncWait, Value: strconv.Itoa(DefTiDBStatsLoadSyncWait), Type: TypeInt, MinValue: 0, MaxValue: math.MaxInt32, SetSession: func(s *SessionVars, val string) error { - s.StatsLoadSyncWait = TidbOptInt64(val, DefTiDBStatsLoadSyncWait) + s.StatsLoadSyncWait.Store(TidbOptInt64(val, DefTiDBStatsLoadSyncWait)) return nil }, GetGlobal: func(_ context.Context, s *SessionVars) (string, error) { diff --git a/statistics/handle/BUILD.bazel b/statistics/handle/BUILD.bazel index e3a29db97752d..fe11939f9a054 100644 --- a/statistics/handle/BUILD.bazel +++ b/statistics/handle/BUILD.bazel @@ -79,6 +79,7 @@ go_test( "//domain", "//parser/model", "//parser/mysql", + "//sessionctx", "//sessionctx/stmtctx", "//sessionctx/variable", "//statistics", diff --git a/statistics/handle/handle_hist.go b/statistics/handle/handle_hist.go index cee091f2876cb..0fe016d06601a 100644 --- a/statistics/handle/handle_hist.go +++ b/statistics/handle/handle_hist.go @@ -198,7 +198,7 @@ func (h *Handle) SubLoadWorker(ctx sessionctx.Context, exit chan struct{}, exitW // if the last task is not successfully handled in last round for error or panic, pass it to this round to retry var lastTask *NeededItemTask for { - task, err := h.HandleOneTask(lastTask, readerCtx, ctx.(sqlexec.RestrictedSQLExecutor), exit) + task, err := h.HandleOneTask(ctx, lastTask, readerCtx, ctx.(sqlexec.RestrictedSQLExecutor), exit) lastTask = task if err != nil { switch err { @@ -216,7 +216,7 @@ func (h *Handle) SubLoadWorker(ctx sessionctx.Context, exit chan struct{}, exitW } // HandleOneTask handles last task if not nil, else handle a new task from chan, and return current task if fail somewhere. -func (h *Handle) HandleOneTask(lastTask *NeededItemTask, readerCtx *StatsReaderContext, ctx sqlexec.RestrictedSQLExecutor, exit chan struct{}) (task *NeededItemTask, err error) { +func (h *Handle) HandleOneTask(sctx sessionctx.Context, lastTask *NeededItemTask, readerCtx *StatsReaderContext, ctx sqlexec.RestrictedSQLExecutor, exit chan struct{}) (task *NeededItemTask, err error) { defer func() { // recover for each task, worker keeps working if r := recover(); r != nil { @@ -225,7 +225,7 @@ func (h *Handle) HandleOneTask(lastTask *NeededItemTask, readerCtx *StatsReaderC } }() if lastTask == nil { - task, err = h.drainColTask(exit) + task, err = h.drainColTask(sctx, exit) if err != nil { if err != errExit { logutil.BgLogger().Error("Fail to drain task for stats loading.", zap.Error(err)) @@ -408,7 +408,7 @@ func (h *Handle) readStatsForOneItem(item model.TableItemID, w *statsWrapper, re } // drainColTask will hang until a column task can return, and either task or error will be returned. -func (h *Handle) drainColTask(exit chan struct{}) (*NeededItemTask, error) { +func (h *Handle) drainColTask(sctx sessionctx.Context, exit chan struct{}) (*NeededItemTask, error) { // select NeededColumnsCh firstly, if no task, then select TimeoutColumnsCh for { select { @@ -421,6 +421,7 @@ func (h *Handle) drainColTask(exit chan struct{}) (*NeededItemTask, error) { // if the task has already timeout, no sql is sync-waiting for it, // so do not handle it just now, put it to another channel with lower priority if time.Now().After(task.ToTimeout) { + task.ToTimeout.Add(time.Duration(sctx.GetSessionVars().StatsLoadSyncWait.Load()) * time.Microsecond) h.writeToTimeoutChan(h.StatsLoad.TimeoutItemsCh, task) continue } diff --git a/statistics/handle/handle_hist_test.go b/statistics/handle/handle_hist_test.go index 8febf5827165d..bd38bb628bfd8 100644 --- a/statistics/handle/handle_hist_test.go +++ b/statistics/handle/handle_hist_test.go @@ -21,6 +21,7 @@ import ( "github.com/pingcap/failpoint" "github.com/pingcap/tidb/config" "github.com/pingcap/tidb/parser/model" + "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/sessionctx/stmtctx" "github.com/pingcap/tidb/statistics/handle" "github.com/pingcap/tidb/testkit" @@ -205,7 +206,7 @@ func TestConcurrentLoadHistWithPanicAndFail(t *testing.T) { exitCh := make(chan struct{}) require.NoError(t, failpoint.Enable(fp.failPath, fp.inTerms)) - task1, err1 := h.HandleOneTask(nil, readerCtx, testKit.Session().(sqlexec.RestrictedSQLExecutor), exitCh) + task1, err1 := h.HandleOneTask(testKit.Session().(sessionctx.Context), nil, readerCtx, testKit.Session().(sqlexec.RestrictedSQLExecutor), exitCh) require.Error(t, err1) require.NotNil(t, task1) list, ok := h.StatsLoad.WorkingColMap[neededColumns[0]] @@ -213,7 +214,7 @@ func TestConcurrentLoadHistWithPanicAndFail(t *testing.T) { require.Len(t, list, 1) require.Equal(t, stmtCtx1.StatsLoad.ResultCh, list[0]) - task2, err2 := h.HandleOneTask(nil, readerCtx, testKit.Session().(sqlexec.RestrictedSQLExecutor), exitCh) + task2, err2 := h.HandleOneTask(testKit.Session().(sessionctx.Context), nil, readerCtx, testKit.Session().(sqlexec.RestrictedSQLExecutor), exitCh) require.Nil(t, err2) require.Nil(t, task2) list, ok = h.StatsLoad.WorkingColMap[neededColumns[0]] @@ -222,7 +223,7 @@ func TestConcurrentLoadHistWithPanicAndFail(t *testing.T) { require.Equal(t, stmtCtx2.StatsLoad.ResultCh, list[1]) require.NoError(t, failpoint.Disable(fp.failPath)) - task3, err3 := h.HandleOneTask(task1, readerCtx, testKit.Session().(sqlexec.RestrictedSQLExecutor), exitCh) + task3, err3 := h.HandleOneTask(testKit.Session().(sessionctx.Context), task1, readerCtx, testKit.Session().(sqlexec.RestrictedSQLExecutor), exitCh) require.NoError(t, err3) require.Nil(t, task3) diff --git a/statistics/histogram.go b/statistics/histogram.go index 642836c83a941..aba5776ce1bae 100644 --- a/statistics/histogram.go +++ b/statistics/histogram.go @@ -1094,11 +1094,11 @@ func newHistogramBySelectivity(sctx sessionctx.Context, histID int64, oldHist, n // NewHistCollBySelectivity creates new HistColl by the given statsNodes. func (coll *HistColl) NewHistCollBySelectivity(sctx sessionctx.Context, statsNodes []*StatsNode) *HistColl { newColl := &HistColl{ - Columns: make(map[int64]*Column), - Indices: make(map[int64]*Index), - Idx2ColumnIDs: coll.Idx2ColumnIDs, - ColID2IdxIDs: coll.ColID2IdxIDs, - RealtimeCount: coll.RealtimeCount, + Columns: make(map[int64]*Column), + Indices: make(map[int64]*Index), + Idx2ColUniqueIDs: coll.Idx2ColUniqueIDs, + ColUniqueID2IdxIDs: coll.ColUniqueID2IdxIDs, + RealtimeCount: coll.RealtimeCount, } for _, node := range statsNodes { if node.Tp == IndexType { diff --git a/statistics/index.go b/statistics/index.go index 1217a641d89eb..88c3bbfb55b06 100644 --- a/statistics/index.go +++ b/statistics/index.go @@ -395,7 +395,7 @@ func (idx *Index) expBackoffEstimation(sctx sessionctx.Context, coll *HistColl, Collators: make([]collate.Collator, 1), }, } - colsIDs := coll.Idx2ColumnIDs[idx.Histogram.ID] + colsIDs := coll.Idx2ColUniqueIDs[idx.Histogram.ID] singleColumnEstResults := make([]float64, 0, len(indexRange.LowVal)) // The following codes uses Exponential Backoff to reduce the impact of independent assumption. It works like: // 1. Calc the selectivity of each column. @@ -420,7 +420,7 @@ func (idx *Index) expBackoffEstimation(sctx sessionctx.Context, coll *HistColl, foundStats = true count, err = coll.GetRowCountByColumnRanges(sctx, colID, tmpRan) } - if idxIDs, ok := coll.ColID2IdxIDs[colID]; ok && !foundStats && len(indexRange.LowVal) > 1 { + if idxIDs, ok := coll.ColUniqueID2IdxIDs[colID]; ok && !foundStats && len(indexRange.LowVal) > 1 { // Note the `len(indexRange.LowVal) > 1` condition here, it means we only recursively call // `GetRowCountByIndexRanges()` when the input `indexRange` is a multi-column range. This // check avoids infinite recursion. diff --git a/statistics/selectivity.go b/statistics/selectivity.go index 8205286880513..1c42becdd0ead 100644 --- a/statistics/selectivity.go +++ b/statistics/selectivity.go @@ -290,7 +290,7 @@ func (coll *HistColl) Selectivity( slices.Sort(idxIDs) for _, id := range idxIDs { idxStats := coll.Indices[id] - idxCols := FindPrefixOfIndexByCol(extractedCols, coll.Idx2ColumnIDs[id], id2Paths[idxStats.ID]) + idxCols := FindPrefixOfIndexByCol(extractedCols, coll.Idx2ColUniqueIDs[id], id2Paths[idxStats.ID]) if len(idxCols) > 0 { lengths := make([]int, 0, len(idxCols)) for i := 0; i < len(idxCols) && i < len(idxStats.Info.Columns); i++ { diff --git a/statistics/selectivity_test.go b/statistics/selectivity_test.go index 03ac9ae59ff21..17d8774ac1aa5 100644 --- a/statistics/selectivity_test.go +++ b/statistics/selectivity_test.go @@ -1106,8 +1106,8 @@ func generateMapsForMockStatsTbl(statsTbl *statistics.Table) { for _, idxIDs := range colID2IdxIDs { slices.Sort(idxIDs) } - statsTbl.Idx2ColumnIDs = idx2Columns - statsTbl.ColID2IdxIDs = colID2IdxIDs + statsTbl.Idx2ColUniqueIDs = idx2Columns + statsTbl.ColUniqueID2IdxIDs = colID2IdxIDs } func TestIssue39593(t *testing.T) { diff --git a/statistics/table.go b/statistics/table.go index f63b600f339f0..8d72733cef3a9 100644 --- a/statistics/table.go +++ b/statistics/table.go @@ -106,13 +106,11 @@ const ( // HistColl is a collection of histogram. It collects enough information for plan to calculate the selectivity. type HistColl struct { - PhysicalID int64 + // Note that when used in a query, Column use UniqueID as the key while Indices use the index ID in the + // metadata. (See GenerateHistCollFromColumnInfo() for details) Columns map[int64]*Column Indices map[int64]*Index - // Idx2ColumnIDs maps the index id to its column ids. It's used to calculate the selectivity in planner. - Idx2ColumnIDs map[int64][]int64 - // ColID2IdxIDs maps the column id to a list index ids whose first column is it. It's used to calculate the selectivity in planner. - ColID2IdxIDs map[int64][]int64 + PhysicalID int64 // TODO: add AnalyzeCount here RealtimeCount int64 // RealtimeCount is the current table row count, maintained by applying stats delta based on AnalyzeCount. ModifyCount int64 // Total modify count in a table. @@ -121,6 +119,18 @@ type HistColl struct { // The physical id is used when try to load column stats from storage. HavePhysicalID bool Pseudo bool + + /* + Fields below are only used in a query, like for estimation, and they will be useless when stored in + the stats cache. (See GenerateHistCollFromColumnInfo() for details) + */ + + // Idx2ColUniqueIDs maps the index id to its column ids. It's used to calculate the selectivity in planner. + Idx2ColUniqueIDs map[int64][]int64 + // ColUniqueID2IdxIDs maps the column id to a list index ids whose first column is it. It's used to calculate the selectivity in planner. + ColUniqueID2IdxIDs map[int64][]int64 + // UniqueID2colInfoID maps the column UniqueID to its ID in the metadata. + UniqueID2colInfoID map[int64]int64 } // TableMemoryUsage records tbl memory usage @@ -561,19 +571,23 @@ func (t *Table) ColumnEqualRowCount(sctx sessionctx.Context, value types.Datum, } // GetRowCountByIntColumnRanges estimates the row count by a slice of IntColumnRange. -func (coll *HistColl) GetRowCountByIntColumnRanges(sctx sessionctx.Context, colID int64, intRanges []*ranger.Range) (result float64, err error) { +func (coll *HistColl) GetRowCountByIntColumnRanges(sctx sessionctx.Context, colUniqueID int64, intRanges []*ranger.Range) (result float64, err error) { var name string if sctx.GetSessionVars().StmtCtx.EnableOptimizerDebugTrace { debugtrace.EnterContextCommon(sctx) - debugTraceGetRowCountInput(sctx, colID, intRanges) + debugTraceGetRowCountInput(sctx, colUniqueID, intRanges) defer func() { debugtrace.RecordAnyValuesWithNames(sctx, "Name", name, "Result", result) debugtrace.LeaveContextCommon(sctx) }() } sc := sctx.GetSessionVars().StmtCtx - c, ok := coll.Columns[colID] - recordUsedItemStatsStatus(sctx, c, coll.PhysicalID, colID) + c, ok := coll.Columns[colUniqueID] + colInfoID := colUniqueID + if len(coll.UniqueID2colInfoID) > 0 { + colInfoID = coll.UniqueID2colInfoID[colUniqueID] + } + recordUsedItemStatsStatus(sctx, c, coll.PhysicalID, colInfoID) if c != nil && c.Info != nil { name = c.Info.Name.O } @@ -606,19 +620,23 @@ func (coll *HistColl) GetRowCountByIntColumnRanges(sctx sessionctx.Context, colI } // GetRowCountByColumnRanges estimates the row count by a slice of Range. -func (coll *HistColl) GetRowCountByColumnRanges(sctx sessionctx.Context, colID int64, colRanges []*ranger.Range) (result float64, err error) { +func (coll *HistColl) GetRowCountByColumnRanges(sctx sessionctx.Context, colUniqueID int64, colRanges []*ranger.Range) (result float64, err error) { var name string if sctx.GetSessionVars().StmtCtx.EnableOptimizerDebugTrace { debugtrace.EnterContextCommon(sctx) - debugTraceGetRowCountInput(sctx, colID, colRanges) + debugTraceGetRowCountInput(sctx, colUniqueID, colRanges) defer func() { debugtrace.RecordAnyValuesWithNames(sctx, "Name", name, "Result", result) debugtrace.LeaveContextCommon(sctx) }() } sc := sctx.GetSessionVars().StmtCtx - c, ok := coll.Columns[colID] - recordUsedItemStatsStatus(sctx, c, coll.PhysicalID, colID) + c, ok := coll.Columns[colUniqueID] + colInfoID := colUniqueID + if len(coll.UniqueID2colInfoID) > 0 { + colInfoID = coll.UniqueID2colInfoID[colUniqueID] + } + recordUsedItemStatsStatus(sctx, c, coll.PhysicalID, colInfoID) if c != nil && c.Info != nil { name = c.Info.Name.O } @@ -733,7 +751,7 @@ func (coll *HistColl) findAvailableStatsForCol(sctx sessionctx.Context, uniqueID return false, uniqueID } // try to find available stats in single column index stats (except for prefix index) - for idxStatsIdx, cols := range coll.Idx2ColumnIDs { + for idxStatsIdx, cols := range coll.Idx2ColUniqueIDs { if len(cols) == 1 && cols[0] == uniqueID { idxStats, ok := coll.Indices[idxStatsIdx] if ok && @@ -938,13 +956,15 @@ func (coll *HistColl) ID2UniqueID(columns []*expression.Column) *HistColl { return newColl } -// GenerateHistCollFromColumnInfo generates a new HistColl whose ColID2IdxIDs and IdxID2ColIDs is built from the given parameter. +// GenerateHistCollFromColumnInfo generates a new HistColl whose ColUniqueID2IdxIDs and IdxID2ColIDs is built from the given parameter. func (coll *HistColl) GenerateHistCollFromColumnInfo(tblInfo *model.TableInfo, columns []*expression.Column) *HistColl { newColHistMap := make(map[int64]*Column) colInfoID2UniqueID := make(map[int64]int64, len(columns)) idxID2idxInfo := make(map[int64]*model.IndexInfo) + uniqueID2colInfoID := make(map[int64]int64, len(columns)) for _, col := range columns { colInfoID2UniqueID[col.ID] = col.UniqueID + uniqueID2colInfoID[col.UniqueID] = col.ID } for id, colHist := range coll.Columns { uniqueID, ok := colInfoID2UniqueID[id] @@ -984,15 +1004,16 @@ func (coll *HistColl) GenerateHistCollFromColumnInfo(tblInfo *model.TableInfo, c slices.Sort(idxIDs) } newColl := &HistColl{ - PhysicalID: coll.PhysicalID, - HavePhysicalID: coll.HavePhysicalID, - Pseudo: coll.Pseudo, - RealtimeCount: coll.RealtimeCount, - ModifyCount: coll.ModifyCount, - Columns: newColHistMap, - Indices: newIdxHistMap, - ColID2IdxIDs: colID2IdxIDs, - Idx2ColumnIDs: idx2Columns, + PhysicalID: coll.PhysicalID, + HavePhysicalID: coll.HavePhysicalID, + Pseudo: coll.Pseudo, + RealtimeCount: coll.RealtimeCount, + ModifyCount: coll.ModifyCount, + Columns: newColHistMap, + Indices: newIdxHistMap, + ColUniqueID2IdxIDs: colID2IdxIDs, + Idx2ColUniqueIDs: idx2Columns, + UniqueID2colInfoID: uniqueID2colInfoID, } return newColl } @@ -1065,7 +1086,7 @@ func (coll *HistColl) crossValidationSelectivity( }() } minRowCount = math.MaxFloat64 - cols := coll.Idx2ColumnIDs[idx.ID] + cols := coll.Idx2ColUniqueIDs[idx.ID] crossValidationSelectivity = 1.0 totalRowCount := idx.TotalRowCount() for i, colID := range cols { @@ -1134,7 +1155,7 @@ func (coll *HistColl) getEqualCondSelectivity(sctx sessionctx.Context, idx *Inde return outOfRangeEQSelectivity(sctx, idx.NDV, coll.RealtimeCount, int64(idx.TotalRowCount())), nil } // The equal condition only uses prefix columns of the index. - colIDs := coll.Idx2ColumnIDs[idx.ID] + colIDs := coll.Idx2ColUniqueIDs[idx.ID] var ndv int64 for i, colID := range colIDs { if i >= usedColsLen { @@ -1236,19 +1257,19 @@ func (coll *HistColl) getIndexRowCount(sctx sessionctx.Context, idxID int64, ind } var count float64 var err error - colIDs := coll.Idx2ColumnIDs[idxID] - var colID int64 - if rangePosition >= len(colIDs) { - colID = -1 + colUniqueIDs := coll.Idx2ColUniqueIDs[idxID] + var colUniqueID int64 + if rangePosition >= len(colUniqueIDs) { + colUniqueID = -1 } else { - colID = colIDs[rangePosition] + colUniqueID = colUniqueIDs[rangePosition] } // prefer index stats over column stats - if idxIDs, ok := coll.ColID2IdxIDs[colID]; ok && len(idxIDs) > 0 { + if idxIDs, ok := coll.ColUniqueID2IdxIDs[colUniqueID]; ok && len(idxIDs) > 0 { idxID := idxIDs[0] count, err = coll.GetRowCountByIndexRanges(sctx, idxID, []*ranger.Range{&rang}) } else { - count, err = coll.GetRowCountByColumnRanges(sctx, colID, []*ranger.Range{&rang}) + count, err = coll.GetRowCountByColumnRanges(sctx, colUniqueID, []*ranger.Range{&rang}) } if err != nil { return 0, errors.Trace(err) diff --git a/tests/integrationtest/r/ddl/db_rename.result b/tests/integrationtest/r/ddl/db_rename.result new file mode 100644 index 0000000000000..8b47d3aff357e --- /dev/null +++ b/tests/integrationtest/r/ddl/db_rename.result @@ -0,0 +1,27 @@ +drop table if exists t; +create table t (pk int primary key, c int default 1, c1 int default 1, unique key k1(c), key k2(c1)); +alter table t rename index k1 to k3; +admin check index t k3; +alter table t rename index k3 to k3; +admin check index t k3; +alter table t rename index x to x; +Error 1176 (42000): Key 'x' doesn't exist in table 't' +alter table t rename index k3 to k2; +Error 1061 (42000): Duplicate key name 'k2' +alter table t rename index k2 to K2; +alter table t rename key k3 to K2; +Error 1061 (42000): Duplicate key name 'K2' +drop table t; +create table t(j json); +alter table t add index idx1((cast(j as char(10) array))); +alter table t rename index idx1 to idx2; +alter table t add index idx1((cast(j as char(10) array))); +insert into t values ('["1"]'); +alter table t add index IDX3((cast(j as char(10) array))); +alter table t rename index IDX3 to IDX4; +alter table t add index IDX3((cast(j as char(10) array))); +insert into t values ('["2"]'); +select * from t; +j +["1"] +["2"] diff --git a/tests/integrationtest/t/ddl/db_rename.test b/tests/integrationtest/t/ddl/db_rename.test new file mode 100644 index 0000000000000..05f3814f24e61 --- /dev/null +++ b/tests/integrationtest/t/ddl/db_rename.test @@ -0,0 +1,27 @@ +# TestRenameIndex +drop table if exists t; +create table t (pk int primary key, c int default 1, c1 int default 1, unique key k1(c), key k2(c1)); +alter table t rename index k1 to k3; +admin check index t k3; +alter table t rename index k3 to k3; +admin check index t k3; +-- error 1176 +alter table t rename index x to x; +-- error 1061 +alter table t rename index k3 to k2; +alter table t rename index k2 to K2; +-- error 1061 +alter table t rename key k3 to K2; + +# TestIssue51431 +drop table t; +create table t(j json); +alter table t add index idx1((cast(j as char(10) array))); +alter table t rename index idx1 to idx2; +alter table t add index idx1((cast(j as char(10) array))); +insert into t values ('["1"]'); +alter table t add index IDX3((cast(j as char(10) array))); +alter table t rename index IDX3 to IDX4; +alter table t add index IDX3((cast(j as char(10) array))); +insert into t values ('["2"]'); +select * from t; diff --git a/tests/realtikvtest/addindextest/add_index_test.go b/tests/realtikvtest/addindextest/add_index_test.go index 7f9fa71b9a977..2fdf6ddda1177 100644 --- a/tests/realtikvtest/addindextest/add_index_test.go +++ b/tests/realtikvtest/addindextest/add_index_test.go @@ -192,3 +192,15 @@ func TestAddIndexDistCancel(t *testing.T) { tk.MustExec(`set global tidb_enable_dist_task=0;`) } + +func TestAddUKWithSmallIntHandles(t *testing.T) { + store := realtikvtest.CreateMockStoreAndSetup(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("drop database if exists small;") + tk.MustExec("create database small;") + tk.MustExec("use small;") + tk.MustExec(`set global tidb_ddl_enable_fast_reorg=1;`) + tk.MustExec("create table t (a bigint, b int, primary key (a) clustered)") + tk.MustExec("insert into t values (-9223372036854775808, 1),(-9223372036854775807, 1)") + tk.MustContainErrMsg("alter table t add unique index uk(b)", "Duplicate entry '1' for key 't.uk'") +} diff --git a/tests/realtikvtest/addindextest/integration_test.go b/tests/realtikvtest/addindextest/integration_test.go index e05b7f670db4e..bb6af7e925673 100644 --- a/tests/realtikvtest/addindextest/integration_test.go +++ b/tests/realtikvtest/addindextest/integration_test.go @@ -555,3 +555,13 @@ func TestAddIndexRemoteDuplicateCheck(t *testing.T) { tk.MustGetErrCode("alter table t add unique index idx(b);", errno.ErrDupEntry) ingest.ForceSyncFlagForTest = false } + +func TestAddUniqueIndexDuplicatedError(t *testing.T) { + store := realtikvtest.CreateMockStoreAndSetup(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("DROP TABLE IF EXISTS `b1cce552` ") + tk.MustExec("CREATE TABLE `b1cce552` (\n `f5d9aecb` timestamp DEFAULT '2031-12-22 06:44:52',\n `d9337060` varchar(186) DEFAULT 'duplicatevalue',\n `4c74082f` year(4) DEFAULT '1977',\n `9215adc3` tinytext DEFAULT NULL,\n `85ad5a07` decimal(5,0) NOT NULL DEFAULT '68649',\n `8c60260f` varchar(130) NOT NULL DEFAULT 'drfwe301tuehhkmk0jl79mzekuq0byg',\n `8069da7b` varchar(90) DEFAULT 'ra5rhqzgjal4o47ppr33xqjmumpiiillh7o5ajx7gohmuroan0u',\n `91e218e1` tinytext DEFAULT NULL,\n PRIMARY KEY (`8c60260f`,`85ad5a07`) /*T![clustered_index] CLUSTERED */,\n KEY `d88975e1` (`8069da7b`)\n);") + tk.MustExec("INSERT INTO `b1cce552` (`f5d9aecb`, `d9337060`, `4c74082f`, `9215adc3`, `85ad5a07`, `8c60260f`, `8069da7b`, `91e218e1`) VALUES ('2031-12-22 06:44:52', 'duplicatevalue', 2028, NULL, 846, 'N6QD1=@ped@owVoJx', '9soPM2d6H', 'Tv%'), ('2031-12-22 06:44:52', 'duplicatevalue', 2028, NULL, 9052, '_HWaf#gD!bw', '9soPM2d6H', 'Tv%');") + tk.MustGetErrCode("ALTER TABLE `b1cce552` ADD unique INDEX `65290727` (`4c74082f`, `d9337060`, `8069da7b`);", errno.ErrDupEntry) +} diff --git a/util/cpuprofile/testutil/util.go b/util/cpuprofile/testutil/util.go index c0e72d0a1dbb7..46c23bb5e0a23 100644 --- a/util/cpuprofile/testutil/util.go +++ b/util/cpuprofile/testutil/util.go @@ -16,6 +16,7 @@ package testutil import ( "context" + "encoding/hex" "runtime/pprof" ) @@ -24,7 +25,8 @@ func MockCPULoad(ctx context.Context, labels ...string) { lvs := []string{} for _, label := range labels { lvs = append(lvs, label) - lvs = append(lvs, label+" value") + val := hex.EncodeToString([]byte(label + " value")) + lvs = append(lvs, val) // start goroutine with only 1 label. go mockCPULoadByGoroutineWithLabel(ctx, label, label+" value") } diff --git a/util/topsql/collector/BUILD.bazel b/util/topsql/collector/BUILD.bazel index 0bb7e073a9d4b..bc48ae5ec0ef8 100644 --- a/util/topsql/collector/BUILD.bazel +++ b/util/topsql/collector/BUILD.bazel @@ -8,7 +8,6 @@ go_library( deps = [ "//util", "//util/cpuprofile", - "//util/hack", "//util/logutil", "//util/topsql/state", "@com_github_google_pprof//profile", diff --git a/util/topsql/collector/cpu.go b/util/topsql/collector/cpu.go index fbd7fc2d9c025..8bda153161c0a 100644 --- a/util/topsql/collector/cpu.go +++ b/util/topsql/collector/cpu.go @@ -16,6 +16,7 @@ package collector import ( "context" + "encoding/hex" "runtime/pprof" "sync" "time" @@ -23,7 +24,6 @@ import ( "github.com/google/pprof/profile" "github.com/pingcap/tidb/util" "github.com/pingcap/tidb/util/cpuprofile" - "github.com/pingcap/tidb/util/hack" "github.com/pingcap/tidb/util/logutil" topsqlstate "github.com/pingcap/tidb/util/topsql/state" "go.uber.org/zap" @@ -193,12 +193,25 @@ func (sp *SQLCPUCollector) parseCPUProfileBySQLLabels(p *profile.Profile) []SQLC func (*SQLCPUCollector) createSQLStats(sqlMap map[string]*sqlStats) []SQLCPUTimeRecord { stats := make([]SQLCPUTimeRecord, 0, len(sqlMap)) - for sqlDigest, stmt := range sqlMap { + for hexSQLDigest, stmt := range sqlMap { stmt.tune() - for planDigest, val := range stmt.plans { + + sqlDigest, err := hex.DecodeString(hexSQLDigest) + if err != nil { + logutil.BgLogger().Error("decode sql digest failed", zap.String("sqlDigest", hexSQLDigest), zap.Error(err)) + continue + } + + for hexPlanDigest, val := range stmt.plans { + planDigest, err := hex.DecodeString(hexPlanDigest) + if err != nil { + logutil.BgLogger().Error("decode plan digest failed", zap.String("planDigest", hexPlanDigest), zap.Error(err)) + continue + } + stats = append(stats, SQLCPUTimeRecord{ - SQLDigest: []byte(sqlDigest), - PlanDigest: []byte(planDigest), + SQLDigest: sqlDigest, + PlanDigest: planDigest, CPUTimeMs: uint32(time.Duration(val).Milliseconds()), }) } @@ -255,12 +268,12 @@ func (s *sqlStats) tune() { } // CtxWithSQLDigest wrap the ctx with sql digest. -func CtxWithSQLDigest(ctx context.Context, sqlDigest []byte) context.Context { - return pprof.WithLabels(ctx, pprof.Labels(labelSQLDigest, string(hack.String(sqlDigest)))) +func CtxWithSQLDigest(ctx context.Context, sqlDigest string) context.Context { + return pprof.WithLabels(ctx, pprof.Labels(labelSQLDigest, sqlDigest)) } // CtxWithSQLAndPlanDigest wrap the ctx with sql digest and plan digest. -func CtxWithSQLAndPlanDigest(ctx context.Context, sqlDigest, planDigest []byte) context.Context { - return pprof.WithLabels(ctx, pprof.Labels(labelSQLDigest, string(hack.String(sqlDigest)), - labelPlanDigest, string(hack.String(planDigest)))) +func CtxWithSQLAndPlanDigest(ctx context.Context, sqlDigest, planDigest string) context.Context { + return pprof.WithLabels(ctx, pprof.Labels(labelSQLDigest, sqlDigest, + labelPlanDigest, planDigest)) } diff --git a/util/topsql/reporter/BUILD.bazel b/util/topsql/reporter/BUILD.bazel index d910cc5f425cf..d8334905d3fd5 100644 --- a/util/topsql/reporter/BUILD.bazel +++ b/util/topsql/reporter/BUILD.bazel @@ -52,6 +52,7 @@ go_test( "//util/topsql/reporter/mock", "//util/topsql/state", "//util/topsql/stmtstats", + "@com_github_pingcap_failpoint//:failpoint", "@com_github_pingcap_tipb//go-tipb", "@com_github_stretchr_testify//assert", "@com_github_stretchr_testify//require", diff --git a/util/topsql/reporter/pubsub.go b/util/topsql/reporter/pubsub.go index cdf93c7c2aa3a..198ad61ee921f 100644 --- a/util/topsql/reporter/pubsub.go +++ b/util/topsql/reporter/pubsub.go @@ -19,6 +19,8 @@ import ( "errors" "time" + tidberrors "github.com/pingcap/errors" + "github.com/pingcap/failpoint" "github.com/pingcap/tidb/util" "github.com/pingcap/tidb/util/logutil" reporter_metrics "github.com/pingcap/tidb/util/topsql/reporter/metrics" @@ -97,6 +99,11 @@ func (ds *pubSubDataSink) OnReporterClosing() { func (ds *pubSubDataSink) run() error { defer func() { + if r := recover(); r != nil { + err := tidberrors.Errorf("%v", r) + // To catch panic when log grpc error. https://github.com/pingcap/tidb/issues/51301. + logutil.BgLogger().Error("[top-sql] got panic in pub sub data sink, just ignore", zap.Error(err)) + } ds.registerer.Deregister(ds) ds.cancel() }() @@ -133,6 +140,7 @@ func (ds *pubSubDataSink) run() error { return ctx.Err() } + failpoint.Inject("mockGrpcLogPanic", nil) if err != nil { logutil.BgLogger().Warn( "[top-sql] pubsub datasink failed to send data to subscriber", diff --git a/util/topsql/reporter/pubsub_test.go b/util/topsql/reporter/pubsub_test.go index 20e26bd972dc2..8ba35aabeb171 100644 --- a/util/topsql/reporter/pubsub_test.go +++ b/util/topsql/reporter/pubsub_test.go @@ -20,8 +20,10 @@ import ( "testing" "time" + "github.com/pingcap/failpoint" "github.com/pingcap/tipb/go-tipb" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "google.golang.org/grpc/metadata" ) @@ -85,6 +87,8 @@ func TestPubSubDataSink(t *testing.T) { _ = ds.run() }() + panicPath := "github.com/pingcap/tidb/pkg/util/topsql/reporter/mockGrpcLogPanic" + require.NoError(t, failpoint.Enable(panicPath, "panic")) err := ds.TrySend(&ReportData{ DataRecords: []tipb.TopSQLRecord{{ SqlDigest: []byte("S1"), @@ -117,4 +121,5 @@ func TestPubSubDataSink(t *testing.T) { mockStream.Unlock() ds.OnReporterClosing() + require.NoError(t, failpoint.Disable(panicPath)) } diff --git a/util/topsql/topsql.go b/util/topsql/topsql.go index 97e10dd58e498..5c23126d43499 100644 --- a/util/topsql/topsql.go +++ b/util/topsql/topsql.go @@ -97,11 +97,11 @@ func RegisterPlan(normalizedPlan string, planDigest *parser.Digest) { // AttachAndRegisterSQLInfo attach the sql information into Top SQL and register the SQL meta information. func AttachAndRegisterSQLInfo(ctx context.Context, normalizedSQL string, sqlDigest *parser.Digest, isInternal bool) context.Context { - if sqlDigest == nil || len(sqlDigest.Bytes()) == 0 { + if sqlDigest == nil || len(sqlDigest.String()) == 0 { return ctx } sqlDigestBytes := sqlDigest.Bytes() - ctx = collector.CtxWithSQLDigest(ctx, sqlDigestBytes) + ctx = collector.CtxWithSQLDigest(ctx, sqlDigest.String()) pprof.SetGoroutineLabels(ctx) linkSQLTextWithDigest(sqlDigestBytes, normalizedSQL, isInternal) @@ -124,15 +124,15 @@ func AttachAndRegisterSQLInfo(ctx context.Context, normalizedSQL string, sqlDige // AttachSQLAndPlanInfo attach the sql and plan information into Top SQL func AttachSQLAndPlanInfo(ctx context.Context, sqlDigest *parser.Digest, planDigest *parser.Digest) context.Context { - if sqlDigest == nil || len(sqlDigest.Bytes()) == 0 { + if sqlDigest == nil || len(sqlDigest.String()) == 0 { return ctx } - var planDigestBytes []byte - sqlDigestBytes := sqlDigest.Bytes() + var planDigestStr string + sqlDigestStr := sqlDigest.String() if planDigest != nil { - planDigestBytes = planDigest.Bytes() + planDigestStr = planDigest.String() } - ctx = collector.CtxWithSQLAndPlanDigest(ctx, sqlDigestBytes, planDigestBytes) + ctx = collector.CtxWithSQLAndPlanDigest(ctx, sqlDigestStr, planDigestStr) pprof.SetGoroutineLabels(ctx) failpoint.Inject("mockHighLoadForEachPlan", func(val failpoint.Value) {