diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 3e23f56c60f38..eed7fbd2b1279 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -15,7 +15,7 @@ A Contributor refers to the person who contributes to the following projects: ## How to become a TiDB Contributor? -If a PR (Pull Request) submitted to the TiDB / TiKV / TiSpark / PD / Docs/Docs-cn projects by you is approved and merged, then you become a TiDB Contributor. +If a PR (Pull Request) submitted to the TiDB/TiKV/TiSpark/PD/Docs/Docs-cn projects by you is approved and merged, then you become a TiDB Contributor. You are also encouraged to participate in the projects in the following ways: - Actively answer technical questions asked by community users. diff --git a/Makefile b/Makefile index 7220dda9dc99d..1658f1b02d847 100644 --- a/Makefile +++ b/Makefile @@ -12,7 +12,7 @@ path_to_add := $(addsuffix /bin,$(subst :,/bin:,$(GOPATH))) export PATH := $(path_to_add):$(PATH) GO := GO111MODULE=on go -GOBUILD := CGO_ENABLED=0 $(GO) build $(BUILD_FLAG) +GOBUILD := CGO_ENABLED=1 $(GO) build $(BUILD_FLAG) GOTEST := CGO_ENABLED=1 $(GO) test -p 3 OVERALLS := CGO_ENABLED=1 GO111MODULE=on overalls @@ -24,8 +24,8 @@ PACKAGES := $$($(PACKAGE_LIST)) PACKAGE_DIRECTORIES := $(PACKAGE_LIST) | sed 's|github.com/pingcap/$(PROJECT)/||' FILES := $$(find $$($(PACKAGE_DIRECTORIES)) -name "*.go") -GOFAIL_ENABLE := $$(find $$PWD/ -type d | grep -vE "(\.git|tools)" | xargs gofail enable) -GOFAIL_DISABLE := $$(find $$PWD/ -type d | grep -vE "(\.git|tools)" | xargs gofail disable) +GOFAIL_ENABLE := $$(find $$PWD/ -type d | grep -vE "(\.git|tools)" | xargs tools/bin/gofail enable) +GOFAIL_DISABLE := $$(find $$PWD/ -type d | grep -vE "(\.git|tools)" | xargs tools/bin/gofail disable) LDFLAGS += -X "github.com/pingcap/parser/mysql.TiDBReleaseVersion=$(shell git describe --tags --dirty)" LDFLAGS += -X "github.com/pingcap/tidb/util/printer.TiDBBuildTS=$(shell date -u '+%Y-%m-%d %I:%M:%S')" @@ -123,11 +123,7 @@ ifeq ("$(TRAVIS_COVERAGE)", "1") bash <(curl -s https://codecov.io/bash) endif -gotest: - @rm -rf $GOPATH/bin/gofail - $(GO) get github.com/pingcap/gofail - @which gofail - @$(GOFAIL_ENABLE) +gotest: gofail-enable ifeq ("$(TRAVIS_COVERAGE)", "1") @echo "Running in TRAVIS_COVERAGE mode." @export log_level=error; \ @@ -140,23 +136,17 @@ else endif @$(GOFAIL_DISABLE) -race: - $(GO) get github.com/pingcap/gofail - @$(GOFAIL_ENABLE) +race: gofail-enable @export log_level=debug; \ $(GOTEST) -timeout 20m -race $(PACKAGES) || { $(GOFAIL_DISABLE); exit 1; } @$(GOFAIL_DISABLE) -leak: - $(GO) get github.com/pingcap/gofail - @$(GOFAIL_ENABLE) +leak: gofail-enable @export log_level=debug; \ $(GOTEST) -tags leak $(PACKAGES) || { $(GOFAIL_DISABLE); exit 1; } @$(GOFAIL_DISABLE) -tikv_integration_test: - $(GO) get github.com/pingcap/gofail - @$(GOFAIL_ENABLE) +tikv_integration_test: gofail-enable $(GOTEST) ./store/tikv/. -with-tikv=true || { $(GOFAIL_DISABLE); exit 1; } @$(GOFAIL_DISABLE) @@ -200,37 +190,40 @@ importer: checklist: cat checklist.md -gofail-enable: +gofail-enable: tools/bin/gofail # Converting gofail failpoints... @$(GOFAIL_ENABLE) -gofail-disable: +gofail-disable: tools/bin/gofail # Restoring gofail failpoints... @$(GOFAIL_DISABLE) checkdep: $(GO) list -f '{{ join .Imports "\n" }}' github.com/pingcap/tidb/store/tikv | grep ^github.com/pingcap/parser$$ || exit 0; exit 1 -tools/bin/megacheck: +tools/bin/megacheck: tools/check/go.mod cd tools/check; \ - $go build -o ../bin/megacheck honnef.co/go/tools/cmd/megacheck + $(GO) build -o ../bin/megacheck honnef.co/go/tools/cmd/megacheck -tools/bin/revive: +tools/bin/revive: tools/check/go.mod cd tools/check; \ $(GO) build -o ../bin/revive github.com/mgechev/revive -tools/bin/goword: +tools/bin/goword: tools/check/go.mod cd tools/check; \ $(GO) build -o ../bin/goword github.com/chzchzchz/goword -tools/bin/gometalinter: +tools/bin/gometalinter: tools/check/go.mod cd tools/check; \ $(GO) build -o ../bin/gometalinter gopkg.in/alecthomas/gometalinter.v2 -tools/bin/gosec: +tools/bin/gosec: tools/check/go.mod cd tools/check; \ $(GO) build -o ../bin/gosec github.com/securego/gosec/cmd/gosec -tools/bin/errcheck: +tools/bin/errcheck: tools/check/go.mod cd tools/check; \ $(GO) build -o ../bin/errcheck github.com/kisielk/errcheck + +tools/bin/gofail: go.mod + $(GO) build -o $@ github.com/pingcap/gofail diff --git a/README.md b/README.md index 50a2291b05b5c..0029c901c0f2a 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ [![Build Status](https://travis-ci.org/pingcap/tidb.svg?branch=master)](https://travis-ci.org/pingcap/tidb) [![Go Report Card](https://goreportcard.com/badge/github.com/pingcap/tidb)](https://goreportcard.com/report/github.com/pingcap/tidb) -![GitHub release](https://img.shields.io/github/release/pingcap/tidb.svg) +![GitHub release](https://img.shields.io/github/tag/pingcap/tidb.svg?label=release) [![CircleCI Status](https://circleci.com/gh/pingcap/tidb.svg?style=shield)](https://circleci.com/gh/pingcap/tidb) [![Coverage Status](https://codecov.io/gh/pingcap/tidb/branch/master/graph/badge.svg)](https://codecov.io/gh/pingcap/tidb) diff --git a/cmd/benchdb/main.go b/cmd/benchdb/main.go index 1585cabe8c766..152950b51c60c 100644 --- a/cmd/benchdb/main.go +++ b/cmd/benchdb/main.go @@ -118,13 +118,13 @@ func (ut *benchDB) mustExec(sql string) { if len(rss) > 0 { ctx := context.Background() rs := rss[0] - chk := rs.NewChunk() + req := rs.NewRecordBatch() for { - err := rs.Next(ctx, chk) + err := rs.Next(ctx, req) if err != nil { log.Fatal(err) } - if chk.NumRows() == 0 { + if req.NumRows() == 0 { break } } diff --git a/cmd/pluginpkg/pluginpkg.go b/cmd/pluginpkg/pluginpkg.go new file mode 100644 index 0000000000000..e1b1db5a3dba3 --- /dev/null +++ b/cmd/pluginpkg/pluginpkg.go @@ -0,0 +1,159 @@ +// Copyright 2018 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "context" + "encoding/json" + "flag" + "fmt" + "log" + "os" + "os/exec" + "path" + "path/filepath" + "strings" + "text/template" + "time" + + "github.com/BurntSushi/toml" +) + +var ( + pkgDir string + outDir string +) + +const codeTemplate = ` +package main + +import ( + "github.com/pingcap/tidb/plugin" + "github.com/pingcap/tidb/sessionctx/variable" +) + +func PluginManifest() *plugin.Manifest { + return plugin.ExportManifest(&plugin.{{.kind}}Manifest{ + Manifest: plugin.Manifest{ + Kind: plugin.{{.kind}}, + Name: "{{.name}}", + Description: "{{.description}}", + Version: {{.version}}, + RequireVersion: map[string]uint16{}, + License: "{{.license}}", + BuildTime: "{{.buildTime}}", + SysVars: map[string]*variable.SysVar{ + {{range .sysVars}} + "{{.name}}": { + Scope: variable.Scope{{.scope}}, + Name: "{{.name}}", + Value: "{{.value}}", + }, + {{end}} + }, + Validate: {{.validate}}, + OnInit: {{.onInit}}, + OnShutdown: {{.onShutdown}}, + }, + {{range .export}} + {{.extPoint}}: {{.impl}}, + {{end}} + }) +} +` + +func init() { + flag.StringVar(&pkgDir, "pkg-dir", "", "plugin package folder path") + flag.StringVar(&outDir, "out-dir", "", "plugin packaged folder path") + flag.Usage = usage +} + +func usage() { + log.Printf("Usage: %s --pkg-dir [plugin source pkg folder] --outDir-dir [outDir-dir]\n", path.Base(os.Args[0])) + flag.PrintDefaults() + os.Exit(1) +} + +func main() { + flag.Parse() + if pkgDir == "" || outDir == "" { + flag.Usage() + } + var manifest map[string]interface{} + _, err := toml.DecodeFile(filepath.Join(pkgDir, "manifest.toml"), &manifest) + if err != nil { + log.Printf("read pkg %s's manifest failure, %+v\n", pkgDir, err) + os.Exit(1) + } + manifest["buildTime"] = time.Now().String() + + pluginName := manifest["name"].(string) + if strings.Contains(pluginName, "-") { + log.Printf("plugin name should not contain '-'\n") + os.Exit(1) + } + if pluginName != filepath.Base(pkgDir) { + log.Printf("plugin package must be same with plugin name in manifest file\n") + os.Exit(1) + } + + version := manifest["version"].(string) + tmpl, err := template.New("gen-plugin").Parse(codeTemplate) + if err != nil { + log.Printf("generate code failure during parse template, %+v\n", err) + os.Exit(1) + } + + genFileName := filepath.Join(pkgDir, filepath.Base(pkgDir)+".gen.go") + genFile, err := os.OpenFile(genFileName, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0755) + if err != nil { + log.Printf("generate code failure during prepare output file, %+v\n", err) + os.Exit(1) + } + defer func() { + err1 := os.Remove(genFileName) + if err1 != nil { + log.Printf("remove tmp file %s failure, please clean up manually at %v", genFileName, err1) + } + }() + + err = tmpl.Execute(genFile, manifest) + if err != nil { + log.Printf("generate code failure during generating code, %+v\n", err) + os.Exit(1) + } + + outputFile := filepath.Join(outDir, pluginName+"-"+version+".so") + pluginPath := `-pluginpath=` + pluginName + "-" + version + ctx := context.Background() + buildCmd := exec.CommandContext(ctx, "go", "build", + "-ldflags", pluginPath, + "-buildmode=plugin", + "-o", outputFile, pkgDir) + buildCmd.Stderr = os.Stderr + buildCmd.Stdout = os.Stdout + buildCmd.Env = append(os.Environ(), "GO111MODULE=on") + err = buildCmd.Run() + if err != nil { + log.Printf("compile plugin source code failure, %+v\n", err) + os.Exit(1) + } + fmt.Printf(`Package "%s" as plugin "%s" success.`+"\nManifest:\n", pkgDir, outputFile) + encoder := json.NewEncoder(os.Stdout) + encoder.SetIndent(" ", "\t") + err = encoder.Encode(manifest) + if err != nil { + log.Printf("print manifest detail failure, err: %v", err) + } +} diff --git a/config/config.go b/config/config.go index 0065d6448605e..3f29d75bd9301 100644 --- a/config/config.go +++ b/config/config.go @@ -74,6 +74,7 @@ type Config struct { TiKVClient TiKVClient `toml:"tikv-client" json:"tikv-client"` Binlog Binlog `toml:"binlog" json:"binlog"` CompatibleKillQuery bool `toml:"compatible-kill-query" json:"compatible-kill-query"` + Plugin Plugin `toml:"plugin" json:"plugin"` } // Log is the log section of config. @@ -240,8 +241,18 @@ type TiKVClient struct { GrpcKeepAliveTimeout uint `toml:"grpc-keepalive-timeout" json:"grpc-keepalive-timeout"` // CommitTimeout is the max time which command 'commit' will wait. CommitTimeout string `toml:"commit-timeout" json:"commit-timeout"` + // MaxTxnTimeUse is the max time a Txn may use (in seconds) from its startTS to commitTS. MaxTxnTimeUse uint `toml:"max-txn-time-use" json:"max-txn-time-use"` + + // MaxBatchSize is the max batch size when calling batch commands API. + MaxBatchSize uint `toml:"max-batch-size" json:"max-batch-size"` + // If TiKV load is greater than this, TiDB will wait for a while to avoid little batch. + OverloadThreshold uint `toml:"overload-threshold" json:"overload-threshold"` + // MaxBatchWaitTime in nanosecond is the max wait time for batch. + MaxBatchWaitTime time.Duration `toml:"max-batch-wait-time" json:"max-batch-wait-time"` + // BatchWaitSize is the max wait size for batch. + BatchWaitSize uint `toml:"batch-wait-size" json:"batch-wait-size"` } // Binlog is the config for binlog. @@ -255,6 +266,12 @@ type Binlog struct { BinlogSocket string `toml:"binlog-socket" json:"binlog-socket"` } +// Plugin is the config for plugin +type Plugin struct { + Dir string `toml:"dir" json:"dir"` + Load string `toml:"load" json:"load"` +} + var defaultConf = Config{ Host: "0.0.0.0", AdvertiseAddress: "", @@ -328,7 +345,13 @@ var defaultConf = Config{ GrpcKeepAliveTime: 10, GrpcKeepAliveTimeout: 3, CommitTimeout: "41s", - MaxTxnTimeUse: 590, + + MaxTxnTimeUse: 590, + + MaxBatchSize: 128, + OverloadThreshold: 200, + MaxBatchWaitTime: 0, + BatchWaitSize: 8, }, Binlog: Binlog{ WriteTimeout: "15s", diff --git a/config/config.toml.example b/config/config.toml.example index 376052d49bac0..622c32a13d970 100644 --- a/config/config.toml.example +++ b/config/config.toml.example @@ -231,7 +231,7 @@ grpc-keepalive-time = 10 # and if no activity is seen even after that the connection is closed. grpc-keepalive-timeout = 3 -# max time for commit command, must be twice bigger than raft election timeout. +# Max time for commit command, must be twice bigger than raft election timeout. commit-timeout = "41s" # The max time a Txn may use (in seconds) from its startTS to commitTS. @@ -239,6 +239,15 @@ commit-timeout = "41s" # value is less than gc_life_time - 10s. max-txn-time-use = 590 +# Max batch size in gRPC. +max-batch-size = 128 +# Overload threshold of TiKV. +overload-threshold = 200 +# Max batch wait time in nanosecond to avoid waiting too long. 0 means disable this feature. +max-batch-wait-time = 0 +# Batch wait size, to avoid waiting too long. +batch-wait-size = 8 + [txn-local-latches] # Enable local latches for transactions. Enable it when # there are lots of conflicts between transactions. diff --git a/config/config_test.go b/config/config_test.go index 88f60fec4b920..7ef1a687f9888 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -45,7 +45,9 @@ func (s *testConfigSuite) TestConfig(c *C) { c.Assert(err, IsNil) _, err = f.WriteString(`[performance] [tikv-client] -commit-timeout="41s"`) +commit-timeout="41s" +max-batch-size=128 +`) c.Assert(err, IsNil) c.Assert(f.Sync(), IsNil) @@ -55,6 +57,7 @@ commit-timeout="41s"`) c.Assert(conf.Binlog.Enable, Equals, true) c.Assert(conf.TiKVClient.CommitTimeout, Equals, "41s") + c.Assert(conf.TiKVClient.MaxBatchSize, Equals, uint(128)) c.Assert(f.Close(), IsNil) c.Assert(os.Remove(configFile), IsNil) diff --git a/ddl/db_test.go b/ddl/db_test.go index 377b0061aa372..4b20295015fe0 100644 --- a/ddl/db_test.go +++ b/ddl/db_test.go @@ -473,6 +473,64 @@ func (s *testDBSuite) TestCancelDropIndex(c *C) { s.mustExec(c, "alter table t drop index idx_c2") } +// TestCancelRenameIndex tests cancel ddl job which type is rename index. +func (s *testDBSuite) TestCancelRenameIndex(c *C) { + s.tk = testkit.NewTestKit(c, s.store) + s.mustExec(c, "use test_db") + s.mustExec(c, "create database if not exists test_rename_index") + s.mustExec(c, "drop table if exists t") + s.mustExec(c, "create table t(c1 int, c2 int)") + defer s.mustExec(c, "drop table t;") + for i := 0; i < 100; i++ { + s.mustExec(c, "insert into t values (?, ?)", i, i) + } + s.mustExec(c, "alter table t add index idx_c2(c2)") + var checkErr error + hook := &ddl.TestDDLCallback{} + hook.OnJobRunBeforeExported = func(job *model.Job) { + if job.Type == model.ActionRenameIndex && job.State == model.JobStateNone { + jobIDs := []int64{job.ID} + hookCtx := mock.NewContext() + hookCtx.Store = s.store + err := hookCtx.NewTxn(context.Background()) + if err != nil { + checkErr = errors.Trace(err) + return + } + txn, err := hookCtx.Txn(true) + if err != nil { + checkErr = errors.Trace(err) + return + } + errs, err := admin.CancelJobs(txn, jobIDs) + if err != nil { + checkErr = errors.Trace(err) + return + } + if errs[0] != nil { + checkErr = errors.Trace(errs[0]) + return + } + checkErr = txn.Commit(context.Background()) + } + } + originalHook := s.dom.DDL().GetHook() + s.dom.DDL().(ddl.DDLForTest).SetHook(hook) + rs, err := s.tk.Exec("alter table t rename index idx_c2 to idx_c3") + if rs != nil { + rs.Close() + } + c.Assert(checkErr, IsNil) + c.Assert(err, NotNil) + c.Assert(err.Error(), Equals, "[ddl:12]cancelled DDL job") + s.dom.DDL().(ddl.DDLForTest).SetHook(originalHook) + t := s.testGetTable(c, "t") + for _, idx := range t.Indices() { + c.Assert(strings.EqualFold(idx.Meta().Name.L, "idx_c3"), IsFalse) + } + s.mustExec(c, "alter table t rename index idx_c2 to idx_c3") +} + // TestCancelDropTable tests cancel ddl job which type is drop table. func (s *testDBSuite) TestCancelDropTableAndSchema(c *C) { s.tk = testkit.NewTestKit(c, s.store) @@ -1619,14 +1677,49 @@ func (s *testDBSuite) testRenameTable(c *C, sql string, isAlterTable bool) { // for failure case failSQL := fmt.Sprintf(sql, "test_not_exist.t", "test_not_exist.t") - s.testErrorCode(c, failSQL, tmysql.ErrFileNotFound) + if isAlterTable { + s.testErrorCode(c, failSQL, tmysql.ErrNoSuchTable) + } else { + s.testErrorCode(c, failSQL, tmysql.ErrFileNotFound) + } failSQL = fmt.Sprintf(sql, "test.test_not_exist", "test.test_not_exist") - s.testErrorCode(c, failSQL, tmysql.ErrFileNotFound) + if isAlterTable { + s.testErrorCode(c, failSQL, tmysql.ErrNoSuchTable) + } else { + s.testErrorCode(c, failSQL, tmysql.ErrFileNotFound) + } failSQL = fmt.Sprintf(sql, "test.t_not_exist", "test_not_exist.t") - s.testErrorCode(c, failSQL, tmysql.ErrFileNotFound) + if isAlterTable { + s.testErrorCode(c, failSQL, tmysql.ErrNoSuchTable) + } else { + s.testErrorCode(c, failSQL, tmysql.ErrFileNotFound) + } failSQL = fmt.Sprintf(sql, "test1.t2", "test_not_exist.t") s.testErrorCode(c, failSQL, tmysql.ErrErrorOnRename) + s.tk.MustExec("use test1") + s.tk.MustExec("create table if not exists t_exist (c1 int, c2 int)") + failSQL = fmt.Sprintf(sql, "test1.t2", "test1.t_exist") + s.testErrorCode(c, failSQL, tmysql.ErrTableExists) + failSQL = fmt.Sprintf(sql, "test.t_not_exist", "test1.t_exist") + if isAlterTable { + s.testErrorCode(c, failSQL, tmysql.ErrNoSuchTable) + } else { + s.testErrorCode(c, failSQL, tmysql.ErrTableExists) + } + failSQL = fmt.Sprintf(sql, "test_not_exist.t", "test1.t_exist") + if isAlterTable { + s.testErrorCode(c, failSQL, tmysql.ErrNoSuchTable) + } else { + s.testErrorCode(c, failSQL, tmysql.ErrTableExists) + } + failSQL = fmt.Sprintf(sql, "test_not_exist.t", "test1.t_not_exist") + if isAlterTable { + s.testErrorCode(c, failSQL, tmysql.ErrNoSuchTable) + } else { + s.testErrorCode(c, failSQL, tmysql.ErrFileNotFound) + } + // for the same table name s.tk.MustExec("use test1") s.tk.MustExec("create table if not exists t (c1 int, c2 int)") diff --git a/ddl/ddl_api.go b/ddl/ddl_api.go index 7f6f2a45d1e1d..ed699dca0dc93 100644 --- a/ddl/ddl_api.go +++ b/ddl/ddl_api.go @@ -2283,10 +2283,22 @@ func (d *ddl) RenameTable(ctx sessionctx.Context, oldIdent, newIdent ast.Ident, is := d.GetInformationSchema(ctx) oldSchema, ok := is.SchemaByName(oldIdent.Schema) if !ok { + if isAlterTable { + return infoschema.ErrTableNotExists.GenWithStackByArgs(oldIdent.Schema, oldIdent.Name) + } + if is.TableExists(newIdent.Schema, newIdent.Name) { + return infoschema.ErrTableExists.GenWithStackByArgs(newIdent) + } return errFileNotFound.GenWithStackByArgs(oldIdent.Schema, oldIdent.Name) } oldTbl, err := is.TableByName(oldIdent.Schema, oldIdent.Name) if err != nil { + if isAlterTable { + return infoschema.ErrTableNotExists.GenWithStackByArgs(oldIdent.Schema, oldIdent.Name) + } + if is.TableExists(newIdent.Schema, newIdent.Name) { + return infoschema.ErrTableExists.GenWithStackByArgs(newIdent) + } return errFileNotFound.GenWithStackByArgs(oldIdent.Schema, oldIdent.Name) } if isAlterTable && newIdent.Schema.L == oldIdent.Schema.L && newIdent.Name.L == oldIdent.Name.L { diff --git a/ddl/ddl_worker_test.go b/ddl/ddl_worker_test.go index a9648bd05aa67..e375205ca4f76 100644 --- a/ddl/ddl_worker_test.go +++ b/ddl/ddl_worker_test.go @@ -375,9 +375,9 @@ func buildCancelJobTests(firstID int64) []testCancelJob { // Test create database, watch out, database id will alloc a globalID. {act: model.ActionCreateSchema, jobIDs: []int64{firstID + 12}, cancelRetErrs: noErrs, cancelState: model.StateNone, ddlRetErr: err}, - {act: model.ActionDropColumn, jobIDs: []int64{firstID + 13}, cancelRetErrs: []error{admin.ErrCancelFinishedDDLJob.GenWithStackByArgs(firstID + 13)}, cancelState: model.StateDeleteOnly, ddlRetErr: err}, - {act: model.ActionDropColumn, jobIDs: []int64{firstID + 14}, cancelRetErrs: []error{admin.ErrCancelFinishedDDLJob.GenWithStackByArgs(firstID + 14)}, cancelState: model.StateWriteOnly, ddlRetErr: err}, - {act: model.ActionDropColumn, jobIDs: []int64{firstID + 15}, cancelRetErrs: []error{admin.ErrCancelFinishedDDLJob.GenWithStackByArgs(firstID + 15)}, cancelState: model.StateWriteReorganization, ddlRetErr: err}, + {act: model.ActionDropColumn, jobIDs: []int64{firstID + 13}, cancelRetErrs: []error{admin.ErrCannotCancelDDLJob.GenWithStackByArgs(firstID + 13)}, cancelState: model.StateDeleteOnly, ddlRetErr: err}, + {act: model.ActionDropColumn, jobIDs: []int64{firstID + 14}, cancelRetErrs: []error{admin.ErrCannotCancelDDLJob.GenWithStackByArgs(firstID + 14)}, cancelState: model.StateWriteOnly, ddlRetErr: err}, + {act: model.ActionDropColumn, jobIDs: []int64{firstID + 15}, cancelRetErrs: []error{admin.ErrCannotCancelDDLJob.GenWithStackByArgs(firstID + 15)}, cancelState: model.StateWriteReorganization, ddlRetErr: err}, } return tests @@ -427,15 +427,15 @@ func (s *testDDLSuite) TestCancelJob(c *C) { dbInfo := testSchemaInfo(c, d, "test_cancel_job") testCreateSchema(c, testNewContext(d), d, dbInfo) - // create table t (c1 int, c2 int); - tblInfo := testTableInfo(c, d, "t", 2) + // create table t (c1 int, c2 int, c3 int, c4 int, c5 int); + tblInfo := testTableInfo(c, d, "t", 5) ctx := testNewContext(d) err := ctx.NewTxn(context.Background()) c.Assert(err, IsNil) job := testCreateTable(c, ctx, d, dbInfo, tblInfo) - // insert t values (1, 2); + // insert t values (1, 2, 3, 4, 5); originTable := testGetTable(c, d, dbInfo.ID, tblInfo.ID) - row := types.MakeDatums(1, 2) + row := types.MakeDatums(1, 2, 3, 4, 5) _, err = originTable.AddRecord(ctx, row) c.Assert(err, IsNil) txn, err := ctx.Txn(true) @@ -559,21 +559,26 @@ func (s *testDDLSuite) TestCancelJob(c *C) { // for drop column. test = &tests[10] - dropColName := "c2" - dropColumnArgs := []interface{}{model.NewCIStr(dropColName)} - doDDLJobErrWithSchemaState(ctx, d, c, dbInfo.ID, tblInfo.ID, model.ActionDropColumn, dropColumnArgs, &cancelState) - c.Check(errors.ErrorStack(checkErr), Equals, "") + dropColName := "c3" s.checkCancelDropColumn(c, d, dbInfo.ID, tblInfo.ID, dropColName, false) + testDropColumn(c, ctx, d, dbInfo, tblInfo, dropColName, false) + c.Check(errors.ErrorStack(checkErr), Equals, "") + s.checkCancelDropColumn(c, d, dbInfo.ID, tblInfo.ID, dropColName, true) test = &tests[11] - doDDLJobErrWithSchemaState(ctx, d, c, dbInfo.ID, tblInfo.ID, model.ActionDropColumn, dropColumnArgs, &cancelState) - c.Check(errors.ErrorStack(checkErr), Equals, "") + + dropColName = "c4" s.checkCancelDropColumn(c, d, dbInfo.ID, tblInfo.ID, dropColName, false) + testDropColumn(c, ctx, d, dbInfo, tblInfo, dropColName, false) + c.Check(errors.ErrorStack(checkErr), Equals, "") + s.checkCancelDropColumn(c, d, dbInfo.ID, tblInfo.ID, dropColName, true) test = &tests[12] - doDDLJobErrWithSchemaState(ctx, d, c, dbInfo.ID, tblInfo.ID, model.ActionDropColumn, dropColumnArgs, &cancelState) - c.Check(errors.ErrorStack(checkErr), Equals, "") + dropColName = "c5" s.checkCancelDropColumn(c, d, dbInfo.ID, tblInfo.ID, dropColName, false) + testDropColumn(c, ctx, d, dbInfo, tblInfo, dropColName, false) + c.Check(errors.ErrorStack(checkErr), Equals, "") + s.checkCancelDropColumn(c, d, dbInfo.ID, tblInfo.ID, dropColName, true) } func (s *testDDLSuite) TestIgnorableSpec(c *C) { diff --git a/ddl/failtest/fail_db_test.go b/ddl/failtest/fail_db_test.go index 491704bca3eb9..a3269dea9bd20 100644 --- a/ddl/failtest/fail_db_test.go +++ b/ddl/failtest/fail_db_test.go @@ -29,6 +29,7 @@ import ( "github.com/pingcap/parser/model" "github.com/pingcap/tidb/ddl" "github.com/pingcap/tidb/ddl/testutil" + ddlutil "github.com/pingcap/tidb/ddl/util" "github.com/pingcap/tidb/domain" "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/session" @@ -110,11 +111,11 @@ func (s *testFailDBSuite) TestHalfwayCancelOperations(c *C) { // Make sure that the table's data has not been deleted. rs, err := s.se.Execute(context.Background(), "select count(*) from t") c.Assert(err, IsNil) - chk := rs[0].NewChunk() - err = rs[0].Next(context.Background(), chk) + req := rs[0].NewRecordBatch() + err = rs[0].Next(context.Background(), req) c.Assert(err, IsNil) - c.Assert(chk.NumRows() == 0, IsFalse) - row := chk.GetRow(0) + c.Assert(req.NumRows() == 0, IsFalse) + row := req.GetRow(0) c.Assert(row.Len(), Equals, 1) c.Assert(row.GetInt64(0), DeepEquals, int64(1)) c.Assert(rs[0].Close(), IsNil) @@ -143,11 +144,11 @@ func (s *testFailDBSuite) TestHalfwayCancelOperations(c *C) { // Make sure that the table's data has not been deleted. rs, err = s.se.Execute(context.Background(), "select count(*) from tx") c.Assert(err, IsNil) - chk = rs[0].NewChunk() - err = rs[0].Next(context.Background(), chk) + req = rs[0].NewRecordBatch() + err = rs[0].Next(context.Background(), req) c.Assert(err, IsNil) - c.Assert(chk.NumRows() == 0, IsFalse) - row = chk.GetRow(0) + c.Assert(req.NumRows() == 0, IsFalse) + row = req.GetRow(0) c.Assert(row.Len(), Equals, 1) c.Assert(row.GetInt64(0), DeepEquals, int64(1)) c.Assert(rs[0].Close(), IsNil) @@ -342,11 +343,13 @@ func (s *testFailDBSuite) TestAddIndexWorkerNum(c *C) { // Split table to multi region. s.cluster.SplitTable(s.mvccStore, tbl.Meta().ID, splitCount) + err = ddlutil.LoadDDLReorgVars(tk.Se) + c.Assert(err, IsNil) originDDLAddIndexWorkerCnt := variable.GetDDLReorgWorkerCounter() lastSetWorkerCnt := originDDLAddIndexWorkerCnt atomic.StoreInt32(&ddl.TestCheckWorkerNumber, lastSetWorkerCnt) ddl.TestCheckWorkerNumber = lastSetWorkerCnt - defer variable.SetDDLReorgWorkerCounter(originDDLAddIndexWorkerCnt) + defer tk.MustExec(fmt.Sprintf("set @@global.tidb_ddl_reorg_worker_cnt=%d", originDDLAddIndexWorkerCnt)) gofail.Enable("github.com/pingcap/tidb/ddl/checkIndexWorkerNum", `return(true)`) defer gofail.Disable("github.com/pingcap/tidb/ddl/checkIndexWorkerNum") @@ -364,7 +367,7 @@ LOOP: c.Assert(err, IsNil, Commentf("err:%v", errors.ErrorStack(err))) case <-ddl.TestCheckWorkerNumCh: lastSetWorkerCnt = int32(rand.Intn(8) + 8) - tk.MustExec(fmt.Sprintf("set @@tidb_ddl_reorg_worker_cnt=%d", lastSetWorkerCnt)) + tk.MustExec(fmt.Sprintf("set @@global.tidb_ddl_reorg_worker_cnt=%d", lastSetWorkerCnt)) atomic.StoreInt32(&ddl.TestCheckWorkerNumber, lastSetWorkerCnt) checkNum++ } diff --git a/ddl/index.go b/ddl/index.go index d70e0fb13736b..1fa405fb1d61d 100644 --- a/ddl/index.go +++ b/ddl/index.go @@ -23,6 +23,7 @@ import ( "github.com/pingcap/parser/ast" "github.com/pingcap/parser/model" "github.com/pingcap/parser/mysql" + ddlutil "github.com/pingcap/tidb/ddl/util" "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/infoschema" "github.com/pingcap/tidb/kv" @@ -200,26 +201,11 @@ func validateRenameIndex(from, to model.CIStr, tbl *model.TableInfo) (ignore boo } func onRenameIndex(t *meta.Meta, job *model.Job) (ver int64, _ error) { - var from, to model.CIStr - if err := job.DecodeArgs(&from, &to); err != nil { - job.State = model.JobStateCancelled - return ver, errors.Trace(err) - } - tblInfo, err := getTableInfo(t, job, job.SchemaID) + tblInfo, from, to, err := checkRenameIndex(t, job) if err != nil { - job.State = model.JobStateCancelled return ver, errors.Trace(err) } - // Double check. See function `RenameIndex` in ddl_api.go - duplicate, err := validateRenameIndex(from, to, tblInfo) - if duplicate { - return ver, nil - } - if err != nil { - job.State = model.JobStateCancelled - return ver, errors.Trace(err) - } idx := schemautil.FindIndexByName(from.L, tblInfo.Indices) idx.Name = to if ver, err = updateVersionAndTableInfo(t, job, tblInfo, true); err != nil { @@ -436,6 +422,31 @@ func checkDropIndex(t *meta.Meta, job *model.Job) (*model.TableInfo, *model.Inde return tblInfo, indexInfo, nil } +func checkRenameIndex(t *meta.Meta, job *model.Job) (*model.TableInfo, model.CIStr, model.CIStr, error) { + var from, to model.CIStr + schemaID := job.SchemaID + tblInfo, err := getTableInfo(t, job, schemaID) + if err != nil { + return nil, from, to, errors.Trace(err) + } + + if err := job.DecodeArgs(&from, &to); err != nil { + job.State = model.JobStateCancelled + return nil, from, to, errors.Trace(err) + } + + // Double check. See function `RenameIndex` in ddl_api.go + duplicate, err := validateRenameIndex(from, to, tblInfo) + if duplicate { + return nil, from, to, nil + } + if err != nil { + job.State = model.JobStateCancelled + return nil, from, to, errors.Trace(err) + } + return tblInfo, from, to, errors.Trace(err) +} + const ( // DefaultTaskHandleCnt is default batch size of adding indices. DefaultTaskHandleCnt = 128 @@ -1070,6 +1081,17 @@ var ( TestCheckWorkerNumber = int32(16) ) +func loadDDLReorgVars(w *worker) error { + // Get sessionctx from context resource pool. + var ctx sessionctx.Context + ctx, err := w.sessPool.get() + if err != nil { + return errors.Trace(err) + } + defer w.sessPool.put(ctx) + return ddlutil.LoadDDLReorgVars(ctx) +} + // addPhysicalTableIndex handles the add index reorganization state for a non-partitioned table or a partition. // For a partitioned table, it should be handled partition by partition. // @@ -1110,6 +1132,9 @@ func (w *worker) addPhysicalTableIndex(t table.PhysicalTable, indexInfo *model.I } // For dynamic adjust add index worker number. + if err := loadDDLReorgVars(w); err != nil { + log.Error(err) + } workerCnt = variable.GetDDLReorgWorkerCounter() // If only have 1 range, we can only start 1 worker. if len(kvRanges) < int(workerCnt) { diff --git a/ddl/rollingback.go b/ddl/rollingback.go index 541b7230a579b..3af1b01e37586 100644 --- a/ddl/rollingback.go +++ b/ddl/rollingback.go @@ -190,6 +190,21 @@ func rollingbackDropSchema(t *meta.Meta, job *model.Job) error { return nil } +func rollingbackRenameIndex(t *meta.Meta, job *model.Job) (ver int64, err error) { + tblInfo, from, _, err := checkRenameIndex(t, job) + if err != nil { + return ver, errors.Trace(err) + } + // Here rename index is done in a transaction, if the job is not completed, it can be canceled. + idx := schemautil.FindIndexByName(from.L, tblInfo.Indices) + if idx.State == model.StatePublic { + job.State = model.JobStateCancelled + return ver, errCancelledDDLJob + } + job.State = model.JobStateRunning + return ver, errors.Trace(err) +} + func convertJob2RollbackJob(w *worker, d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, err error) { switch job.Type { case model.ActionAddColumn: @@ -204,6 +219,8 @@ func convertJob2RollbackJob(w *worker, d *ddlCtx, t *meta.Meta, job *model.Job) err = rollingbackDropTableOrView(t, job) case model.ActionDropSchema: err = rollingbackDropSchema(t, job) + case model.ActionRenameIndex: + ver, err = rollingbackRenameIndex(t, job) default: job.State = model.JobStateCancelled err = errCancelledDDLJob diff --git a/ddl/util/util.go b/ddl/util/util.go index 9635c66380db2..07d6000b5db80 100644 --- a/ddl/util/util.go +++ b/ddl/util/util.go @@ -22,6 +22,7 @@ import ( "github.com/pingcap/parser/terror" "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/sessionctx" + "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/util/chunk" "github.com/pingcap/tidb/util/sqlexec" ) @@ -68,14 +69,14 @@ func loadDeleteRangesFromTable(ctx sessionctx.Context, table string, safePoint u } rs := rss[0] - chk := rs.NewChunk() - it := chunk.NewIterator4Chunk(chk) + req := rs.NewRecordBatch() + it := chunk.NewIterator4Chunk(req.Chunk) for { - err = rs.Next(context.TODO(), chk) + err = rs.Next(context.TODO(), req) if err != nil { return nil, errors.Trace(err) } - if chk.NumRows() == 0 { + if req.NumRows() == 0 { break } @@ -128,3 +129,23 @@ func UpdateDeleteRange(ctx sessionctx.Context, dr DelRangeTask, newStartKey, old _, err := ctx.(sqlexec.SQLExecutor).Execute(context.TODO(), sql) return errors.Trace(err) } + +const loadDDLReorgVarsSQL = "select HIGH_PRIORITY variable_name, variable_value from mysql.global_variables where variable_name in ('" + + variable.TiDBDDLReorgWorkerCount + "', '" + + variable.TiDBDDLReorgBatchSize + "')" + +// LoadDDLReorgVars loads ddl reorg variable from mysql.global_variables. +func LoadDDLReorgVars(ctx sessionctx.Context) error { + if sctx, ok := ctx.(sqlexec.RestrictedSQLExecutor); ok { + rows, _, err := sctx.ExecRestrictedSQL(ctx, loadDDLReorgVarsSQL) + if err != nil { + return errors.Trace(err) + } + for _, row := range rows { + varName := row.GetString(0) + varValue := row.GetString(1) + variable.SetLocalSystemVar(varName, varValue) + } + } + return nil +} diff --git a/docs/design/2018-12-10-plugin-framework.md b/docs/design/2018-12-10-plugin-framework.md new file mode 100644 index 0000000000000..59b11d7ed8d2d --- /dev/null +++ b/docs/design/2018-12-10-plugin-framework.md @@ -0,0 +1,237 @@ +# Proposal: Support Plugin + +- Author(s): [lysu](https://github.com/lysu) +- Last updated: 2018-12-10 +- Discussion at: + +## Abstract + +This proposal proposes to introduce the plugin framework to TiDB to support TiDB plugin development. + +## Background + +Many cool customized requirements need to be addressed but it is not convenient to merge them into the main TiDB repository. In addition, Go 1.9+ introduces the new plugin support, so we can add a plugin framework to TiDB to make those requirements addressed, and attract more people to build the TiDB ecosystem together. + +## Proposal + +Add a plugin framework to TiDB. + +## Rationale + +Adding the plugin framework is based on Go's plugin support, but this framework supports uniform plugin manifest, package, and flexible SPI. + +The plugin developer can build a TiDB plugin in 7 steps: + +- Choose an exists plugin kind or create new plugin kind if not exists +- Create a normal go package, and add `manifest.toml` like example one. +- Implement the `validate`, `init`, `destroy` methods, which are needed for all plugins. +- Implement the kind special method to implement the plugin logic. +- Use `cmd/pluginpkg` to build plugin binary, and put the plugin binary into the plugin deployment folder. +- Start TiDB with the `-plugin-dir` and `-plugin-load` parameters. +- Run `show plugin` to check it's load status. + +## Implementation + +### Go Plugin + +We build the plugin framework based on Go's plugin support. At first, let's see "what is Go's plugin supported?" + +Go's plugin support is simple, just as the document at https://golang.org/pkg/plugin/. We can build and use the plugin in three steps: + +- Build the plugin via `go build -buildmode=plugin` in the `main` package to make plugin `.so`. +- Use `plugin.Open` to `dlopen` the plugin's `.so`. +- Use `plugin.Lookup` to `dlsym` to find the symbol in the plugin `.so`. + +There is another "undocumented" but important concept: `pluginpath`. Just as previously said we let our plugin code into the `main` package and then `go build -buildmode=plugin` to build a plugin. `pluginpath` is the package path for a plugin after plugin packaged. For example, we have a method named `DoIt` and `pluginpath` be `pkg1`, and then we can use `nm` to see the method name be `pluginpath.DoIt`. + +`pluginpath` can be given by `-ldflags -pluginpath=[path-value]` or generated by [go build](https://github.com/golang/go/blob/3b137dd2df19c261a007b8a620a2182cd679d700/src/cmd/go/internal/work/gc.go#L389)(for 1.11.1, it is the package name if `pluginpath` is built with the package folder or a content hash if built with the source file). + +If we load a Go plugin with the same `pluginpath` twice, the second `Load` call will get an error, Go plugin use `pluginpath` to detect duplicate load. + +The last thing we need to care about is the Go plugin's dependence. At first, almost all plugins need to depend on TiDB code to do its logic. Go runtime requires that the runtime hash and the link time hash for the dependence package are equal. So we do not need to care about the plugin that depends on some TiDB component. But TiDB code changes, so we need to release a new plugin whenever a TiDB new version is released. + +### TiDB Plugin + +Go plugin gives us a good start point, but we need to do something more to let the plugin be more uniform and easy to use with TiDB. + +#### Manifest + +Go Plugin gives us the ability to open a shared library. We need some meta info to self-describe the plugin, and then TiDB can know how to work with the loaded library. The information we need is as follows: + +- Plugin name: we need to reload the plugin, so we need to load the same plugin with a different version that is at a much higher level than `pluginpath`. +- Plugin version: plugin version makes it much easier to maintain. +- Simple Dependence check: Go helps to use a check to build the version, but in the real world it is common that a plugin relies on b plugin's some new logic. We try to maintain a simple Required-Version relationship between different plugins. +- Configuration: the plugin is a standalone module, so every plugin will introduce special sysVars just like normal MySQL variables. A user can use those variables to tweak plugin behaviors just like normal MySQL variables do. +- Stats: the plugin will introduce new stats info. TiDB uses Prometheus, so the plugin can easily push metrics to Prometheus. +- Plugin Category and Flexible SPI: TiDB can have limited plugin categories, a new plugin need choose a category and implement SPI defined by those category. + +All of above construct the plugin metadata which we usually call `Manifest`. `Manifest` describes the metadata and how others can use the plugin. + +We just use the Go plugin mechanism to load the TiDB plugin and the TiDB plugin gives us a `Manifest`. Then just use manifest to interact with the plugin. (Only load/lookup is heavy CGO call, and later call manifest is normal Golang method call.) + +#### SPI + +The SPI (Service Provider Interface) for the plugin is the method that returns manifest and the manifest info itself. The method that returns manifest can be generated by `pluginpkg`, so implementing SPI work for the developer is to choose and construct different manifest info (`pluginpkg` also helps with this). + +`Manifest` is the base struct for all other sub manifests. The caller can use `Kind` and `DeclareXXManifest` to convert them to sub manifests. + +Manifest provides common metadata: + +- Kind: the plugin's category. Now we have the audit kind, authentication kind, and so on. It's easy to add more. +- Name: name of the plugin, which is used to identify a plugin, so it cannot duplicate with other plugins. +- Version: we can load multiple versions a plugin into TiDB, but just activate one of them to support hot-fix or hot-upgrade. +- RequireVersion: it will make a simple relationship between different plugins. +- SysVars: it defines the plugin's configuration info. + +Manifest also provides three lifecycle extension points: + +- Validate: called after loading all plugins but before onInit, so it can do cross plugins check before init. +- OnInit: plugin can prepare resource before real work using OnInit. +- OnShutDown: plugin can clean up its resources before dying using OnShutDown. + +So we can image a common manifest code like this: + +``` +type Manifest struct { + Kind Kind + Name string + Description string + Version uint16 + RequireVersion map[string]uint16 + License string + BuildTime string + SysVars map[string]*variable.SysVar + Validate func(ctx context.Context, manifest *Manifest) error + OnInit func(ctx context.Context) error + OnShutdown func(ctx context.Context) error +} +``` + +Base on `Kind`, we can define other subManifest for authentication plugin, audit plugin and so on. + +Every subManifest will have a `Manifest` anonymous field as the FIRST field in struct definition, so every subManifest can be used as `Manifest` (by `unsafe.Pointer` cast). For example, an audit plugin' manifest will be like this: + +``` +type AuditManifest struct { + Manifest + NotifyEvent func(ctx context.Context) error +} +``` + +The reason we chose the embedded struct + unsafe.Pointer cast instead of the interface way here is that the first way is more flexible and more efficient to access data member than the fixed interface. At last, we also provide the package tools and a helper method to hide those details from the plugin developer. + +#### Package tool + +In this proposal, we add a simple tool `cmd/pluginpkg` to help package a plugin, and also uniform the package format. + +Plugin's development event no longer needs to care about previous Manifest and so on, so the developer can just provide a `manifest.toml` configuration file like this in the package: + +``` +name = "conn_ip_example" +kind = "Audit" +description = "just a test" +version = "2" +license = "" +sysVars = [ + {name="conn_ip_example_test_variable", scope="Global", value="2"}, + {name="conn_ip_example_test_variable2", scope="Session", value="2"}, +] +validate = "Validate" +onInit = "OnInit" +onShutdown = "OnShutdown" +export = [ + {extPoint="NotifyEvent", impl="NotifyEvent"} +] +``` + +- `name`: name of the plugin. It must be unique in the loaded TiDB instance. +- `kind`: kind of plugin. It determines the call-point in TiDB. The package tool is also based on it to generate a different manifest. +- `version`: the version of a plugin. For the same plugin, the same version is only loaded once. +- `description`: description of plugin usage. +- `license`: license of the plugin, which will display in `show plugins`. +- `sysVars`: it defines the variable needed by this plugin with name, scope and default value. +- `validate`: it specifies the callback function used to validate before load, e.g. auth plugin check `-with-skip-grant-tables` configuration. +- `onInit`: it specifies the callback function used to init plugin before it joins real work. +- `onShutdown`: the callback function will be called when the plugin shuts down to release outer resource held by the plugin, normally TiDB shutdown. +- `export`: it defines the callback list for the special kind plugins, e.g. for auth plugin it uses a `NotifyEvent` method to implement the `notifyEvent` extension point. + +`pluginpkg` generates code and the generated code is built as a Go plugin, and using plugin package we also control the plugin binary's format: + +- The plugin file name is `[pluginName]-[version].so`, so we can know the plugin's version from the filename. +- `pluginpath` will be `[pluginName]-[version]`, and then we can load the same plugin of a different version in the same host program. +- The package tool also adds some build time and misc info into Manifest info. + +Package tools add an abstract layer over manifest, so we can change manifest easier in future if needed. + +#### Plugin Point + +In TiDB code, we can add a new plugin point everywhere and: + +- Call `plugin.GetByKind` or `plugin.Get` to find matched plugins. +- Call `plugin.Declare[Kind]Manifest` to cast Manifest to a special kind. +- Call the extension point method for special manifest. + +We can see a simple example in `clientConn#Run` and `conn_ip_example` plugin implementation. + +Adding the new plugin point needs to modify the TiDB code and pass the required context and parameters. + +#### Configuration + +Every plugin has its own configurations. TiDB plugin uses system variables to handle configuration management requirement. + +In `manifest.toml`, we can use the `sysVar` field to provide plugin's variable name and its default value. Plugin's system variable will be registered as TiDB system variable, so the user can read/modify variable just like normal system variables. + +Plugin's variable name must use plugin name as the prefix. At last, the plugin cannot be reloaded if we change the plugin's sysVar (include default-value, add or remove variable). + +We implement it by adding the plugin variable into `variable.SysVars` before `bootstrap`, so later `doDMLWorker` will handle them just as normal sysVars, and change `loadCommonGlobalVarsSQL` to load them. (that it cannot unload plugin and cannot modify sysVar during reload makes this implementation easier) + +#### Dependency + +Go's plugin mechanism will check all dependency package hash to ensure link time and run time use the same version([see code](https://github.com/golang/go/blob/50bd1c4d4eb4fac8ddeb5f063c099daccfb71b26/src/runtime/plugin.go#L52)), so we no longer need to care about compiling package dependency. + +But for the real world, there may be logic dependency between plugins. For example, some guy writes an authorization plugin but it relies on vault plugin and only works when vault is enabled but does not directly rely on the vault plugin's source code. + +In `manifest.toml`, we can use `requireVersion` to declare A plugin requires B plugin in X version, and then plugin runtime will check it during the load or reload phase. + +### Reload + +Go plugin doesn't support unloading a plugin, but this cannot stop us from loading multiple versions of the plugin into the host program and framework, to ensure the last reloaded one will be active, and others aren't unloaded but disabled. + +So, we can reload the plugin with a different version that is packaged by `pluginpkg` to modify the plugin's implementation logic. Although we can not change the plugin's meta info (e.g. sysVars) now, it's still useful. + +#### Management + +To add a plugin to TiDB, we need to: + +- Add `-plugin-dir` as the start argument to specify the folder containing plugins, e.g. '-plugin-dir=/data/deploy/tidb/plugin'. +- Add `-plugin-load` as the start argument to specify the plugin id (name "-" version) that needs to be loaded, e.g. '-plugin-load=conn_limit-1'. + +Then starting TiDB will load and enable plugins. + +We can see all the plugins info by: + +``` +mysql> show plugins; ++-----------------+--------+-------+----------------------------------------------------+---------+---------+ +| Name | Status | Type | Library | License | Version | ++-----------------+--------+-------+----------------------------------------------------+---------+---------+ +| conn_limit-1 | Ready | Audit | /data/deploy/tidb/plugin/conn_limit-1.so | | 1 | ++-----------------+--------+-------+----------------------------------------------------+---------+---------+ +1 row in set (0.00 sec) +``` + +To reload a loaded plugin, just use + +``` +mysql> admin plugins reload conn_limit-2; +``` + +### Limitations + +The TiDB plugin has the following limitations: + +- The plugin cannot be unloaded. Once the plugin is loaded into TiDB, it can never be unloaded until the server is restarted, but we can reload the plugin in the limited situation to the hotfix plugin bug. +- Read sysVars in OnInit will get unexpected value but can access `Manifest` to get the default config value. +- Reloading cannot change the sysVar's default value or add/remove the variable. +- Building the plugin needs the TiDB source code tree, which is different from MySQL that can build plugin standalone (expect Information Schema and Storage Engine plugins) +- The plugin can only be written in Go. diff --git a/domain/domain.go b/domain/domain.go index c122e17ffb61d..367721fa0f1ef 100644 --- a/domain/domain.go +++ b/domain/domain.go @@ -33,11 +33,11 @@ import ( "github.com/pingcap/tidb/config" "github.com/pingcap/tidb/ddl" "github.com/pingcap/tidb/infoschema" + "github.com/pingcap/tidb/infoschema/perfschema" "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/meta" "github.com/pingcap/tidb/metrics" "github.com/pingcap/tidb/owner" - "github.com/pingcap/tidb/perfschema" "github.com/pingcap/tidb/privilege/privileges" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/sessionctx/variable" @@ -378,10 +378,13 @@ func (do *Domain) topNSlowQueryLoop() { do.slowQuery.Append(info) case msg := <-do.slowQuery.msgCh: req := msg.request - if req.Tp == ast.ShowSlowTop { + switch req.Tp { + case ast.ShowSlowTop: msg.result = do.slowQuery.QueryTop(int(req.Count), req.Kind) - } else if req.Tp == ast.ShowSlowRecent { + case ast.ShowSlowRecent: msg.result = do.slowQuery.QueryRecent(int(req.Count)) + default: + msg.result = do.slowQuery.QueryAll() } msg.Done() } diff --git a/domain/topn_slow_query.go b/domain/topn_slow_query.go index 0bf6721454d6c..4689fafad0a0e 100644 --- a/domain/topn_slow_query.go +++ b/domain/topn_slow_query.go @@ -160,6 +160,10 @@ func (q *topNSlowQueries) Append(info *SlowQueryInfo) { } } +func (q *topNSlowQueries) QueryAll() []*SlowQueryInfo { + return q.recent.data +} + func (q *topNSlowQueries) RemoveExpired(now time.Time) { q.user.RemoveExpired(now, q.period) q.internal.RemoveExpired(now, q.period) diff --git a/executor/adapter.go b/executor/adapter.go index f176181b7f664..f37ad05c5a1ab 100644 --- a/executor/adapter.go +++ b/executor/adapter.go @@ -95,18 +95,18 @@ func schema2ResultFields(schema *expression.Schema, defaultDB string) (rfs []*as // The reason we need update is that chunk with 0 rows indicating we already finished current query, we need prepare for // next query. // If stmt is not nil and chunk with some rows inside, we simply update last query found rows by the number of row in chunk. -func (a *recordSet) Next(ctx context.Context, chk *chunk.Chunk) error { +func (a *recordSet) Next(ctx context.Context, req *chunk.RecordBatch) error { if span := opentracing.SpanFromContext(ctx); span != nil && span.Tracer() != nil { span1 := span.Tracer().StartSpan("recordSet.Next", opentracing.ChildOf(span.Context())) defer span1.Finish() } - err := a.executor.Next(ctx, chk) + err := a.executor.Next(ctx, req) if err != nil { a.lastErr = err return errors.Trace(err) } - numRows := chk.NumRows() + numRows := req.NumRows() if numRows == 0 { if a.stmt != nil { a.stmt.Ctx.GetSessionVars().LastFoundRows = a.stmt.Ctx.GetSessionVars().StmtCtx.FoundRows() @@ -119,9 +119,9 @@ func (a *recordSet) Next(ctx context.Context, chk *chunk.Chunk) error { return nil } -// NewChunk create a new chunk using NewChunk function in chunk package. -func (a *recordSet) NewChunk() *chunk.Chunk { - return a.executor.newFirstChunk() +// NewRecordBatch create a recordBatch base on top-level executor's newFirstChunk(). +func (a *recordSet) NewRecordBatch() *chunk.RecordBatch { + return chunk.NewRecordBatch(a.executor.newFirstChunk()) } func (a *recordSet) Close() error { @@ -295,7 +295,7 @@ func (a *ExecStmt) handleNoDelayExecutor(ctx context.Context, sctx sessionctx.Co a.LogSlowQuery(txnTS, err == nil) }() - err = e.Next(ctx, e.newFirstChunk()) + err = e.Next(ctx, chunk.NewRecordBatch(e.newFirstChunk())) if err != nil { return nil, errors.Trace(err) } @@ -396,12 +396,12 @@ func (a *ExecStmt) LogSlowQuery(txnTS uint64, succ bool) { execDetail := sessVars.StmtCtx.GetExecDetails() if costTime < threshold { logutil.SlowQueryLogger.Debugf( - "[QUERY] %vcost_time:%v %s succ:%v con:%v user:%s txn_start_ts:%v database:%v %v%vsql:%v", - internal, costTime, execDetail, succ, connID, user, txnTS, currentDB, tableIDs, indexIDs, sql) + "[QUERY] %vcost_time:%vs %s succ:%v con:%v user:%s txn_start_ts:%v database:%v %v%vsql:%v", + internal, costTime.Seconds(), execDetail, succ, connID, user, txnTS, currentDB, tableIDs, indexIDs, sql) } else { logutil.SlowQueryLogger.Warnf( - "[SLOW_QUERY] %vcost_time:%v %s succ:%v con:%v user:%s txn_start_ts:%v database:%v %v%vsql:%v", - internal, costTime, execDetail, succ, connID, user, txnTS, currentDB, tableIDs, indexIDs, sql) + "[SLOW_QUERY] %vcost_time:%vs %s succ:%v con:%v user:%s txn_start_ts:%v database:%v %v%vsql:%v", + internal, costTime.Seconds(), execDetail, succ, connID, user, txnTS, currentDB, tableIDs, indexIDs, sql) metrics.TotalQueryProcHistogram.Observe(costTime.Seconds()) metrics.TotalCopProcHistogram.Observe(execDetail.ProcessTime.Seconds()) metrics.TotalCopWaitHistogram.Observe(execDetail.WaitTime.Seconds()) diff --git a/executor/admin.go b/executor/admin.go index 6759dd43945df..e31597062d807 100644 --- a/executor/admin.go +++ b/executor/admin.go @@ -60,8 +60,8 @@ type CheckIndexRangeExec struct { } // Next implements the Executor Next interface. -func (e *CheckIndexRangeExec) Next(ctx context.Context, chk *chunk.Chunk) error { - chk.Reset() +func (e *CheckIndexRangeExec) Next(ctx context.Context, req *chunk.RecordBatch) error { + req.Reset() handleIdx := e.schema.Len() - 1 for { err := e.result.Next(ctx, e.srcChunk) @@ -76,12 +76,12 @@ func (e *CheckIndexRangeExec) Next(ctx context.Context, chk *chunk.Chunk) error handle := row.GetInt64(handleIdx) for _, hr := range e.handleRanges { if handle >= hr.Begin && handle < hr.End { - chk.AppendRow(row) + req.AppendRow(row) break } } } - if chk.NumRows() > 0 { + if req.NumRows() > 0 { return nil } } @@ -444,8 +444,8 @@ func (e *RecoverIndexExec) backfillIndexInTxn(ctx context.Context, txn kv.Transa } // Next implements the Executor Next interface. -func (e *RecoverIndexExec) Next(ctx context.Context, chk *chunk.Chunk) error { - chk.Reset() +func (e *RecoverIndexExec) Next(ctx context.Context, req *chunk.RecordBatch) error { + req.Reset() if e.done { return nil } @@ -455,8 +455,8 @@ func (e *RecoverIndexExec) Next(ctx context.Context, chk *chunk.Chunk) error { return errors.Trace(err) } - chk.AppendInt64(0, totalAddedCnt) - chk.AppendInt64(1, totalScanCnt) + req.AppendInt64(0, totalAddedCnt) + req.AppendInt64(1, totalScanCnt) e.done = true return nil } @@ -580,8 +580,8 @@ func (e *CleanupIndexExec) fetchIndex(ctx context.Context, txn kv.Transaction) e } // Next implements the Executor Next interface. -func (e *CleanupIndexExec) Next(ctx context.Context, chk *chunk.Chunk) error { - chk.Reset() +func (e *CleanupIndexExec) Next(ctx context.Context, req *chunk.RecordBatch) error { + req.Reset() if e.done { return nil } @@ -614,7 +614,7 @@ func (e *CleanupIndexExec) Next(ctx context.Context, chk *chunk.Chunk) error { } } e.done = true - chk.AppendUint64(0, e.removeCnt) + req.AppendUint64(0, e.removeCnt) return nil } diff --git a/executor/aggregate.go b/executor/aggregate.go index 480da96700719..1f241ea67b1a1 100644 --- a/executor/aggregate.go +++ b/executor/aggregate.go @@ -518,20 +518,20 @@ func (w *HashAggFinalWorker) run(ctx sessionctx.Context, waitGroup *sync.WaitGro } // Next implements the Executor Next interface. -func (e *HashAggExec) Next(ctx context.Context, chk *chunk.Chunk) error { +func (e *HashAggExec) Next(ctx context.Context, req *chunk.RecordBatch) error { if span := opentracing.SpanFromContext(ctx); span != nil && span.Tracer() != nil { span1 := span.Tracer().StartSpan("hashagg.Next", opentracing.ChildOf(span.Context())) defer span1.Finish() } if e.runtimeStats != nil { start := time.Now() - defer func() { e.runtimeStats.Record(time.Now().Sub(start), chk.NumRows()) }() + defer func() { e.runtimeStats.Record(time.Now().Sub(start), req.NumRows()) }() } - chk.Reset() + req.Reset() if e.isUnparallelExec { - return errors.Trace(e.unparallelExec(ctx, chk)) + return errors.Trace(e.unparallelExec(ctx, req.Chunk)) } - return errors.Trace(e.parallelExec(ctx, chk)) + return errors.Trace(e.parallelExec(ctx, req.Chunk)) } func (e *HashAggExec) fetchChildData(ctx context.Context) { @@ -559,7 +559,7 @@ func (e *HashAggExec) fetchChildData(ctx context.Context) { } chk = input.chk } - err = e.children[0].Next(ctx, chk) + err = e.children[0].Next(ctx, chunk.NewRecordBatch(chk)) if err != nil { e.finalOutputCh <- &AfFinalResult{err: errors.Trace(err)} return @@ -684,7 +684,7 @@ func (e *HashAggExec) unparallelExec(ctx context.Context, chk *chunk.Chunk) erro func (e *HashAggExec) execute(ctx context.Context) (err error) { inputIter := chunk.NewIterator4Chunk(e.childResult) for { - err := e.children[0].Next(ctx, e.childResult) + err := e.children[0].Next(ctx, chunk.NewRecordBatch(e.childResult)) if err != nil { return errors.Trace(err) } @@ -795,18 +795,18 @@ func (e *StreamAggExec) Close() error { } // Next implements the Executor Next interface. -func (e *StreamAggExec) Next(ctx context.Context, chk *chunk.Chunk) error { +func (e *StreamAggExec) Next(ctx context.Context, req *chunk.RecordBatch) error { if span := opentracing.SpanFromContext(ctx); span != nil && span.Tracer() != nil { span1 := span.Tracer().StartSpan("streamAgg.Next", opentracing.ChildOf(span.Context())) defer span1.Finish() } if e.runtimeStats != nil { start := time.Now() - defer func() { e.runtimeStats.Record(time.Now().Sub(start), chk.NumRows()) }() + defer func() { e.runtimeStats.Record(time.Now().Sub(start), req.NumRows()) }() } - chk.Reset() - for !e.executed && chk.NumRows() < e.maxChunkSize { - err := e.consumeOneGroup(ctx, chk) + req.Reset() + for !e.executed && req.NumRows() < e.maxChunkSize { + err := e.consumeOneGroup(ctx, req.Chunk) if err != nil { e.executed = true return errors.Trace(err) @@ -871,7 +871,7 @@ func (e *StreamAggExec) fetchChildIfNecessary(ctx context.Context, chk *chunk.Ch return errors.Trace(err) } - err = e.children[0].Next(ctx, e.childResult) + err = e.children[0].Next(ctx, chunk.NewRecordBatch(e.childResult)) if err != nil { return errors.Trace(err) } diff --git a/executor/analyze.go b/executor/analyze.go index 2d7703935991c..2650bb7f1e59c 100644 --- a/executor/analyze.go +++ b/executor/analyze.go @@ -51,7 +51,7 @@ const ( ) // Next implements the Executor Next interface. -func (e *AnalyzeExec) Next(ctx context.Context, chk *chunk.Chunk) error { +func (e *AnalyzeExec) Next(ctx context.Context, req *chunk.RecordBatch) error { concurrency, err := getBuildStatsConcurrency(e.ctx) if err != nil { return errors.Trace(err) diff --git a/executor/builder.go b/executor/builder.go index 81151e7efd03c..c7018e4f23d22 100644 --- a/executor/builder.go +++ b/executor/builder.go @@ -918,6 +918,9 @@ func (b *executorBuilder) buildHashJoin(v *plannercore.PhysicalHashJoin) Executo v.OtherConditions, lhsTypes, rhsTypes) } metrics.ExecutorCounter.WithLabelValues("HashJoinExec").Inc() + if e.ctx.GetSessionVars().EnableRadixJoin { + return &RadixHashJoinExec{HashJoinExec: e} + } return e } diff --git a/executor/checksum.go b/executor/checksum.go index eb1bceaa06b40..3eff498d4f2b6 100644 --- a/executor/checksum.go +++ b/executor/checksum.go @@ -83,17 +83,17 @@ func (e *ChecksumTableExec) Open(ctx context.Context) error { } // Next implements the Executor Next interface. -func (e *ChecksumTableExec) Next(ctx context.Context, chk *chunk.Chunk) error { - chk.Reset() +func (e *ChecksumTableExec) Next(ctx context.Context, req *chunk.RecordBatch) error { + req.Reset() if e.done { return nil } for _, t := range e.tables { - chk.AppendString(0, t.DBInfo.Name.O) - chk.AppendString(1, t.TableInfo.Name.O) - chk.AppendUint64(2, t.Response.Checksum) - chk.AppendUint64(3, t.Response.TotalKvs) - chk.AppendUint64(4, t.Response.TotalBytes) + req.AppendString(0, t.DBInfo.Name.O) + req.AppendString(1, t.TableInfo.Name.O) + req.AppendUint64(2, t.Response.Checksum) + req.AppendUint64(3, t.Response.TotalKvs) + req.AppendUint64(4, t.Response.TotalBytes) } e.done = true return nil diff --git a/executor/ddl.go b/executor/ddl.go index 05c4f18eba002..c06f9caaadb2c 100644 --- a/executor/ddl.go +++ b/executor/ddl.go @@ -67,7 +67,7 @@ func (e *DDLExec) toErr(err error) error { } // Next implements the Executor Next interface. -func (e *DDLExec) Next(ctx context.Context, chk *chunk.Chunk) (err error) { +func (e *DDLExec) Next(ctx context.Context, req *chunk.RecordBatch) (err error) { if e.done { return nil } diff --git a/executor/ddl_test.go b/executor/ddl_test.go index cf88c198a5584..808fdad05fb61 100644 --- a/executor/ddl_test.go +++ b/executor/ddl_test.go @@ -25,7 +25,9 @@ import ( "github.com/pingcap/parser/mysql" "github.com/pingcap/parser/terror" "github.com/pingcap/tidb/ddl" + ddlutil "github.com/pingcap/tidb/ddl/util" "github.com/pingcap/tidb/domain" + "github.com/pingcap/tidb/meta/autoid" plannercore "github.com/pingcap/tidb/planner/core" "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/table" @@ -86,12 +88,12 @@ func (s *testSuite3) TestCreateTable(c *C) { rs, err := tk.Exec(`desc issue312_1`) c.Assert(err, IsNil) ctx := context.Background() - chk := rs.NewChunk() - it := chunk.NewIterator4Chunk(chk) + req := rs.NewRecordBatch() + it := chunk.NewIterator4Chunk(req.Chunk) for { - err1 := rs.Next(ctx, chk) + err1 := rs.Next(ctx, req) c.Assert(err1, IsNil) - if chk.NumRows() == 0 { + if req.NumRows() == 0 { break } for row := it.Begin(); row != it.End(); row = it.Next() { @@ -100,16 +102,16 @@ func (s *testSuite3) TestCreateTable(c *C) { } rs, err = tk.Exec(`desc issue312_2`) c.Assert(err, IsNil) - chk = rs.NewChunk() - it = chunk.NewIterator4Chunk(chk) + req = rs.NewRecordBatch() + it = chunk.NewIterator4Chunk(req.Chunk) for { - err1 := rs.Next(ctx, chk) + err1 := rs.Next(ctx, req) c.Assert(err1, IsNil) - if chk.NumRows() == 0 { + if req.NumRows() == 0 { break } for row := it.Begin(); row != it.End(); row = it.Next() { - c.Assert(chk.GetRow(0).GetString(1), Equals, "double") + c.Assert(req.GetRow(0).GetString(1), Equals, "double") } } @@ -243,10 +245,10 @@ func (s *testSuite3) TestAlterTableAddColumn(c *C) { now := time.Now().Add(-time.Duration(1 * time.Millisecond)).Format(types.TimeFormat) r, err := tk.Exec("select c2 from alter_test") c.Assert(err, IsNil) - chk := r.NewChunk() - err = r.Next(context.Background(), chk) + req := r.NewRecordBatch() + err = r.Next(context.Background(), req) c.Assert(err, IsNil) - row := chk.GetRow(0) + row := req.GetRow(0) c.Assert(row.Len(), Equals, 1) c.Assert(now, GreaterEqual, row.GetTime(0).String()) r.Close() @@ -469,6 +471,24 @@ func (s *testSuite3) TestShardRowIDBits(c *C) { _, err = tk.Exec("alter table auto shard_row_id_bits = 4") c.Assert(err, NotNil) tk.MustExec("alter table auto shard_row_id_bits = 0") + + // Test overflow + tk.MustExec("drop table if exists t1") + tk.MustExec("create table t1 (a int) shard_row_id_bits = 15") + defer tk.MustExec("drop table if exists t1") + + tbl, err = domain.GetDomain(tk.Se).InfoSchema().TableByName(model.NewCIStr("test"), model.NewCIStr("t1")) + c.Assert(err, IsNil) + maxID := 1<<(64-15-1) - 1 + err = tbl.RebaseAutoID(tk.Se, int64(maxID)-1, false) + c.Assert(err, IsNil) + tk.MustExec("insert into t1 values(1)") + + // continue inserting will fail. + _, err = tk.Exec("insert into t1 values(2)") + c.Assert(autoid.ErrAutoincReadFailed.Equal(err), IsTrue, Commentf("err:%v", err)) + _, err = tk.Exec("insert into t1 values(3)") + c.Assert(autoid.ErrAutoincReadFailed.Equal(err), IsTrue, Commentf("err:%v", err)) } func (s *testSuite3) TestMaxHandleAddIndex(c *C) { @@ -491,24 +511,32 @@ func (s *testSuite3) TestMaxHandleAddIndex(c *C) { func (s *testSuite3) TestSetDDLReorgWorkerCnt(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") + err := ddlutil.LoadDDLReorgVars(tk.Se) + c.Assert(err, IsNil) c.Assert(variable.GetDDLReorgWorkerCounter(), Equals, int32(variable.DefTiDBDDLReorgWorkerCount)) - tk.MustExec("set tidb_ddl_reorg_worker_cnt = 1") + tk.MustExec("set @@global.tidb_ddl_reorg_worker_cnt = 1") + err = ddlutil.LoadDDLReorgVars(tk.Se) + c.Assert(err, IsNil) c.Assert(variable.GetDDLReorgWorkerCounter(), Equals, int32(1)) - tk.MustExec("set tidb_ddl_reorg_worker_cnt = 100") + tk.MustExec("set @@global.tidb_ddl_reorg_worker_cnt = 100") + err = ddlutil.LoadDDLReorgVars(tk.Se) + c.Assert(err, IsNil) c.Assert(variable.GetDDLReorgWorkerCounter(), Equals, int32(100)) - _, err := tk.Exec("set tidb_ddl_reorg_worker_cnt = invalid_val") + _, err = tk.Exec("set @@global.tidb_ddl_reorg_worker_cnt = invalid_val") c.Assert(terror.ErrorEqual(err, variable.ErrWrongTypeForVar), IsTrue, Commentf("err %v", err)) - tk.MustExec("set tidb_ddl_reorg_worker_cnt = 100") + tk.MustExec("set @@global.tidb_ddl_reorg_worker_cnt = 100") + err = ddlutil.LoadDDLReorgVars(tk.Se) + c.Assert(err, IsNil) c.Assert(variable.GetDDLReorgWorkerCounter(), Equals, int32(100)) - _, err = tk.Exec("set tidb_ddl_reorg_worker_cnt = -1") + _, err = tk.Exec("set @@global.tidb_ddl_reorg_worker_cnt = -1") c.Assert(terror.ErrorEqual(err, variable.ErrWrongValueForVar), IsTrue, Commentf("err %v", err)) - tk.MustExec("set tidb_ddl_reorg_worker_cnt = 100") - res := tk.MustQuery("select @@tidb_ddl_reorg_worker_cnt") + tk.MustExec("set @@global.tidb_ddl_reorg_worker_cnt = 100") + res := tk.MustQuery("select @@global.tidb_ddl_reorg_worker_cnt") res.Check(testkit.Rows("100")) res = tk.MustQuery("select @@global.tidb_ddl_reorg_worker_cnt") - res.Check(testkit.Rows(fmt.Sprintf("%v", variable.DefTiDBDDLReorgWorkerCount))) + res.Check(testkit.Rows("100")) tk.MustExec("set @@global.tidb_ddl_reorg_worker_cnt = 100") res = tk.MustQuery("select @@global.tidb_ddl_reorg_worker_cnt") res.Check(testkit.Rows("100")) @@ -517,28 +545,39 @@ func (s *testSuite3) TestSetDDLReorgWorkerCnt(c *C) { func (s *testSuite3) TestSetDDLReorgBatchSize(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") + err := ddlutil.LoadDDLReorgVars(tk.Se) + c.Assert(err, IsNil) c.Assert(variable.GetDDLReorgBatchSize(), Equals, int32(variable.DefTiDBDDLReorgBatchSize)) - tk.MustExec("set tidb_ddl_reorg_batch_size = 1") + tk.MustExec("set @@global.tidb_ddl_reorg_batch_size = 1") tk.MustQuery("show warnings;").Check(testkit.Rows("Warning 1292 Truncated incorrect tidb_ddl_reorg_batch_size value: '1'")) + err = ddlutil.LoadDDLReorgVars(tk.Se) + c.Assert(err, IsNil) c.Assert(variable.GetDDLReorgBatchSize(), Equals, int32(variable.MinDDLReorgBatchSize)) - tk.MustExec(fmt.Sprintf("set tidb_ddl_reorg_batch_size = %v", variable.MaxDDLReorgBatchSize+1)) + tk.MustExec(fmt.Sprintf("set @@global.tidb_ddl_reorg_batch_size = %v", variable.MaxDDLReorgBatchSize+1)) tk.MustQuery("show warnings;").Check(testkit.Rows(fmt.Sprintf("Warning 1292 Truncated incorrect tidb_ddl_reorg_batch_size value: '%d'", variable.MaxDDLReorgBatchSize+1))) + err = ddlutil.LoadDDLReorgVars(tk.Se) + c.Assert(err, IsNil) c.Assert(variable.GetDDLReorgBatchSize(), Equals, int32(variable.MaxDDLReorgBatchSize)) - _, err := tk.Exec("set tidb_ddl_reorg_batch_size = invalid_val") + _, err = tk.Exec("set @@global.tidb_ddl_reorg_batch_size = invalid_val") c.Assert(terror.ErrorEqual(err, variable.ErrWrongTypeForVar), IsTrue, Commentf("err %v", err)) - tk.MustExec("set tidb_ddl_reorg_batch_size = 100") + tk.MustExec("set @@global.tidb_ddl_reorg_batch_size = 100") + err = ddlutil.LoadDDLReorgVars(tk.Se) + c.Assert(err, IsNil) c.Assert(variable.GetDDLReorgBatchSize(), Equals, int32(100)) - tk.MustExec("set tidb_ddl_reorg_batch_size = -1") + tk.MustExec("set @@global.tidb_ddl_reorg_batch_size = -1") tk.MustQuery("show warnings;").Check(testkit.Rows("Warning 1292 Truncated incorrect tidb_ddl_reorg_batch_size value: '-1'")) - tk.MustExec("set tidb_ddl_reorg_batch_size = 100") - res := tk.MustQuery("select @@tidb_ddl_reorg_batch_size") + tk.MustExec("set @@global.tidb_ddl_reorg_batch_size = 100") + res := tk.MustQuery("select @@global.tidb_ddl_reorg_batch_size") res.Check(testkit.Rows("100")) res = tk.MustQuery("select @@global.tidb_ddl_reorg_batch_size") - res.Check(testkit.Rows(fmt.Sprintf("%v", variable.DefTiDBDDLReorgBatchSize))) + res.Check(testkit.Rows(fmt.Sprintf("%v", 100))) tk.MustExec("set @@global.tidb_ddl_reorg_batch_size = 1000") res = tk.MustQuery("select @@global.tidb_ddl_reorg_batch_size") res.Check(testkit.Rows("1000")) + + // If do not LoadDDLReorgVars, the local variable will be the last loaded value. + c.Assert(variable.GetDDLReorgBatchSize(), Equals, int32(100)) } diff --git a/executor/delete.go b/executor/delete.go index f0fdc3c84c7cf..6cb537cbccddc 100644 --- a/executor/delete.go +++ b/executor/delete.go @@ -44,13 +44,13 @@ type DeleteExec struct { } // Next implements the Executor Next interface. -func (e *DeleteExec) Next(ctx context.Context, chk *chunk.Chunk) error { +func (e *DeleteExec) Next(ctx context.Context, req *chunk.RecordBatch) error { if span := opentracing.SpanFromContext(ctx); span != nil && span.Tracer() != nil { span1 := span.Tracer().StartSpan("delete.Next", opentracing.ChildOf(span.Context())) defer span1.Finish() } - chk.Reset() + req.Reset() if e.IsMultiTable { return errors.Trace(e.deleteMultiTablesByChunk(ctx)) } @@ -106,7 +106,7 @@ func (e *DeleteExec) deleteSingleTableByChunk(ctx context.Context) error { for { iter := chunk.NewIterator4Chunk(chk) - err := e.children[0].Next(ctx, chk) + err := e.children[0].Next(ctx, chunk.NewRecordBatch(chk)) if err != nil { return errors.Trace(err) } @@ -188,7 +188,7 @@ func (e *DeleteExec) deleteMultiTablesByChunk(ctx context.Context) error { chk := e.children[0].newFirstChunk() for { iter := chunk.NewIterator4Chunk(chk) - err := e.children[0].Next(ctx, chk) + err := e.children[0].Next(ctx, chunk.NewRecordBatch(chk)) if err != nil { return errors.Trace(err) } diff --git a/executor/distsql.go b/executor/distsql.go index 8b2b6a21820e6..9a8948f9a8643 100644 --- a/executor/distsql.go +++ b/executor/distsql.go @@ -123,7 +123,7 @@ func statementContextToFlags(sc *stmtctx.StatementContext) uint64 { var flags uint64 if sc.InInsertStmt { flags |= model.FlagInInsertStmt - } else if sc.InUpdateOrDeleteStmt { + } else if sc.InUpdateStmt || sc.InDeleteStmt { flags |= model.FlagInUpdateOrDeleteStmt } else if sc.InSelectStmt { flags |= model.FlagInSelectStmt @@ -248,16 +248,16 @@ func (e *IndexReaderExecutor) Close() error { } // Next implements the Executor Next interface. -func (e *IndexReaderExecutor) Next(ctx context.Context, chk *chunk.Chunk) error { +func (e *IndexReaderExecutor) Next(ctx context.Context, req *chunk.RecordBatch) error { if span := opentracing.SpanFromContext(ctx); span != nil && span.Tracer() != nil { span1 := span.Tracer().StartSpan("tableReader.Next", opentracing.ChildOf(span.Context())) defer span1.Finish() } if e.runtimeStats != nil { start := time.Now() - defer func() { e.runtimeStats.Record(time.Now().Sub(start), chk.NumRows()) }() + defer func() { e.runtimeStats.Record(time.Now().Sub(start), req.NumRows()) }() } - err := e.result.Next(ctx, chk) + err := e.result.Next(ctx, req.Chunk) if err != nil { e.feedback.Invalidate() } @@ -539,12 +539,12 @@ func (e *IndexLookUpExecutor) Close() error { } // Next implements Exec Next interface. -func (e *IndexLookUpExecutor) Next(ctx context.Context, chk *chunk.Chunk) error { +func (e *IndexLookUpExecutor) Next(ctx context.Context, req *chunk.RecordBatch) error { if e.runtimeStats != nil { start := time.Now() - defer func() { e.runtimeStats.Record(time.Now().Sub(start), chk.NumRows()) }() + defer func() { e.runtimeStats.Record(time.Now().Sub(start), req.NumRows()) }() } - chk.Reset() + req.Reset() for { resultTask, err := e.getResultTask() if err != nil { @@ -554,9 +554,9 @@ func (e *IndexLookUpExecutor) Next(ctx context.Context, chk *chunk.Chunk) error return nil } for resultTask.cursor < len(resultTask.rows) { - chk.AppendRow(resultTask.rows[resultTask.cursor]) + req.AppendRow(resultTask.rows[resultTask.cursor]) resultTask.cursor++ - if chk.NumRows() >= e.maxChunkSize { + if req.NumRows() >= e.maxChunkSize { return nil } } @@ -745,7 +745,7 @@ func (w *tableWorker) executeTask(ctx context.Context, task *lookupTableTask) er task.rows = make([]chunk.Row, 0, handleCnt) for { chk := tableReader.newFirstChunk() - err = tableReader.Next(ctx, chk) + err = tableReader.Next(ctx, chunk.NewRecordBatch(chk)) if err != nil { log.Error(err) return errors.Trace(err) diff --git a/executor/distsql_test.go b/executor/distsql_test.go index 54a50b22f214b..d63624e883343 100644 --- a/executor/distsql_test.go +++ b/executor/distsql_test.go @@ -68,10 +68,10 @@ func (s *testSuite3) TestCopClientSend(c *C) { // Send coprocessor request when the table split. rs, err := tk.Exec("select sum(id) from copclient") c.Assert(err, IsNil) - chk := rs.NewChunk() - err = rs.Next(ctx, chk) + req := rs.NewRecordBatch() + err = rs.Next(ctx, req) c.Assert(err, IsNil) - c.Assert(chk.GetRow(0).GetMyDecimal(0).String(), Equals, "499500") + c.Assert(req.GetRow(0).GetMyDecimal(0).String(), Equals, "499500") rs.Close() // Split one region. @@ -83,17 +83,17 @@ func (s *testSuite3) TestCopClientSend(c *C) { // Check again. rs, err = tk.Exec("select sum(id) from copclient") c.Assert(err, IsNil) - chk = rs.NewChunk() - err = rs.Next(ctx, chk) + req = rs.NewRecordBatch() + err = rs.Next(ctx, req) c.Assert(err, IsNil) - c.Assert(chk.GetRow(0).GetMyDecimal(0).String(), Equals, "499500") + c.Assert(req.GetRow(0).GetMyDecimal(0).String(), Equals, "499500") rs.Close() // Check there is no goroutine leak. rs, err = tk.Exec("select * from copclient order by id") c.Assert(err, IsNil) - chk = rs.NewChunk() - err = rs.Next(ctx, chk) + req = rs.NewRecordBatch() + err = rs.Next(ctx, req) c.Assert(err, IsNil) rs.Close() keyword := "(*copIterator).work" diff --git a/executor/executor.go b/executor/executor.go index e5b7697261547..4dd821d9b442e 100644 --- a/executor/executor.go +++ b/executor/executor.go @@ -124,7 +124,7 @@ func (e *baseExecutor) retTypes() []*types.FieldType { } // Next fills mutiple rows into a chunk. -func (e *baseExecutor) Next(ctx context.Context, chk *chunk.Chunk) error { +func (e *baseExecutor) Next(ctx context.Context, req *chunk.RecordBatch) error { return nil } @@ -162,7 +162,7 @@ func newBaseExecutor(ctx sessionctx.Context, schema *expression.Schema, id strin // NOTE: Executors must call "chk.Reset()" before appending their results to it. type Executor interface { Open(context.Context) error - Next(ctx context.Context, chk *chunk.Chunk) error + Next(ctx context.Context, req *chunk.RecordBatch) error Close() error Schema() *expression.Schema @@ -180,22 +180,22 @@ type CancelDDLJobsExec struct { } // Next implements the Executor Next interface. -func (e *CancelDDLJobsExec) Next(ctx context.Context, chk *chunk.Chunk) error { +func (e *CancelDDLJobsExec) Next(ctx context.Context, req *chunk.RecordBatch) error { if e.runtimeStats != nil { start := time.Now() - defer func() { e.runtimeStats.Record(time.Now().Sub(start), chk.NumRows()) }() + defer func() { e.runtimeStats.Record(time.Now().Sub(start), req.NumRows()) }() } - chk.GrowAndReset(e.maxChunkSize) + req.GrowAndReset(e.maxChunkSize) if e.cursor >= len(e.jobIDs) { return nil } - numCurBatch := mathutil.Min(chk.Capacity(), len(e.jobIDs)-e.cursor) + numCurBatch := mathutil.Min(req.Capacity(), len(e.jobIDs)-e.cursor) for i := e.cursor; i < e.cursor+numCurBatch; i++ { - chk.AppendString(0, fmt.Sprintf("%d", e.jobIDs[i])) + req.AppendString(0, fmt.Sprintf("%d", e.jobIDs[i])) if e.errs[i] != nil { - chk.AppendString(1, fmt.Sprintf("error: %v", e.errs[i])) + req.AppendString(1, fmt.Sprintf("error: %v", e.errs[i])) } else { - chk.AppendString(1, "successful") + req.AppendString(1, "successful") } } e.cursor += numCurBatch @@ -210,8 +210,8 @@ type ShowNextRowIDExec struct { } // Next implements the Executor Next interface. -func (e *ShowNextRowIDExec) Next(ctx context.Context, chk *chunk.Chunk) error { - chk.Reset() +func (e *ShowNextRowIDExec) Next(ctx context.Context, req *chunk.RecordBatch) error { + req.Reset() if e.done { return nil } @@ -231,10 +231,10 @@ func (e *ShowNextRowIDExec) Next(ctx context.Context, chk *chunk.Chunk) error { if err != nil { return errors.Trace(err) } - chk.AppendString(0, e.tblName.Schema.O) - chk.AppendString(1, e.tblName.Name.O) - chk.AppendString(2, colName.O) - chk.AppendInt64(3, nextGlobalID) + req.AppendString(0, e.tblName.Schema.O) + req.AppendString(1, e.tblName.Name.O) + req.AppendString(2, colName.O) + req.AppendInt64(3, nextGlobalID) e.done = true return nil } @@ -250,8 +250,8 @@ type ShowDDLExec struct { } // Next implements the Executor Next interface. -func (e *ShowDDLExec) Next(ctx context.Context, chk *chunk.Chunk) error { - chk.Reset() +func (e *ShowDDLExec) Next(ctx context.Context, req *chunk.RecordBatch) error { + req.Reset() if e.done { return nil } @@ -264,10 +264,10 @@ func (e *ShowDDLExec) Next(ctx context.Context, chk *chunk.Chunk) error { ddlJobs += "\n" } } - chk.AppendInt64(0, e.ddlInfo.SchemaVer) - chk.AppendString(1, e.ddlOwnerID) - chk.AppendString(2, ddlJobs) - chk.AppendString(3, e.selfID) + req.AppendInt64(0, e.ddlInfo.SchemaVer) + req.AppendString(1, e.ddlOwnerID) + req.AppendString(2, ddlJobs) + req.AppendString(3, e.selfID) e.done = true return nil } @@ -318,19 +318,19 @@ func (e *ShowDDLJobQueriesExec) Open(ctx context.Context) error { } // Next implements the Executor Next interface. -func (e *ShowDDLJobQueriesExec) Next(ctx context.Context, chk *chunk.Chunk) error { - chk.GrowAndReset(e.maxChunkSize) +func (e *ShowDDLJobQueriesExec) Next(ctx context.Context, req *chunk.RecordBatch) error { + req.GrowAndReset(e.maxChunkSize) if e.cursor >= len(e.jobs) { return nil } if len(e.jobIDs) >= len(e.jobs) { return nil } - numCurBatch := mathutil.Min(chk.Capacity(), len(e.jobs)-e.cursor) + numCurBatch := mathutil.Min(req.Capacity(), len(e.jobs)-e.cursor) for _, id := range e.jobIDs { for i := e.cursor; i < e.cursor+numCurBatch; i++ { if id == e.jobs[i].ID { - chk.AppendString(0, e.jobs[i].Query) + req.AppendString(0, e.jobs[i].Query) } } } @@ -365,23 +365,23 @@ func (e *ShowDDLJobsExec) Open(ctx context.Context) error { } // Next implements the Executor Next interface. -func (e *ShowDDLJobsExec) Next(ctx context.Context, chk *chunk.Chunk) error { - chk.GrowAndReset(e.maxChunkSize) +func (e *ShowDDLJobsExec) Next(ctx context.Context, req *chunk.RecordBatch) error { + req.GrowAndReset(e.maxChunkSize) if e.cursor >= len(e.jobs) { return nil } - numCurBatch := mathutil.Min(chk.Capacity(), len(e.jobs)-e.cursor) + numCurBatch := mathutil.Min(req.Capacity(), len(e.jobs)-e.cursor) for i := e.cursor; i < e.cursor+numCurBatch; i++ { - chk.AppendInt64(0, e.jobs[i].ID) - chk.AppendString(1, getSchemaName(e.is, e.jobs[i].SchemaID)) - chk.AppendString(2, getTableName(e.is, e.jobs[i].TableID)) - chk.AppendString(3, e.jobs[i].Type.String()) - chk.AppendString(4, e.jobs[i].SchemaState.String()) - chk.AppendInt64(5, e.jobs[i].SchemaID) - chk.AppendInt64(6, e.jobs[i].TableID) - chk.AppendInt64(7, e.jobs[i].RowCount) - chk.AppendString(8, model.TSConvert2Time(e.jobs[i].StartTS).String()) - chk.AppendString(9, e.jobs[i].State.String()) + req.AppendInt64(0, e.jobs[i].ID) + req.AppendString(1, getSchemaName(e.is, e.jobs[i].SchemaID)) + req.AppendString(2, getTableName(e.is, e.jobs[i].TableID)) + req.AppendString(3, e.jobs[i].Type.String()) + req.AppendString(4, e.jobs[i].SchemaState.String()) + req.AppendInt64(5, e.jobs[i].SchemaID) + req.AppendInt64(6, e.jobs[i].TableID) + req.AppendInt64(7, e.jobs[i].RowCount) + req.AppendString(8, model.TSConvert2Time(e.jobs[i].StartTS).String()) + req.AppendString(9, e.jobs[i].State.String()) } e.cursor += numCurBatch return nil @@ -432,7 +432,7 @@ func (e *CheckTableExec) Open(ctx context.Context) error { } // Next implements the Executor Next interface. -func (e *CheckTableExec) Next(ctx context.Context, chk *chunk.Chunk) error { +func (e *CheckTableExec) Next(ctx context.Context, req *chunk.RecordBatch) error { if e.done { return nil } @@ -521,7 +521,7 @@ func (e *CheckIndexExec) Close() error { } // Next implements the Executor Next interface. -func (e *CheckIndexExec) Next(ctx context.Context, chk *chunk.Chunk) error { +func (e *CheckIndexExec) Next(ctx context.Context, req *chunk.RecordBatch) error { if e.done { return nil } @@ -531,9 +531,9 @@ func (e *CheckIndexExec) Next(ctx context.Context, chk *chunk.Chunk) error { if err != nil { return errors.Trace(err) } - chk = e.src.newFirstChunk() + chk := e.src.newFirstChunk() for { - err := e.src.Next(ctx, chk) + err := e.src.Next(ctx, chunk.NewRecordBatch(chk)) if err != nil { return errors.Trace(err) } @@ -568,37 +568,37 @@ func (e *ShowSlowExec) Open(ctx context.Context) error { } // Next implements the Executor Next interface. -func (e *ShowSlowExec) Next(ctx context.Context, chk *chunk.Chunk) error { - chk.Reset() +func (e *ShowSlowExec) Next(ctx context.Context, req *chunk.RecordBatch) error { + req.Reset() if e.cursor >= len(e.result) { return nil } - for e.cursor < len(e.result) && chk.NumRows() < e.maxChunkSize { + for e.cursor < len(e.result) && req.NumRows() < e.maxChunkSize { slow := e.result[e.cursor] - chk.AppendString(0, slow.SQL) - chk.AppendTime(1, types.Time{ + req.AppendString(0, slow.SQL) + req.AppendTime(1, types.Time{ Time: types.FromGoTime(slow.Start), Type: mysql.TypeTimestamp, Fsp: types.MaxFsp, }) - chk.AppendDuration(2, types.Duration{Duration: slow.Duration, Fsp: types.MaxFsp}) - chk.AppendString(3, slow.Detail.String()) + req.AppendDuration(2, types.Duration{Duration: slow.Duration, Fsp: types.MaxFsp}) + req.AppendString(3, slow.Detail.String()) if slow.Succ { - chk.AppendInt64(4, 1) + req.AppendInt64(4, 1) } else { - chk.AppendInt64(4, 0) + req.AppendInt64(4, 0) } - chk.AppendUint64(5, slow.ConnID) - chk.AppendUint64(6, slow.TxnTS) - chk.AppendString(7, slow.User) - chk.AppendString(8, slow.DB) - chk.AppendString(9, slow.TableIDs) - chk.AppendString(10, slow.IndexIDs) + req.AppendUint64(5, slow.ConnID) + req.AppendUint64(6, slow.TxnTS) + req.AppendString(7, slow.User) + req.AppendString(8, slow.DB) + req.AppendString(9, slow.TableIDs) + req.AppendString(10, slow.IndexIDs) if slow.Internal { - chk.AppendInt64(11, 0) + req.AppendInt64(11, 0) } else { - chk.AppendInt64(11, 1) + req.AppendInt64(11, 1) } e.cursor++ } @@ -633,9 +633,9 @@ func (e *SelectLockExec) Open(ctx context.Context) error { } // Next implements the Executor Next interface. -func (e *SelectLockExec) Next(ctx context.Context, chk *chunk.Chunk) error { - chk.GrowAndReset(e.maxChunkSize) - err := e.children[0].Next(ctx, chk) +func (e *SelectLockExec) Next(ctx context.Context, req *chunk.RecordBatch) error { + req.GrowAndReset(e.maxChunkSize) + err := e.children[0].Next(ctx, req) if err != nil { return errors.Trace(err) } @@ -647,8 +647,8 @@ func (e *SelectLockExec) Next(ctx context.Context, chk *chunk.Chunk) error { if err != nil { return errors.Trace(err) } - keys := make([]kv.Key, 0, chk.NumRows()) - iter := chunk.NewIterator4Chunk(chk) + keys := make([]kv.Key, 0, req.NumRows()) + iter := chunk.NewIterator4Chunk(req.Chunk) for id, cols := range e.Schema().TblID2Handle { for _, col := range cols { keys = keys[:0] @@ -680,21 +680,21 @@ type LimitExec struct { } // Next implements the Executor Next interface. -func (e *LimitExec) Next(ctx context.Context, chk *chunk.Chunk) error { +func (e *LimitExec) Next(ctx context.Context, req *chunk.RecordBatch) error { if span := opentracing.SpanFromContext(ctx); span != nil && span.Tracer() != nil { span1 := span.Tracer().StartSpan("limit.Next", opentracing.ChildOf(span.Context())) defer span1.Finish() } if e.runtimeStats != nil { start := time.Now() - defer func() { e.runtimeStats.Record(time.Now().Sub(start), chk.NumRows()) }() + defer func() { e.runtimeStats.Record(time.Now().Sub(start), req.NumRows()) }() } - chk.Reset() + req.Reset() if e.cursor >= e.end { return nil } for !e.meetFirstBatch { - err := e.children[0].Next(ctx, e.childResult) + err := e.children[0].Next(ctx, chunk.NewRecordBatch(e.childResult)) if err != nil { return errors.Trace(err) } @@ -713,22 +713,22 @@ func (e *LimitExec) Next(ctx context.Context, chk *chunk.Chunk) error { if begin == end { break } - chk.Append(e.childResult, int(begin), int(end)) + req.Append(e.childResult, int(begin), int(end)) return nil } e.cursor += batchSize } - err := e.children[0].Next(ctx, chk) + err := e.children[0].Next(ctx, req) if err != nil { return errors.Trace(err) } - batchSize := uint64(chk.NumRows()) + batchSize := uint64(req.NumRows()) // no more data. if batchSize == 0 { return nil } if e.cursor+batchSize > e.end { - chk.TruncateTo(int(e.end - e.cursor)) + req.TruncateTo(int(e.end - e.cursor)) batchSize = e.end - e.cursor } e.cursor += batchSize @@ -770,7 +770,7 @@ func init() { } chk := exec.newFirstChunk() for { - err = exec.Next(ctx, chk) + err = exec.Next(ctx, chunk.NewRecordBatch(chk)) if err != nil { return rows, errors.Trace(err) } @@ -803,24 +803,24 @@ func (e *TableDualExec) Open(ctx context.Context) error { } // Next implements the Executor Next interface. -func (e *TableDualExec) Next(ctx context.Context, chk *chunk.Chunk) error { +func (e *TableDualExec) Next(ctx context.Context, req *chunk.RecordBatch) error { if span := opentracing.SpanFromContext(ctx); span != nil && span.Tracer() != nil { span1 := span.Tracer().StartSpan("tableDual.Next", opentracing.ChildOf(span.Context())) defer span1.Finish() } if e.runtimeStats != nil { start := time.Now() - defer func() { e.runtimeStats.Record(time.Now().Sub(start), chk.NumRows()) }() + defer func() { e.runtimeStats.Record(time.Now().Sub(start), req.NumRows()) }() } - chk.Reset() + req.Reset() if e.numReturned >= e.numDualRows { return nil } if e.Schema().Len() == 0 { - chk.SetNumVirtualRows(1) + req.SetNumVirtualRows(1) } else { for i := range e.Schema().Columns { - chk.AppendNull(i) + req.AppendNull(i) } } e.numReturned = e.numDualRows @@ -862,19 +862,19 @@ func (e *SelectionExec) Close() error { } // Next implements the Executor Next interface. -func (e *SelectionExec) Next(ctx context.Context, chk *chunk.Chunk) error { +func (e *SelectionExec) Next(ctx context.Context, req *chunk.RecordBatch) error { if span := opentracing.SpanFromContext(ctx); span != nil && span.Tracer() != nil { span1 := span.Tracer().StartSpan("selection.Next", opentracing.ChildOf(span.Context())) defer span1.Finish() } if e.runtimeStats != nil { start := time.Now() - defer func() { e.runtimeStats.Record(time.Now().Sub(start), chk.NumRows()) }() + defer func() { e.runtimeStats.Record(time.Now().Sub(start), req.NumRows()) }() } - chk.GrowAndReset(e.maxChunkSize) + req.GrowAndReset(e.maxChunkSize) if !e.batched { - return errors.Trace(e.unBatchedNext(ctx, chk)) + return errors.Trace(e.unBatchedNext(ctx, req.Chunk)) } for { @@ -882,12 +882,12 @@ func (e *SelectionExec) Next(ctx context.Context, chk *chunk.Chunk) error { if !e.selected[e.inputRow.Idx()] { continue } - if chk.NumRows() >= chk.Capacity() { + if req.NumRows() >= req.Capacity() { return nil } - chk.AppendRow(e.inputRow) + req.AppendRow(e.inputRow) } - err := e.children[0].Next(ctx, e.childResult) + err := e.children[0].Next(ctx, chunk.NewRecordBatch(e.childResult)) if err != nil { return errors.Trace(err) } @@ -919,7 +919,7 @@ func (e *SelectionExec) unBatchedNext(ctx context.Context, chk *chunk.Chunk) err return nil } } - err := e.children[0].Next(ctx, e.childResult) + err := e.children[0].Next(ctx, chunk.NewRecordBatch(e.childResult)) if err != nil { return errors.Trace(err) } @@ -945,18 +945,18 @@ type TableScanExec struct { } // Next implements the Executor Next interface. -func (e *TableScanExec) Next(ctx context.Context, chk *chunk.Chunk) error { +func (e *TableScanExec) Next(ctx context.Context, req *chunk.RecordBatch) error { if span := opentracing.SpanFromContext(ctx); span != nil && span.Tracer() != nil { span1 := span.Tracer().StartSpan("tableScan.Next", opentracing.ChildOf(span.Context())) defer span1.Finish() } if e.runtimeStats != nil { start := time.Now() - defer func() { e.runtimeStats.Record(time.Now().Sub(start), chk.NumRows()) }() + defer func() { e.runtimeStats.Record(time.Now().Sub(start), req.NumRows()) }() } - chk.GrowAndReset(e.maxChunkSize) + req.GrowAndReset(e.maxChunkSize) if e.isVirtualTable { - return errors.Trace(e.nextChunk4InfoSchema(ctx, chk)) + return errors.Trace(e.nextChunk4InfoSchema(ctx, req.Chunk)) } handle, found, err := e.nextHandle() if err != nil || !found { @@ -964,14 +964,14 @@ func (e *TableScanExec) Next(ctx context.Context, chk *chunk.Chunk) error { } mutableRow := chunk.MutRowFromTypes(e.retTypes()) - for chk.NumRows() < chk.Capacity() { + for req.NumRows() < req.Capacity() { row, err := e.getRow(handle) if err != nil { return errors.Trace(err) } e.seekHandle = handle + 1 mutableRow.SetDatums(row...) - chk.AppendRow(mutableRow.ToRow()) + req.AppendRow(mutableRow.ToRow()) } return nil } @@ -1053,28 +1053,28 @@ func (e *MaxOneRowExec) Open(ctx context.Context) error { } // Next implements the Executor Next interface. -func (e *MaxOneRowExec) Next(ctx context.Context, chk *chunk.Chunk) error { +func (e *MaxOneRowExec) Next(ctx context.Context, req *chunk.RecordBatch) error { if span := opentracing.SpanFromContext(ctx); span != nil && span.Tracer() != nil { span1 := span.Tracer().StartSpan("maxOneRow.Next", opentracing.ChildOf(span.Context())) defer span1.Finish() } if e.runtimeStats != nil { start := time.Now() - defer func() { e.runtimeStats.Record(time.Now().Sub(start), chk.NumRows()) }() + defer func() { e.runtimeStats.Record(time.Now().Sub(start), req.NumRows()) }() } - chk.Reset() + req.Reset() if e.evaluated { return nil } e.evaluated = true - err := e.children[0].Next(ctx, chk) + err := e.children[0].Next(ctx, req) if err != nil { return errors.Trace(err) } - if num := chk.NumRows(); num == 0 { + if num := req.NumRows(); num == 0 { for i := range e.schema.Columns { - chk.AppendNull(i) + req.AppendNull(i) } return nil } else if num != 1 { @@ -1082,7 +1082,7 @@ func (e *MaxOneRowExec) Next(ctx context.Context, chk *chunk.Chunk) error { } childChunk := e.children[0].newFirstChunk() - err = e.children[0].Next(ctx, childChunk) + err = e.children[0].Next(ctx, chunk.NewRecordBatch(childChunk)) if err != nil { return errors.Trace(err) } @@ -1193,7 +1193,7 @@ func (e *UnionExec) resultPuller(ctx context.Context, childID int) { return case result.chk = <-e.resourcePools[childID]: } - result.err = errors.Trace(e.children[childID].Next(ctx, result.chk)) + result.err = errors.Trace(e.children[childID].Next(ctx, chunk.NewRecordBatch(result.chk))) if result.err == nil && result.chk.NumRows() == 0 { return } @@ -1206,16 +1206,16 @@ func (e *UnionExec) resultPuller(ctx context.Context, childID int) { } // Next implements the Executor Next interface. -func (e *UnionExec) Next(ctx context.Context, chk *chunk.Chunk) error { +func (e *UnionExec) Next(ctx context.Context, req *chunk.RecordBatch) error { if span := opentracing.SpanFromContext(ctx); span != nil && span.Tracer() != nil { span1 := span.Tracer().StartSpan("union.Next", opentracing.ChildOf(span.Context())) defer span1.Finish() } if e.runtimeStats != nil { start := time.Now() - defer func() { e.runtimeStats.Record(time.Now().Sub(start), chk.NumRows()) }() + defer func() { e.runtimeStats.Record(time.Now().Sub(start), req.NumRows()) }() } - chk.GrowAndReset(e.maxChunkSize) + req.GrowAndReset(e.maxChunkSize) if !e.initialized { e.initialize(ctx) e.initialized = true @@ -1228,7 +1228,7 @@ func (e *UnionExec) Next(ctx context.Context, chk *chunk.Chunk) error { return errors.Trace(result.err) } - chk.SwapColumns(result.chk) + req.SwapColumns(result.chk) result.src <- result.chk return nil } @@ -1272,7 +1272,7 @@ func ResetContextOfStmt(ctx sessionctx.Context, s ast.StmtNode) (err error) { // pushing them down to TiKV as flags. switch stmt := s.(type) { case *ast.UpdateStmt: - sc.InUpdateOrDeleteStmt = true + sc.InUpdateStmt = true sc.DupKeyAsWarning = stmt.IgnoreErr sc.BadNullAsWarning = !vars.StrictSQLMode || stmt.IgnoreErr sc.TruncateAsWarning = !vars.StrictSQLMode || stmt.IgnoreErr @@ -1280,7 +1280,7 @@ func ResetContextOfStmt(ctx sessionctx.Context, s ast.StmtNode) (err error) { sc.IgnoreZeroInDate = !vars.StrictSQLMode || stmt.IgnoreErr sc.Priority = stmt.Priority case *ast.DeleteStmt: - sc.InUpdateOrDeleteStmt = true + sc.InDeleteStmt = true sc.DupKeyAsWarning = stmt.IgnoreErr sc.BadNullAsWarning = !vars.StrictSQLMode || stmt.IgnoreErr sc.TruncateAsWarning = !vars.StrictSQLMode || stmt.IgnoreErr @@ -1301,6 +1301,7 @@ func ResetContextOfStmt(ctx sessionctx.Context, s ast.StmtNode) (err error) { sc.DupKeyAsWarning = true sc.BadNullAsWarning = true sc.TruncateAsWarning = !vars.StrictSQLMode + sc.InLoadDataStmt = true case *ast.SelectStmt: sc.InSelectStmt = true @@ -1341,7 +1342,7 @@ func ResetContextOfStmt(ctx sessionctx.Context, s ast.StmtNode) (err error) { sc.PrevLastInsertID = vars.StmtCtx.PrevLastInsertID } sc.PrevAffectedRows = 0 - if vars.StmtCtx.InUpdateOrDeleteStmt || vars.StmtCtx.InInsertStmt { + if vars.StmtCtx.InUpdateStmt || vars.StmtCtx.InDeleteStmt || vars.StmtCtx.InInsertStmt { sc.PrevAffectedRows = int64(vars.StmtCtx.AffectedRows()) } else if vars.StmtCtx.InSelectStmt { sc.PrevAffectedRows = -1 diff --git a/executor/executor_pkg_test.go b/executor/executor_pkg_test.go index 635c3b6bf136a..b454b6c8ed4b2 100644 --- a/executor/executor_pkg_test.go +++ b/executor/executor_pkg_test.go @@ -95,13 +95,13 @@ func (s *testExecSuite) TestShowProcessList(c *C) { it := chunk.NewIterator4Chunk(chk) // Run test and check results. for _, p := range ps { - err = e.Next(context.Background(), chk) + err = e.Next(context.Background(), chunk.NewRecordBatch(chk)) c.Assert(err, IsNil) for row := it.Begin(); row != it.End(); row = it.Next() { c.Assert(row.GetUint64(0), Equals, p.ID) } } - err = e.Next(context.Background(), chk) + err = e.Next(context.Background(), chunk.NewRecordBatch(chk)) c.Assert(err, IsNil) c.Assert(chk.NumRows(), Equals, 0) err = e.Close() diff --git a/executor/executor_test.go b/executor/executor_test.go index 5e5efbb4ce356..fc00607c1fb71 100644 --- a/executor/executor_test.go +++ b/executor/executor_test.go @@ -149,20 +149,20 @@ func (s *testSuite) TestAdmin(c *C) { // cancel DDL jobs test r, err := tk.Exec("admin cancel ddl jobs 1") c.Assert(err, IsNil, Commentf("err %v", err)) - chk := r.NewChunk() - err = r.Next(ctx, chk) + req := r.NewRecordBatch() + err = r.Next(ctx, req) c.Assert(err, IsNil) - row := chk.GetRow(0) + row := req.GetRow(0) c.Assert(row.Len(), Equals, 2) c.Assert(row.GetString(0), Equals, "1") c.Assert(row.GetString(1), Equals, "error: [admin:4]DDL Job:1 not found") r, err = tk.Exec("admin show ddl") c.Assert(err, IsNil) - chk = r.NewChunk() - err = r.Next(ctx, chk) + req = r.NewRecordBatch() + err = r.Next(ctx, req) c.Assert(err, IsNil) - row = chk.GetRow(0) + row = req.GetRow(0) c.Assert(row.Len(), Equals, 4) txn, err := s.store.Begin() c.Assert(err, IsNil) @@ -174,20 +174,20 @@ func (s *testSuite) TestAdmin(c *C) { // ownerInfos := strings.Split(ddlInfo.Owner.String(), ",") // c.Assert(rowOwnerInfos[0], Equals, ownerInfos[0]) c.Assert(row.GetString(2), Equals, "") - chk = r.NewChunk() - err = r.Next(ctx, chk) + req = r.NewRecordBatch() + err = r.Next(ctx, req) c.Assert(err, IsNil) - c.Assert(chk.NumRows() == 0, IsTrue) + c.Assert(req.NumRows() == 0, IsTrue) err = txn.Rollback() c.Assert(err, IsNil) // show DDL jobs test r, err = tk.Exec("admin show ddl jobs") c.Assert(err, IsNil) - chk = r.NewChunk() - err = r.Next(ctx, chk) + req = r.NewRecordBatch() + err = r.Next(ctx, req) c.Assert(err, IsNil) - row = chk.GetRow(0) + row = req.GetRow(0) c.Assert(row.Len(), Equals, 10) txn, err = s.store.Begin() c.Assert(err, IsNil) @@ -200,10 +200,10 @@ func (s *testSuite) TestAdmin(c *C) { r, err = tk.Exec("admin show ddl jobs 20") c.Assert(err, IsNil) - chk = r.NewChunk() - err = r.Next(ctx, chk) + req = r.NewRecordBatch() + err = r.Next(ctx, req) c.Assert(err, IsNil) - row = chk.GetRow(0) + row = req.GetRow(0) c.Assert(row.Len(), Equals, 10) c.Assert(row.GetInt64(0), Equals, historyJobs[0].ID) c.Assert(err, IsNil) @@ -342,6 +342,7 @@ func checkCases(tests []testCase, ld *executor.LoadDataInfo, c.Assert(ctx.NewTxn(context.Background()), IsNil) ctx.GetSessionVars().StmtCtx.DupKeyAsWarning = true ctx.GetSessionVars().StmtCtx.BadNullAsWarning = true + ctx.GetSessionVars().StmtCtx.InLoadDataStmt = true data, reachLimit, err1 := ld.InsertData(tt.data1, tt.data2) c.Assert(err1, IsNil) c.Assert(reachLimit, IsFalse) @@ -849,10 +850,10 @@ func (s *testSuite) TestIssue2612(c *C) { tk.MustExec(`insert into t values ('2016-02-13 15:32:24', '2016-02-11 17:23:22');`) rs, err := tk.Exec(`select timediff(finish_at, create_at) from t;`) c.Assert(err, IsNil) - chk := rs.NewChunk() - err = rs.Next(context.Background(), chk) + req := rs.NewRecordBatch() + err = rs.Next(context.Background(), req) c.Assert(err, IsNil) - c.Assert(chk.GetRow(0).GetDuration(0, 0).String(), Equals, "-46:09:02") + c.Assert(req.GetRow(0).GetDuration(0, 0).String(), Equals, "-46:09:02") rs.Close() } @@ -2608,10 +2609,10 @@ func (s *testSuite) TestBit(c *C) { c.Assert(err, NotNil) r, err := tk.Exec("select * from t where c1 = 2") c.Assert(err, IsNil) - chk := r.NewChunk() - err = r.Next(context.Background(), chk) + req := r.NewRecordBatch() + err = r.Next(context.Background(), req) c.Assert(err, IsNil) - c.Assert(types.BinaryLiteral(chk.GetRow(0).GetBytes(0)), DeepEquals, types.NewBinaryLiteralFromUint(2, -1)) + c.Assert(types.BinaryLiteral(req.GetRow(0).GetBytes(0)), DeepEquals, types.NewBinaryLiteralFromUint(2, -1)) r.Close() tk.MustExec("drop table if exists t") @@ -3237,7 +3238,7 @@ func (s *testSuite3) TestMaxOneRow(c *C) { rs, err := tk.Exec(`select (select t1.a from t1 where t1.a > t2.a) as a from t2;`) c.Assert(err, IsNil) - err = rs.Next(context.TODO(), rs.NewChunk()) + err = rs.Next(context.TODO(), rs.NewRecordBatch()) c.Assert(err.Error(), Equals, "subquery returns more than 1 row") err = rs.Close() @@ -3493,3 +3494,43 @@ func (s *testSuite3) TearDownTest(c *C) { tk.MustExec(fmt.Sprintf("drop table %v", tableName)) } } + +func (s *testSuite) TestStrToDateBuiltin(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustQuery(`select str_to_date('18/10/22','%y/%m/%d') from dual`).Check(testkit.Rows("2018-10-22")) + tk.MustQuery(`select str_to_date('a18/10/22','%y/%m/%d') from dual`).Check(testkit.Rows("")) + tk.MustQuery(`select str_to_date('69/10/22','%y/%m/%d') from dual`).Check(testkit.Rows("2069-10-22")) + tk.MustQuery(`select str_to_date('70/10/22','%y/%m/%d') from dual`).Check(testkit.Rows("1970-10-22")) + tk.MustQuery(`select str_to_date('8/10/22','%y/%m/%d') from dual`).Check(testkit.Rows("2008-10-22")) + tk.MustQuery(`select str_to_date('8/10/22','%Y/%m/%d') from dual`).Check(testkit.Rows("2008-10-22")) + tk.MustQuery(`select str_to_date('18/10/22','%Y/%m/%d') from dual`).Check(testkit.Rows("2018-10-22")) + tk.MustQuery(`select str_to_date('a18/10/22','%Y/%m/%d') from dual`).Check(testkit.Rows("")) + tk.MustQuery(`select str_to_date('69/10/22','%Y/%m/%d') from dual`).Check(testkit.Rows("2069-10-22")) + tk.MustQuery(`select str_to_date('70/10/22','%Y/%m/%d') from dual`).Check(testkit.Rows("1970-10-22")) + tk.MustQuery(`select str_to_date('018/10/22','%Y/%m/%d') from dual`).Check(testkit.Rows("0018-10-22")) + tk.MustQuery(`select str_to_date('2018/10/22','%Y/%m/%d') from dual`).Check(testkit.Rows("2018-10-22")) + tk.MustQuery(`select str_to_date('018/10/22','%y/%m/%d') from dual`).Check(testkit.Rows("")) + tk.MustQuery(`select str_to_date('18/10/22','%y0/%m/%d') from dual`).Check(testkit.Rows("")) + tk.MustQuery(`select str_to_date('18/10/22','%Y0/%m/%d') from dual`).Check(testkit.Rows("")) + tk.MustQuery(`select str_to_date('18a/10/22','%y/%m/%d') from dual`).Check(testkit.Rows("")) + tk.MustQuery(`select str_to_date('18a/10/22','%Y/%m/%d') from dual`).Check(testkit.Rows("")) + tk.MustQuery(`select str_to_date('20188/10/22','%Y/%m/%d') from dual`).Check(testkit.Rows("")) + tk.MustQuery(`select str_to_date('2018510522','%Y5%m5%d') from dual`).Check(testkit.Rows("2018-10-22")) + tk.MustQuery(`select str_to_date('2018^10^22','%Y^%m^%d') from dual`).Check(testkit.Rows("2018-10-22")) + tk.MustQuery(`select str_to_date('2018@10@22','%Y@%m@%d') from dual`).Check(testkit.Rows("2018-10-22")) + tk.MustQuery(`select str_to_date('2018%10%22','%Y%%m%%d') from dual`).Check(testkit.Rows("")) + tk.MustQuery(`select str_to_date('2018(10(22','%Y(%m(%d') from dual`).Check(testkit.Rows("2018-10-22")) + tk.MustQuery(`select str_to_date('2018\10\22','%Y\%m\%d') from dual`).Check(testkit.Rows("")) + tk.MustQuery(`select str_to_date('2018=10=22','%Y=%m=%d') from dual`).Check(testkit.Rows("2018-10-22")) + tk.MustQuery(`select str_to_date('2018+10+22','%Y+%m+%d') from dual`).Check(testkit.Rows("2018-10-22")) + tk.MustQuery(`select str_to_date('2018_10_22','%Y_%m_%d') from dual`).Check(testkit.Rows("2018-10-22")) + tk.MustQuery(`select str_to_date('69510522','%y5%m5%d') from dual`).Check(testkit.Rows("2069-10-22")) + tk.MustQuery(`select str_to_date('69^10^22','%y^%m^%d') from dual`).Check(testkit.Rows("2069-10-22")) + tk.MustQuery(`select str_to_date('18@10@22','%y@%m@%d') from dual`).Check(testkit.Rows("2018-10-22")) + tk.MustQuery(`select str_to_date('18%10%22','%y%%m%%d') from dual`).Check(testkit.Rows("")) + tk.MustQuery(`select str_to_date('18(10(22','%y(%m(%d') from dual`).Check(testkit.Rows("2018-10-22")) + tk.MustQuery(`select str_to_date('18\10\22','%y\%m\%d') from dual`).Check(testkit.Rows("")) + tk.MustQuery(`select str_to_date('18+10+22','%y+%m+%d') from dual`).Check(testkit.Rows("2018-10-22")) + tk.MustQuery(`select str_to_date('18=10=22','%y=%m=%d') from dual`).Check(testkit.Rows("2018-10-22")) + tk.MustQuery(`select str_to_date('18_10_22','%y_%m_%d') from dual`).Check(testkit.Rows("2018-10-22")) +} diff --git a/executor/explain.go b/executor/explain.go index 5e14443df664e..61ced6d564b62 100644 --- a/executor/explain.go +++ b/executor/explain.go @@ -46,7 +46,7 @@ func (e *ExplainExec) Close() error { } // Next implements the Executor Next interface. -func (e *ExplainExec) Next(ctx context.Context, chk *chunk.Chunk) error { +func (e *ExplainExec) Next(ctx context.Context, req *chunk.RecordBatch) error { if e.rows == nil { var err error e.rows, err = e.generateExplainInfo(ctx) @@ -55,15 +55,15 @@ func (e *ExplainExec) Next(ctx context.Context, chk *chunk.Chunk) error { } } - chk.GrowAndReset(e.maxChunkSize) + req.GrowAndReset(e.maxChunkSize) if e.cursor >= len(e.rows) { return nil } - numCurRows := mathutil.Min(chk.Capacity(), len(e.rows)-e.cursor) + numCurRows := mathutil.Min(req.Capacity(), len(e.rows)-e.cursor) for i := e.cursor; i < e.cursor+numCurRows; i++ { for j := range e.rows[i] { - chk.AppendString(j, e.rows[i][j]) + req.AppendString(j, e.rows[i][j]) } } e.cursor += numCurRows @@ -74,7 +74,7 @@ func (e *ExplainExec) generateExplainInfo(ctx context.Context) ([][]string, erro if e.analyzeExec != nil { chk := e.analyzeExec.newFirstChunk() for { - err := e.analyzeExec.Next(ctx, chk) + err := e.analyzeExec.Next(ctx, chunk.NewRecordBatch(chk)) if err != nil { return nil, err } diff --git a/executor/grant.go b/executor/grant.go index 4cfdae5f9f5c4..b98e68922a215 100644 --- a/executor/grant.go +++ b/executor/grant.go @@ -53,7 +53,7 @@ type GrantExec struct { } // Next implements the Executor Next interface. -func (e *GrantExec) Next(ctx context.Context, chk *chunk.Chunk) error { +func (e *GrantExec) Next(ctx context.Context, req *chunk.RecordBatch) error { if e.done { return nil } diff --git a/executor/index_lookup_join.go b/executor/index_lookup_join.go index d0e83c4f25f88..05fdb6b405d15 100644 --- a/executor/index_lookup_join.go +++ b/executor/index_lookup_join.go @@ -211,12 +211,12 @@ func (e *IndexLookUpJoin) newInnerWorker(taskCh chan *lookUpJoinTask) *innerWork } // Next implements the Executor interface. -func (e *IndexLookUpJoin) Next(ctx context.Context, chk *chunk.Chunk) error { +func (e *IndexLookUpJoin) Next(ctx context.Context, req *chunk.RecordBatch) error { if e.runtimeStats != nil { start := time.Now() - defer func() { e.runtimeStats.Record(time.Now().Sub(start), chk.NumRows()) }() + defer func() { e.runtimeStats.Record(time.Now().Sub(start), req.NumRows()) }() } - chk.Reset() + req.Reset() e.joinResult.Reset() for { task, err := e.getFinishedTask(ctx) @@ -234,7 +234,7 @@ func (e *IndexLookUpJoin) Next(ctx context.Context, chk *chunk.Chunk) error { outerRow := task.outerResult.GetRow(task.cursor) if e.innerIter.Current() != e.innerIter.End() { - matched, err := e.joiner.tryToMatch(outerRow, e.innerIter, chk) + matched, err := e.joiner.tryToMatch(outerRow, e.innerIter, req.Chunk) if err != nil { return errors.Trace(err) } @@ -242,12 +242,12 @@ func (e *IndexLookUpJoin) Next(ctx context.Context, chk *chunk.Chunk) error { } if e.innerIter.Current() == e.innerIter.End() { if !task.hasMatch { - e.joiner.onMissMatch(outerRow, chk) + e.joiner.onMissMatch(outerRow, req.Chunk) } task.cursor++ task.hasMatch = false } - if chk.NumRows() == e.maxChunkSize { + if req.NumRows() == e.maxChunkSize { return nil } } @@ -359,7 +359,7 @@ func (ow *outerWorker) buildTask(ctx context.Context) (*lookUpJoinTask, error) { task.memTracker.Consume(task.outerResult.MemoryUsage()) for task.outerResult.NumRows() < ow.batchSize { - err := ow.executor.Next(ctx, ow.executorChk) + err := ow.executor.Next(ctx, chunk.NewRecordBatch(ow.executorChk)) if err != nil { return task, errors.Trace(err) } @@ -547,7 +547,7 @@ func (iw *innerWorker) fetchInnerResults(ctx context.Context, task *lookUpJoinTa innerResult.GetMemTracker().SetLabel("inner result") innerResult.GetMemTracker().AttachTo(task.memTracker) for { - err := innerExec.Next(ctx, iw.executorChk) + err := innerExec.Next(ctx, chunk.NewRecordBatch(iw.executorChk)) if err != nil { return errors.Trace(err) } diff --git a/executor/index_lookup_join_test.go b/executor/index_lookup_join_test.go index d2f494f84c962..6ab45c7652962 100644 --- a/executor/index_lookup_join_test.go +++ b/executor/index_lookup_join_test.go @@ -31,9 +31,9 @@ func (s *testSuite1) TestIndexLookupJoinHang(c *C) { rs, err := tk.Exec("select /*+ TIDB_INLJ(i)*/ * from idxJoinOuter o left join idxJoinInner i on o.a = i.a where o.a in (1, 2) and (i.a - 3) > 0") c.Assert(err, IsNil) - chk := rs.NewChunk() + req := rs.NewRecordBatch() for i := 0; i < 5; i++ { - rs.Next(context.Background(), chk) + rs.Next(context.Background(), req) } rs.Close() } diff --git a/executor/insert.go b/executor/insert.go index e258db753e6a7..fac4723471f91 100644 --- a/executor/insert.go +++ b/executor/insert.go @@ -135,13 +135,13 @@ func (e *InsertExec) batchUpdateDupRows(newRows [][]types.Datum) error { } // Next implements the Executor Next interface. -func (e *InsertExec) Next(ctx context.Context, chk *chunk.Chunk) error { +func (e *InsertExec) Next(ctx context.Context, req *chunk.RecordBatch) error { if span := opentracing.SpanFromContext(ctx); span != nil && span.Tracer() != nil { span1 := span.Tracer().StartSpan("insert.Next", opentracing.ChildOf(span.Context())) defer span1.Finish() } - chk.Reset() + req.Reset() if len(e.children) > 0 && e.children[0] != nil { return e.insertRowsFromSelect(ctx, e.exec) } diff --git a/executor/insert_common.go b/executor/insert_common.go index 03d666ced463c..a20268c7e4422 100644 --- a/executor/insert_common.go +++ b/executor/insert_common.go @@ -307,7 +307,7 @@ func (e *InsertValues) insertRowsFromSelect(ctx context.Context, exec func(ctx c batchSize := sessVars.DMLBatchSize for { - err := selectExec.Next(ctx, chk) + err := selectExec.Next(ctx, chunk.NewRecordBatch(chk)) if err != nil { return errors.Trace(err) } diff --git a/executor/join.go b/executor/join.go index a9bc7f24e0d6b..c0fd8199356d9 100644 --- a/executor/join.go +++ b/executor/join.go @@ -15,7 +15,6 @@ package executor import ( "context" - "math" "sync" "sync/atomic" "time" @@ -31,7 +30,6 @@ import ( "github.com/pingcap/tidb/util/codec" "github.com/pingcap/tidb/util/memory" "github.com/pingcap/tidb/util/mvmap" - "github.com/spaolacci/murmur3" ) var ( @@ -52,7 +50,7 @@ type HashJoinExec struct { prepared bool // concurrency is the number of partition, build and join workers. concurrency uint - hashTable *mvmap.MVMap + globalHashTable *mvmap.MVMap innerFinished chan error hashJoinBuffers []*hashJoinBuffer // joinWorkerWaitGroup is for sync multiple join workers. @@ -77,37 +75,8 @@ type HashJoinExec struct { hashTableValBufs [][][]byte memTracker *memory.Tracker // track memory usage. - - // radixBits indicates the bits using for radix partitioning. Inner relation - // will be split to 2^radixBitsNumber sub-relations before building the hash - // tables. If the complete inner relation can be hold in L2Cache in which - // case radixBits will be 1, we can skip the partition phase. - // Note: We actually check whether `size of sub inner relation < 3/4 * L2 - // cache size` to make sure one inner sub-relation, hash table, one outer - // sub-relation and join result of the sub-relations can be totally loaded - // in L2 cache size. `3/4` is a magic number, we may adjust it after - // benchmark. - radixBits uint32 - innerParts []partition - // innerRowPrts indicates the position in corresponding partition of every - // row in innerResult. - innerRowPrts [][]partRowPtr } -// partition stores the sub-relations of inner relation and outer relation after -// partition phase. Every partition can be fully stored in L2 cache thus can -// reduce the cache miss ratio when building and probing the hash table. -type partition = *chunk.Chunk - -// partRowPtr stores the actual index in `innerParts` or `outerParts`. -type partRowPtr struct { - partitionIdx uint32 - rowIdx uint32 -} - -// partPtr4NullKey indicates a partition pointer which points to a row with null-join-key. -var partPtr4NullKey = partRowPtr{math.MaxUint32, math.MaxUint32} - // outerChkResource stores the result of the join outer fetch worker, // `dest` is for Chunk reuse: after join workers process the outer chunk which is read from `dest`, // they'll store the used chunk as `chk`, and then the outer fetch worker will put new data into `chk` and write `chk` into dest. @@ -188,6 +157,11 @@ func (e *HashJoinExec) Open(ctx context.Context) error { e.hashJoinBuffers = append(e.hashJoinBuffers, buffer) } + e.innerKeyColIdx = make([]int, len(e.innerKeys)) + for i := range e.innerKeys { + e.innerKeyColIdx[i] = e.innerKeys[i].Index + } + e.closeCh = make(chan struct{}) e.finished.Store(false) e.joinWorkerWaitGroup = sync.WaitGroup{} @@ -238,7 +212,7 @@ func (e *HashJoinExec) fetchOuterChunks(ctx context.Context) { } } outerResult := outerResource.chk - err := e.outerExec.Next(ctx, outerResult) + err := e.outerExec.Next(ctx, chunk.NewRecordBatch(outerResult)) if err != nil { e.joinResultCh <- &hashjoinWorkerResult{ err: errors.Trace(err), @@ -278,146 +252,30 @@ func (e *HashJoinExec) wait4Inner() (finished bool, err error) { return false, errors.Trace(err) } } - if e.hashTable.Len() == 0 && (e.joinType == plannercore.InnerJoin || e.joinType == plannercore.SemiJoin) { + if e.innerResult.Len() == 0 && (e.joinType == plannercore.InnerJoin || e.joinType == plannercore.SemiJoin) { return true, nil } return false, nil } -// fetchInnerRows fetches all rows from inner executor, -// and append them to e.innerResult. -func (e *HashJoinExec) fetchInnerRows(ctx context.Context, chkCh chan<- *chunk.Chunk, doneCh chan struct{}) { - defer close(chkCh) +// fetchInnerRows fetches all rows from inner executor, and append them to +// e.innerResult. +func (e *HashJoinExec) fetchInnerRows(ctx context.Context) error { e.innerResult = chunk.NewList(e.innerExec.retTypes(), e.initCap, e.maxChunkSize) e.innerResult.GetMemTracker().AttachTo(e.memTracker) e.innerResult.GetMemTracker().SetLabel("innerResult") var err error for { - select { - case <-doneCh: - return - case <-e.closeCh: - return - default: - if e.finished.Load().(bool) { - return - } - chk := e.children[e.innerIdx].newFirstChunk() - err = e.innerExec.Next(ctx, chk) - if err != nil { - e.innerFinished <- errors.Trace(err) - return - } - if chk.NumRows() == 0 { - return - } - select { - case chkCh <- chk: - break - case <-e.closeCh: - return - } - e.innerResult.Add(chk) - } - } -} - -// partitionInnerRows re-order e.innerResults into sub-relations. -func (e *HashJoinExec) partitionInnerRows() error { - if err := e.preAlloc4InnerParts(); err != nil { - return err - } - - wg := sync.WaitGroup{} - defer wg.Wait() - wg.Add(int(e.concurrency)) - for i := 0; i < int(e.concurrency); i++ { - workerID := i - go util.WithRecovery(func() { - defer wg.Done() - e.doInnerPartition(workerID) - }, e.handlePartitionPanic) - } - return nil -} - -func (e *HashJoinExec) handlePartitionPanic(r interface{}) { - if r != nil { - e.joinResultCh <- &hashjoinWorkerResult{err: errors.Errorf("%v", r)} - } -} - -// doInnerPartition runs concurrently, partitions and copies the inner relation -// to several pre-allocated data partitions. The input inner Chunk idx for each -// partitioner is workerId + x*numPartitioners. -func (e *HashJoinExec) doInnerPartition(workerID int) { - chkIdx, chkNum := workerID, e.innerResult.NumChunks() - for ; chkIdx < chkNum; chkIdx += int(e.concurrency) { - chk := e.innerResult.GetChunk(chkIdx) - for srcRowIdx, partPtr := range e.innerRowPrts[chkIdx] { - if partPtr == partPtr4NullKey { - continue - } - partIdx, destRowIdx := partPtr.partitionIdx, partPtr.rowIdx - part := e.innerParts[partIdx] - part.Insert(int(destRowIdx), chk.GetRow(srcRowIdx)) + if e.finished.Load().(bool) { + return nil } - } -} - -// preAlloc4InnerParts evaluates partRowPtr and pre-alloc the memory space -// for every inner row to help re-order the inner relation. -// TODO: we need to evaluate the skewness for the partitions size, if the -// skewness exceeds a threshold, we do not use partition phase. -func (e *HashJoinExec) preAlloc4InnerParts() (err error) { - hasNull, keyBuf := false, make([]byte, 0, 64) - for chkIdx, chkNum := 0, e.innerResult.NumChunks(); chkIdx < chkNum; chkIdx++ { - chk := e.innerResult.GetChunk(chkIdx) - partPtrs := make([]partRowPtr, chk.NumRows()) - for rowIdx := 0; rowIdx < chk.NumRows(); rowIdx++ { - row := chk.GetRow(rowIdx) - hasNull, keyBuf, err = e.getJoinKeyFromChkRow(false, row, keyBuf) - if err != nil { - return err - } - if hasNull { - partPtrs[rowIdx] = partPtr4NullKey - continue - } - joinHash := murmur3.Sum32(keyBuf) - partIdx := e.radixBits & joinHash - partPtrs[rowIdx].partitionIdx = partIdx - partPtrs[rowIdx].rowIdx = e.getPartition(partIdx).PreAlloc(row) + chk := e.children[e.innerIdx].newFirstChunk() + err = e.innerExec.Next(ctx, chunk.NewRecordBatch(chk)) + if err != nil || chk.NumRows() == 0 { + return err } - e.innerRowPrts = append(e.innerRowPrts, partPtrs) - } - return -} - -func (e *HashJoinExec) getPartition(idx uint32) partition { - if e.innerParts[idx] == nil { - e.innerParts[idx] = chunk.New(e.innerExec.retTypes(), e.initCap, e.maxChunkSize) + e.innerResult.Add(chk) } - return e.innerParts[idx] -} - -// evalRadixBit evaluates the radix bit numbers. -// To ensure that one partition of inner relation, one hash table, one partition -// of outer relation and the join result of these two partitions fit into the L2 -// cache when the input data obeys the uniform distribution, we suppose every -// sub-partition of inner relation using three quarters of the L2 cache size. -func (e *HashJoinExec) evalRadixBit() (needPartition bool) { - sv := e.ctx.GetSessionVars() - innerResultSize := float64(e.innerResult.GetMemTracker().BytesConsumed()) - l2CacheSize := float64(sv.L2CacheSize) * 3 / 4 - radixBitsNum := uint(math.Ceil(math.Log2(innerResultSize / l2CacheSize))) - if radixBitsNum <= 0 { - return false - } - // Take the rightmost radixBitsNum bits as the bitmask. - e.radixBits = ^(math.MaxUint32 << radixBitsNum) - e.innerParts = make([]partition, 1<= e.outerChunk.NumRows() { - err := e.outerExec.Next(ctx, e.outerChunk) + err := e.outerExec.Next(ctx, chunk.NewRecordBatch(e.outerChunk)) if err != nil { return nil, errors.Trace(err) } @@ -839,7 +666,7 @@ func (e *NestedLoopApplyExec) fetchAllInners(ctx context.Context) error { e.innerList.Reset() innerIter := chunk.NewIterator4Chunk(e.innerChunk) for { - err := e.innerExec.Next(ctx, e.innerChunk) + err := e.innerExec.Next(ctx, chunk.NewRecordBatch(e.innerChunk)) if err != nil { return errors.Trace(err) } @@ -860,18 +687,18 @@ func (e *NestedLoopApplyExec) fetchAllInners(ctx context.Context) error { } // Next implements the Executor interface. -func (e *NestedLoopApplyExec) Next(ctx context.Context, chk *chunk.Chunk) (err error) { +func (e *NestedLoopApplyExec) Next(ctx context.Context, req *chunk.RecordBatch) (err error) { if e.runtimeStats != nil { start := time.Now() - defer func() { e.runtimeStats.Record(time.Now().Sub(start), chk.NumRows()) }() + defer func() { e.runtimeStats.Record(time.Now().Sub(start), req.NumRows()) }() } - chk.Reset() + req.Reset() for { if e.innerIter == nil || e.innerIter.Current() == e.innerIter.End() { if e.outerRow != nil && !e.hasMatch { - e.joiner.onMissMatch(*e.outerRow, chk) + e.joiner.onMissMatch(*e.outerRow, req.Chunk) } - e.outerRow, err = e.fetchSelectedOuterRow(ctx, chk) + e.outerRow, err = e.fetchSelectedOuterRow(ctx, req.Chunk) if e.outerRow == nil || err != nil { return errors.Trace(err) } @@ -888,10 +715,10 @@ func (e *NestedLoopApplyExec) Next(ctx context.Context, chk *chunk.Chunk) (err e e.innerIter.Begin() } - matched, err := e.joiner.tryToMatch(*e.outerRow, e.innerIter, chk) + matched, err := e.joiner.tryToMatch(*e.outerRow, e.innerIter, req.Chunk) e.hasMatch = e.hasMatch || matched - if err != nil || chk.NumRows() == e.maxChunkSize { + if err != nil || req.NumRows() == e.maxChunkSize { return errors.Trace(err) } } diff --git a/executor/join_test.go b/executor/join_test.go index 51100540d5f78..b70764ac52282 100644 --- a/executor/join_test.go +++ b/executor/join_test.go @@ -759,8 +759,8 @@ func (s *testSuite2) TestJoinLeak(c *C) { tk.MustExec("commit") result, err := tk.Exec("select * from t t1 left join (select 1) t2 on 1") c.Assert(err, IsNil) - chk := result.NewChunk() - err = result.Next(context.Background(), chk) + req := result.NewRecordBatch() + err = result.Next(context.Background(), req) c.Assert(err, IsNil) time.Sleep(time.Millisecond) result.Close() @@ -989,3 +989,22 @@ func (s *testSuite2) TestHashJoin(c *C) { innerExecInfo = row[3][4].(string) c.Assert(innerExecInfo[len(innerExecInfo)-1:], LessEqual, "5") } + +func (s *testSuite2) TestJoinDifferentDecimals(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("Use test") + tk.MustExec("Drop table if exists t1") + tk.MustExec("Create table t1 (v int)") + tk.MustExec("Insert into t1 value (1)") + tk.MustExec("Insert into t1 value (2)") + tk.MustExec("Insert into t1 value (3)") + tk.MustExec("Drop table if exists t2") + tk.MustExec("Create table t2 (v decimal(12, 3))") + tk.MustExec("Insert into t2 value (1)") + tk.MustExec("Insert into t2 value (2.0)") + tk.MustExec("Insert into t2 value (000003.000000)") + rst := tk.MustQuery("Select * from t1, t2 where t1.v = t2.v order by t1.v") + row := rst.Rows() + c.Assert(len(row), Equals, 3) + rst.Check(testkit.Rows("1 1.000", "2 2.000", "3 3.000")) +} diff --git a/executor/load_data.go b/executor/load_data.go index df81f8ccd2730..8246212d80804 100644 --- a/executor/load_data.go +++ b/executor/load_data.go @@ -49,8 +49,8 @@ func NewLoadDataInfo(ctx sessionctx.Context, row []types.Datum, tbl table.Table, } // Next implements the Executor Next interface. -func (e *LoadDataExec) Next(ctx context.Context, chk *chunk.Chunk) error { - chk.GrowAndReset(e.maxChunkSize) +func (e *LoadDataExec) Next(ctx context.Context, req *chunk.RecordBatch) error { + req.GrowAndReset(e.maxChunkSize) // TODO: support load data without local field. if !e.IsLocal { return errors.New("Load Data: don't support load data without local field") @@ -202,7 +202,7 @@ func (e *LoadDataInfo) getLine(prevData, curData []byte) ([]byte, []byte, bool) } // InsertData inserts data into specified table according to the specified format. -// If it has the rest of data isn't completed the processing, then is returns without completed data. +// If it has the rest of data isn't completed the processing, then it returns without completed data. // If the number of inserted rows reaches the batchRows, then the second return value is true. // If prevData isn't nil and curData is nil, there are no other data to deal with and the isEOF is true. func (e *LoadDataInfo) InsertData(prevData, curData []byte) ([]byte, bool, error) { diff --git a/executor/load_stats.go b/executor/load_stats.go index b7799c17dba64..34bdb47117b99 100644 --- a/executor/load_stats.go +++ b/executor/load_stats.go @@ -50,8 +50,8 @@ func (k loadStatsVarKeyType) String() string { const LoadStatsVarKey loadStatsVarKeyType = 0 // Next implements the Executor Next interface. -func (e *LoadStatsExec) Next(ctx context.Context, chk *chunk.Chunk) error { - chk.GrowAndReset(e.maxChunkSize) +func (e *LoadStatsExec) Next(ctx context.Context, req *chunk.RecordBatch) error { + req.GrowAndReset(e.maxChunkSize) if len(e.info.Path) == 0 { return errors.New("Load Stats: file path is empty") } diff --git a/executor/merge_join.go b/executor/merge_join.go index 14456561f826b..30eb1d15fba78 100644 --- a/executor/merge_join.go +++ b/executor/merge_join.go @@ -138,7 +138,7 @@ func (t *mergeJoinInnerTable) nextRow() (chunk.Row, error) { if t.curRow == t.curIter.End() { t.reallocReaderResult() oldMemUsage := t.curResult.MemoryUsage() - err := t.reader.Next(t.ctx, t.curResult) + err := t.reader.Next(t.ctx, chunk.NewRecordBatch(t.curResult)) // error happens or no more data. if err != nil || t.curResult.NumRows() == 0 { t.curRow = t.curIter.End() @@ -262,20 +262,20 @@ func (e *MergeJoinExec) prepare(ctx context.Context, chk *chunk.Chunk) error { } // Next implements the Executor Next interface. -func (e *MergeJoinExec) Next(ctx context.Context, chk *chunk.Chunk) error { +func (e *MergeJoinExec) Next(ctx context.Context, req *chunk.RecordBatch) error { if e.runtimeStats != nil { start := time.Now() - defer func() { e.runtimeStats.Record(time.Now().Sub(start), chk.NumRows()) }() + defer func() { e.runtimeStats.Record(time.Now().Sub(start), req.NumRows()) }() } - chk.Reset() + req.Reset() if !e.prepared { - if err := e.prepare(ctx, chk); err != nil { + if err := e.prepare(ctx, req.Chunk); err != nil { return errors.Trace(err) } } - for chk.NumRows() < e.maxChunkSize { - hasMore, err := e.joinToChunk(ctx, chk) + for req.NumRows() < e.maxChunkSize { + hasMore, err := e.joinToChunk(ctx, req.Chunk) if err != nil || !hasMore { return errors.Trace(err) } @@ -355,7 +355,7 @@ func (e *MergeJoinExec) fetchNextInnerRows() (err error) { // may not all belong to the same join key, but are guaranteed to be sorted // according to the join key. func (e *MergeJoinExec) fetchNextOuterRows(ctx context.Context) (err error) { - err = e.outerTable.reader.Next(ctx, e.outerTable.chk) + err = e.outerTable.reader.Next(ctx, chunk.NewRecordBatch(e.outerTable.chk)) if err != nil { return errors.Trace(err) } diff --git a/executor/pkg_test.go b/executor/pkg_test.go index 6bda2e044f6c7..be624dc138ad2 100644 --- a/executor/pkg_test.go +++ b/executor/pkg_test.go @@ -32,14 +32,14 @@ type MockExec struct { curRowIdx int } -func (m *MockExec) Next(ctx context.Context, chk *chunk.Chunk) error { - chk.Reset() +func (m *MockExec) Next(ctx context.Context, req *chunk.RecordBatch) error { + req.Reset() colTypes := m.retTypes() - for ; m.curRowIdx < len(m.Rows) && chk.NumRows() < chk.Capacity(); m.curRowIdx++ { + for ; m.curRowIdx < len(m.Rows) && req.NumRows() < req.Capacity(); m.curRowIdx++ { curRow := m.Rows[m.curRowIdx] for i := 0; i < curRow.Len(); i++ { curDatum := curRow.ToRow().GetDatum(i, colTypes[i]) - chk.AppendDatum(i, &curDatum) + req.AppendDatum(i, &curDatum) } } return nil @@ -103,7 +103,7 @@ func (s *pkgTestSuite) TestNestedLoopApply(c *C) { joinChk := join.newFirstChunk() it := chunk.NewIterator4Chunk(joinChk) for rowIdx := 1; ; { - err := join.Next(ctx, joinChk) + err := join.Next(ctx, chunk.NewRecordBatch(joinChk)) c.Check(err, IsNil) if joinChk.NumRows() == 0 { break @@ -129,7 +129,7 @@ func prepareOneColChildExec(sctx sessionctx.Context, rowCount int) Executor { return exec } -func prepare4RadixPartition(sctx sessionctx.Context, rowCount int) *HashJoinExec { +func buildExec4RadixHashJoin(sctx sessionctx.Context, rowCount int) *RadixHashJoinExec { childExec0 := prepareOneColChildExec(sctx, rowCount) childExec1 := prepareOneColChildExec(sctx, rowCount) @@ -148,12 +148,12 @@ func prepare4RadixPartition(sctx sessionctx.Context, rowCount int) *HashJoinExec innerExec: childExec0, outerExec: childExec1, } - return hashJoinExec + return &RadixHashJoinExec{HashJoinExec: hashJoinExec} } func (s *pkgTestSuite) TestRadixPartition(c *C) { sctx := mock.NewContext() - hashJoinExec := prepare4RadixPartition(sctx, 200) + hashJoinExec := buildExec4RadixHashJoin(sctx, 200) sv := sctx.GetSessionVars() originL2CacheSize, originEnableRadixJoin, originMaxChunkSize := sv.L2CacheSize, sv.EnableRadixJoin, sv.MaxChunkSize sv.L2CacheSize = 100 @@ -169,14 +169,7 @@ func (s *pkgTestSuite) TestRadixPartition(c *C) { err := hashJoinExec.Open(ctx) c.Assert(err, IsNil) - innerResultCh := make(chan *chunk.Chunk, 1) - doneCh := make(chan struct{}) - go func() { - for range innerResultCh { - } - }() - - hashJoinExec.fetchInnerRows(ctx, innerResultCh, doneCh) + hashJoinExec.fetchInnerRows(ctx) c.Assert(hashJoinExec.innerResult.GetMemTracker().BytesConsumed(), Equals, int64(14400)) hashJoinExec.evalRadixBit() @@ -247,7 +240,7 @@ func (s *pkgTestSuite) TestMoveInfoSchemaToFront(c *C) { func BenchmarkPartitionInnerRows(b *testing.B) { sctx := mock.NewContext() - hashJoinExec := prepare4RadixPartition(sctx, 1500000) + hashJoinExec := buildExec4RadixHashJoin(sctx, 1500000) sv := sctx.GetSessionVars() originL2CacheSize, originEnableRadixJoin, originMaxChunkSize := sv.L2CacheSize, sv.EnableRadixJoin, sv.MaxChunkSize sv.L2CacheSize = cpuid.CPU.Cache.L2 @@ -260,14 +253,7 @@ func BenchmarkPartitionInnerRows(b *testing.B) { ctx := context.Background() hashJoinExec.Open(ctx) - innerResultCh := make(chan *chunk.Chunk, 1) - doneCh := make(chan struct{}) - go func() { - for range innerResultCh { - } - }() - - hashJoinExec.fetchInnerRows(ctx, innerResultCh, doneCh) + hashJoinExec.fetchInnerRows(ctx) hashJoinExec.evalRadixBit() b.ResetTimer() hashJoinExec.concurrency = 16 @@ -278,3 +264,29 @@ func BenchmarkPartitionInnerRows(b *testing.B) { hashJoinExec.innerRowPrts = hashJoinExec.innerRowPrts[:0] } } + +func (s *pkgTestSuite) TestParallelBuildHashTable4RadixJoin(c *C) { + sctx := mock.NewContext() + hashJoinExec := buildExec4RadixHashJoin(sctx, 200) + + sv := sctx.GetSessionVars() + sv.L2CacheSize = 100 + sv.EnableRadixJoin = true + sv.MaxChunkSize = 100 + sv.StmtCtx.MemTracker = memory.NewTracker("RootMemTracker", variable.DefTiDBMemQuotaHashJoin) + + ctx := context.Background() + err := hashJoinExec.Open(ctx) + c.Assert(err, IsNil) + + hashJoinExec.partitionInnerAndBuildHashTables(ctx) + innerParts := hashJoinExec.innerParts + c.Assert(len(hashJoinExec.hashTables), Equals, len(innerParts)) + for i := 0; i < len(innerParts); i++ { + if innerParts[i] == nil { + c.Assert(hashJoinExec.hashTables[i], IsNil) + } else { + c.Assert(hashJoinExec.hashTables[i], NotNil) + } + } +} diff --git a/executor/point_get.go b/executor/point_get.go index 920538cc12ed0..1dfa5a32783e3 100644 --- a/executor/point_get.go +++ b/executor/point_get.go @@ -73,8 +73,8 @@ func (e *PointGetExecutor) Close() error { } // Next implements the Executor interface. -func (e *PointGetExecutor) Next(ctx context.Context, chk *chunk.Chunk) error { - chk.Reset() +func (e *PointGetExecutor) Next(ctx context.Context, req *chunk.RecordBatch) error { + req.Reset() if e.done { return nil } @@ -113,7 +113,7 @@ func (e *PointGetExecutor) Next(ctx context.Context, chk *chunk.Chunk) error { } return nil } - return e.decodeRowValToChunk(val, chk) + return e.decodeRowValToChunk(val, req.Chunk) } func (e *PointGetExecutor) encodeIndexKey() ([]byte, error) { diff --git a/executor/prepared.go b/executor/prepared.go index 28445b1c770da..f35502cf01aad 100644 --- a/executor/prepared.go +++ b/executor/prepared.go @@ -94,7 +94,7 @@ func NewPrepareExec(ctx sessionctx.Context, is infoschema.InfoSchema, sqlTxt str } // Next implements the Executor Next interface. -func (e *PrepareExec) Next(ctx context.Context, chk *chunk.Chunk) error { +func (e *PrepareExec) Next(ctx context.Context, req *chunk.RecordBatch) error { vars := e.ctx.GetSessionVars() if e.ID != 0 { // Must be the case when we retry a prepare. @@ -201,7 +201,7 @@ type ExecuteExec struct { } // Next implements the Executor Next interface. -func (e *ExecuteExec) Next(ctx context.Context, chk *chunk.Chunk) error { +func (e *ExecuteExec) Next(ctx context.Context, req *chunk.RecordBatch) error { return nil } @@ -237,7 +237,7 @@ type DeallocateExec struct { } // Next implements the Executor Next interface. -func (e *DeallocateExec) Next(ctx context.Context, chk *chunk.Chunk) error { +func (e *DeallocateExec) Next(ctx context.Context, req *chunk.RecordBatch) error { vars := e.ctx.GetSessionVars() id, ok := vars.PreparedStmtNameToID[e.Name] if !ok { diff --git a/executor/projection.go b/executor/projection.go index bedc7eb952b29..a54aa74758377 100644 --- a/executor/projection.go +++ b/executor/projection.go @@ -143,7 +143,7 @@ func (e *ProjectionExec) Open(ctx context.Context) error { // | | | | // +------------------------------+ +----------------------+ // -func (e *ProjectionExec) Next(ctx context.Context, chk *chunk.Chunk) error { +func (e *ProjectionExec) Next(ctx context.Context, req *chunk.RecordBatch) error { if span := opentracing.SpanFromContext(ctx); span != nil && span.Tracer() != nil { span1 := span.Tracer().StartSpan("projection.Next", opentracing.ChildOf(span.Context())) defer span1.Finish() @@ -151,13 +151,13 @@ func (e *ProjectionExec) Next(ctx context.Context, chk *chunk.Chunk) error { if e.runtimeStats != nil { start := time.Now() - defer func() { e.runtimeStats.Record(time.Now().Sub(start), chk.NumRows()) }() + defer func() { e.runtimeStats.Record(time.Now().Sub(start), req.NumRows()) }() } - chk.GrowAndReset(e.maxChunkSize) + req.GrowAndReset(e.maxChunkSize) if e.isUnparallelExec() { - return errors.Trace(e.unParallelExecute(ctx, chk)) + return errors.Trace(e.unParallelExecute(ctx, req.Chunk)) } - return errors.Trace(e.parallelExecute(ctx, chk)) + return errors.Trace(e.parallelExecute(ctx, req.Chunk)) } @@ -166,7 +166,7 @@ func (e *ProjectionExec) isUnparallelExec() bool { } func (e *ProjectionExec) unParallelExecute(ctx context.Context, chk *chunk.Chunk) error { - err := e.children[0].Next(ctx, e.childResult) + err := e.children[0].Next(ctx, chunk.NewRecordBatch(e.childResult)) if err != nil { return errors.Trace(err) } @@ -294,7 +294,7 @@ func (f *projectionInputFetcher) run(ctx context.Context) { f.globalOutputCh <- output - err := f.child.Next(ctx, input.chk) + err := f.child.Next(ctx, chunk.NewRecordBatch(input.chk)) if err != nil || input.chk.NumRows() == 0 { output.done <- errors.Trace(err) return diff --git a/executor/radix_hash_join.go b/executor/radix_hash_join.go new file mode 100644 index 0000000000000..a5dc9e52a4d34 --- /dev/null +++ b/executor/radix_hash_join.go @@ -0,0 +1,270 @@ +// Copyright 2019 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package executor + +import ( + "context" + "math" + "sync" + "time" + "unsafe" + + "github.com/pingcap/errors" + "github.com/pingcap/tidb/util" + "github.com/pingcap/tidb/util/chunk" + "github.com/pingcap/tidb/util/mvmap" + log "github.com/sirupsen/logrus" + "github.com/spaolacci/murmur3" +) + +var ( + _ Executor = &RadixHashJoinExec{} +) + +// RadixHashJoinExec implements the radix partition-based hash join algorithm. +// It will partition the input relations into small pairs of partitions where +// one of the partitions typically fits into one of the caches. The overall goal +// of this method is to minimize the number of cache misses when building and +// probing hash tables. +type RadixHashJoinExec struct { + *HashJoinExec + + // radixBits indicates the bits using for radix partitioning. Inner relation + // will be split to 2^radixBitsNumber sub-relations before building the hash + // tables. If the complete inner relation can be hold in L2Cache in which + // case radixBits will be 1, we can skip the partition phase. + // Note: We actually check whether `size of sub inner relation < 3/4 * L2 + // cache size` to make sure one inner sub-relation, hash table, one outer + // sub-relation and join result of the sub-relations can be totally loaded + // in L2 cache size. `3/4` is a magic number, we may adjust it after + // benchmark. + radixBits uint32 + innerParts []partition + numNonEmptyPart int + // innerRowPrts indicates the position in corresponding partition of every + // row in innerResult. + innerRowPrts [][]partRowPtr + // hashTables stores the hash tables built from the inner relation, if there + // is no partition phase, a global hash table will be stored in + // hashTables[0]. + hashTables []*mvmap.MVMap +} + +// partition stores the sub-relations of inner relation and outer relation after +// partition phase. Every partition can be fully stored in L2 cache thus can +// reduce the cache miss ratio when building and probing the hash table. +type partition = *chunk.Chunk + +// partRowPtr stores the actual index in `innerParts` or `outerParts`. +type partRowPtr struct { + partitionIdx uint32 + rowIdx uint32 +} + +// partPtr4NullKey indicates a partition pointer which points to a row with null-join-key. +var partPtr4NullKey = partRowPtr{math.MaxUint32, math.MaxUint32} + +// Next implements the Executor Next interface. +// radix hash join constructs the result following these steps: +// step 1. fetch data from inner child +// step 2. parallel partition the inner relation into sub-relations and build an +// individual hash table for every partition +// step 3. fetch data from outer child in a background goroutine and partition +// it into sub-relations +// step 4. probe the corresponded sub-hash-table for every sub-outer-relation in +// multiple join workers +func (e *RadixHashJoinExec) Next(ctx context.Context, req *chunk.RecordBatch) (err error) { + if e.runtimeStats != nil { + start := time.Now() + defer func() { e.runtimeStats.Record(time.Now().Sub(start), req.NumRows()) }() + } + if !e.prepared { + e.innerFinished = make(chan error, 1) + go util.WithRecovery(func() { e.partitionInnerAndBuildHashTables(ctx) }, e.handleFetchInnerAndBuildHashTablePanic) + // TODO: parallel fetch outer rows, partition them and do parallel join + e.prepared = true + } + return <-e.innerFinished +} + +// partitionInnerRows re-order e.innerResults into sub-relations. +func (e *RadixHashJoinExec) partitionInnerRows() error { + e.evalRadixBit() + if err := e.preAlloc4InnerParts(); err != nil { + return err + } + + wg := sync.WaitGroup{} + defer wg.Wait() + wg.Add(int(e.concurrency)) + for i := 0; i < int(e.concurrency); i++ { + workerID := i + go util.WithRecovery(func() { + defer wg.Done() + e.doInnerPartition(workerID) + }, e.handlePartitionPanic) + } + return nil +} + +func (e *RadixHashJoinExec) handlePartitionPanic(r interface{}) { + if r != nil { + e.joinResultCh <- &hashjoinWorkerResult{err: errors.Errorf("%v", r)} + } +} + +// doInnerPartition runs concurrently, partitions and copies the inner relation +// to several pre-allocated data partitions. The input inner Chunk idx for each +// partitioner is workerId + x*numPartitioners. +func (e *RadixHashJoinExec) doInnerPartition(workerID int) { + chkIdx, chkNum := workerID, e.innerResult.NumChunks() + for ; chkIdx < chkNum; chkIdx += int(e.concurrency) { + if e.finished.Load().(bool) { + return + } + chk := e.innerResult.GetChunk(chkIdx) + for srcRowIdx, partPtr := range e.innerRowPrts[chkIdx] { + if partPtr == partPtr4NullKey { + continue + } + partIdx, destRowIdx := partPtr.partitionIdx, partPtr.rowIdx + part := e.innerParts[partIdx] + part.Insert(int(destRowIdx), chk.GetRow(srcRowIdx)) + } + } +} + +// preAlloc4InnerParts evaluates partRowPtr and pre-alloc the memory space +// for every inner row to help re-order the inner relation. +// TODO: we need to evaluate the skewness for the partitions size, if the +// skewness exceeds a threshold, we do not use partition phase. +func (e *RadixHashJoinExec) preAlloc4InnerParts() (err error) { + hasNull, keyBuf := false, make([]byte, 0, 64) + for chkIdx, chkNum := 0, e.innerResult.NumChunks(); chkIdx < chkNum; chkIdx++ { + chk := e.innerResult.GetChunk(chkIdx) + partPtrs := make([]partRowPtr, chk.NumRows()) + for rowIdx := 0; rowIdx < chk.NumRows(); rowIdx++ { + row := chk.GetRow(rowIdx) + hasNull, keyBuf, err = e.getJoinKeyFromChkRow(false, row, keyBuf) + if err != nil { + return err + } + if hasNull { + partPtrs[rowIdx] = partPtr4NullKey + continue + } + joinHash := murmur3.Sum32(keyBuf) + partIdx := e.radixBits & joinHash + partPtrs[rowIdx].partitionIdx = partIdx + partPtrs[rowIdx].rowIdx = e.getPartition(partIdx).PreAlloc(row) + } + e.innerRowPrts = append(e.innerRowPrts, partPtrs) + } + if e.numNonEmptyPart < len(e.innerParts) { + numTotalPart := len(e.innerParts) + numEmptyPart := numTotalPart - e.numNonEmptyPart + log.Debugf("[EMPTY_PART_IN_RADIX_HASH_JOIN] txn_start_ts:%v, num_empty_parts:%v, "+ + "num_total_parts:%v, empty_ratio:%v", e.ctx.GetSessionVars().TxnCtx.StartTS, + numEmptyPart, numTotalPart, float64(numEmptyPart)/float64(numTotalPart)) + } + return +} + +func (e *RadixHashJoinExec) getPartition(idx uint32) partition { + if e.innerParts[idx] == nil { + e.numNonEmptyPart++ + e.innerParts[idx] = chunk.New(e.innerExec.retTypes(), e.initCap, e.maxChunkSize) + } + return e.innerParts[idx] +} + +// evalRadixBit evaluates the radix bit numbers. +// To ensure that one partition of inner relation, one hash table, one partition +// of outer relation and the join result of these two partitions fit into the L2 +// cache when the input data obeys the uniform distribution, we suppose every +// sub-partition of inner relation using three quarters of the L2 cache size. +func (e *RadixHashJoinExec) evalRadixBit() { + sv := e.ctx.GetSessionVars() + innerResultSize := float64(e.innerResult.GetMemTracker().BytesConsumed()) + l2CacheSize := float64(sv.L2CacheSize) * 3 / 4 + radixBitsNum := math.Ceil(math.Log2(innerResultSize / l2CacheSize)) + if radixBitsNum <= 0 { + radixBitsNum = 1 + } + // Take the rightmost radixBitsNum bits as the bitmask. + e.radixBits = ^(math.MaxUint32 << uint(radixBitsNum)) + e.innerParts = make([]partition, 1< 0 && e.children[0] != nil { return e.insertRowsFromSelect(ctx, e.exec) } diff --git a/executor/revoke.go b/executor/revoke.go index 2e6954d11151c..bd4ed1129ea91 100644 --- a/executor/revoke.go +++ b/executor/revoke.go @@ -51,7 +51,7 @@ type RevokeExec struct { } // Next implements the Executor Next interface. -func (e *RevokeExec) Next(ctx context.Context, chk *chunk.Chunk) error { +func (e *RevokeExec) Next(ctx context.Context, req *chunk.RecordBatch) error { if e.done { return nil } diff --git a/executor/seqtest/prepared_test.go b/executor/seqtest/prepared_test.go index e83bba66e5c01..924e549e96963 100644 --- a/executor/seqtest/prepared_test.go +++ b/executor/seqtest/prepared_test.go @@ -142,8 +142,8 @@ func (s *seqTestSuite) TestPrepared(c *C) { c.Assert(err, IsNil) rs, err = stmt.Exec(ctx) c.Assert(err, IsNil) - chk := rs.NewChunk() - err = rs.Next(ctx, chk) + req := rs.NewRecordBatch() + err = rs.Next(ctx, req) c.Assert(err, IsNil) c.Assert(rs.Close(), IsNil) diff --git a/executor/seqtest/seq_executor_test.go b/executor/seqtest/seq_executor_test.go index 693947c7dabb9..48fcac309954c 100644 --- a/executor/seqtest/seq_executor_test.go +++ b/executor/seqtest/seq_executor_test.go @@ -134,8 +134,8 @@ func (s *seqTestSuite) TestEarlyClose(c *C) { rss, err1 := tk.Se.Execute(ctx, "select * from earlyclose order by id") c.Assert(err1, IsNil) rs := rss[0] - chk := rs.NewChunk() - err = rs.Next(ctx, chk) + req := rs.NewRecordBatch() + err = rs.Next(ctx, req) c.Assert(err, IsNil) rs.Close() } @@ -146,8 +146,8 @@ func (s *seqTestSuite) TestEarlyClose(c *C) { rss, err := tk.Se.Execute(ctx, "select * from earlyclose") c.Assert(err, IsNil) rs := rss[0] - chk := rs.NewChunk() - err = rs.Next(ctx, chk) + req := rs.NewRecordBatch() + err = rs.Next(ctx, req) c.Assert(err, NotNil) rs.Close() } @@ -642,8 +642,8 @@ func (s *seqTestSuite) TestIndexDoubleReadClose(c *C) { rs, err := tk.Exec("select * from dist where c_idx between 0 and 100") c.Assert(err, IsNil) - chk := rs.NewChunk() - err = rs.Next(context.Background(), chk) + req := rs.NewRecordBatch() + err = rs.Next(context.Background(), req) c.Assert(err, IsNil) c.Assert(err, IsNil) keyword := "pickAndExecTask" @@ -672,8 +672,8 @@ func (s *seqTestSuite) TestParallelHashAggClose(c *C) { rss, err := tk.Se.Execute(ctx, "select sum(a) from (select cast(t.a as signed) as a, b from t) t group by b;") c.Assert(err, IsNil) rs := rss[0] - chk := rs.NewChunk() - err = rs.Next(ctx, chk) + req := rs.NewRecordBatch() + err = rs.Next(ctx, req) c.Assert(err.Error(), Equals, "HashAggExec.parallelExec error") } @@ -691,8 +691,8 @@ func (s *seqTestSuite) TestUnparallelHashAggClose(c *C) { rss, err := tk.Se.Execute(ctx, "select sum(distinct a) from (select cast(t.a as signed) as a, b from t) t group by b;") c.Assert(err, IsNil) rs := rss[0] - chk := rs.NewChunk() - err = rs.Next(ctx, chk) + req := rs.NewRecordBatch() + err = rs.Next(ctx, req) c.Assert(err.Error(), Equals, "HashAggExec.unparallelExec error") } diff --git a/executor/set.go b/executor/set.go index 9b67a749b735a..e4ad853181532 100644 --- a/executor/set.go +++ b/executor/set.go @@ -42,8 +42,8 @@ type SetExecutor struct { } // Next implements the Executor Next interface. -func (e *SetExecutor) Next(ctx context.Context, chk *chunk.Chunk) error { - chk.Reset() +func (e *SetExecutor) Next(ctx context.Context, req *chunk.RecordBatch) error { + req.Reset() if e.done { return nil } diff --git a/executor/set_test.go b/executor/set_test.go index 51a35985d8125..8b78a8f72823b 100644 --- a/executor/set_test.go +++ b/executor/set_test.go @@ -576,3 +576,23 @@ func (s *testSuite2) TestValidateSetVar(c *C) { _, err = tk.Exec("set @@tx_isolation='SERIALIZABLE'") c.Assert(terror.ErrorEqual(err, variable.ErrUnsupportedValueForVar), IsTrue, Commentf("err %v", err)) } + +func (s *testSuite2) TestSelectGlobalVar(c *C) { + tk := testkit.NewTestKit(c, s.store) + + tk.MustQuery("select @@global.max_connections;").Check(testkit.Rows("151")) + tk.MustQuery("select @@max_connections;").Check(testkit.Rows("151")) + + tk.MustExec("set @@global.max_connections=100;") + + tk.MustQuery("select @@global.max_connections;").Check(testkit.Rows("100")) + tk.MustQuery("select @@max_connections;").Check(testkit.Rows("100")) + + tk.MustExec("set @@global.max_connections=151;") + + // test for unknown variable. + _, err := tk.Exec("select @@invalid") + c.Assert(terror.ErrorEqual(err, variable.UnknownSystemVar), IsTrue, Commentf("err %v", err)) + _, err = tk.Exec("select @@global.invalid") + c.Assert(terror.ErrorEqual(err, variable.UnknownSystemVar), IsTrue, Commentf("err %v", err)) +} diff --git a/executor/show.go b/executor/show.go index 8b0cd089edb14..e0c0f829149d1 100644 --- a/executor/show.go +++ b/executor/show.go @@ -32,6 +32,7 @@ import ( "github.com/pingcap/parser/terror" "github.com/pingcap/tidb/infoschema" plannercore "github.com/pingcap/tidb/planner/core" + "github.com/pingcap/tidb/plugin" "github.com/pingcap/tidb/privilege" "github.com/pingcap/tidb/sessionctx/stmtctx" "github.com/pingcap/tidb/sessionctx/variable" @@ -66,8 +67,8 @@ type ShowExec struct { } // Next implements the Executor Next interface. -func (e *ShowExec) Next(ctx context.Context, chk *chunk.Chunk) error { - chk.GrowAndReset(e.maxChunkSize) +func (e *ShowExec) Next(ctx context.Context, req *chunk.RecordBatch) error { + req.GrowAndReset(e.maxChunkSize) if e.result == nil { e.result = e.newFirstChunk() err := e.fetchAll() @@ -90,8 +91,8 @@ func (e *ShowExec) Next(ctx context.Context, chk *chunk.Chunk) error { if e.cursor >= e.result.NumRows() { return nil } - numCurBatch := mathutil.Min(chk.Capacity(), e.result.NumRows()-e.cursor) - chk.Append(e.result, e.cursor, e.cursor+numCurBatch) + numCurBatch := mathutil.Min(req.Capacity(), e.result.NumRows()-e.cursor) + req.Append(e.result, e.cursor, e.cursor+numCurBatch) e.cursor += numCurBatch return nil } @@ -875,6 +876,12 @@ func (e *ShowExec) fetchShowProcedureStatus() error { } func (e *ShowExec) fetchShowPlugins() error { + tiPlugins := plugin.GetAll() + for _, ps := range tiPlugins { + for _, p := range ps { + e.appendRow([]interface{}{p.Name, p.State.String(), p.Kind.String(), p.Path, p.License, strconv.Itoa(int(p.Version))}) + } + } return nil } diff --git a/executor/simple.go b/executor/simple.go index 387c4c0b385b6..b78db5aa72a67 100644 --- a/executor/simple.go +++ b/executor/simple.go @@ -49,7 +49,7 @@ type SimpleExec struct { } // Next implements the Executor Next interface. -func (e *SimpleExec) Next(ctx context.Context, chk *chunk.Chunk) (err error) { +func (e *SimpleExec) Next(ctx context.Context, req *chunk.RecordBatch) (err error) { if e.done { return nil } diff --git a/executor/sort.go b/executor/sort.go index 2a0020c702f72..24520c9a11832 100644 --- a/executor/sort.go +++ b/executor/sort.go @@ -74,16 +74,16 @@ func (e *SortExec) Open(ctx context.Context) error { } // Next implements the Executor Next interface. -func (e *SortExec) Next(ctx context.Context, chk *chunk.Chunk) error { +func (e *SortExec) Next(ctx context.Context, req *chunk.RecordBatch) error { if span := opentracing.SpanFromContext(ctx); span != nil && span.Tracer() != nil { span1 := span.Tracer().StartSpan("sort.Next", opentracing.ChildOf(span.Context())) defer span1.Finish() } if e.runtimeStats != nil { start := time.Now() - defer func() { e.runtimeStats.Record(time.Now().Sub(start), chk.NumRows()) }() + defer func() { e.runtimeStats.Record(time.Now().Sub(start), req.NumRows()) }() } - chk.Reset() + req.Reset() if !e.fetched { err := e.fetchRowChunks(ctx) if err != nil { @@ -104,12 +104,12 @@ func (e *SortExec) Next(ctx context.Context, chk *chunk.Chunk) error { } e.fetched = true } - for chk.NumRows() < e.maxChunkSize { + for req.NumRows() < e.maxChunkSize { if e.Idx >= len(e.rowPtrs) { return nil } rowPtr := e.rowPtrs[e.Idx] - chk.AppendRow(e.rowChunks.GetRow(rowPtr)) + req.AppendRow(e.rowChunks.GetRow(rowPtr)) e.Idx++ } return nil @@ -122,7 +122,7 @@ func (e *SortExec) fetchRowChunks(ctx context.Context) error { e.rowChunks.GetMemTracker().SetLabel("rowChunks") for { chk := e.children[0].newFirstChunk() - err := e.children[0].Next(ctx, chk) + err := e.children[0].Next(ctx, chunk.NewRecordBatch(chk)) if err != nil { return errors.Trace(err) } @@ -305,16 +305,16 @@ func (e *TopNExec) Open(ctx context.Context) error { } // Next implements the Executor Next interface. -func (e *TopNExec) Next(ctx context.Context, chk *chunk.Chunk) error { +func (e *TopNExec) Next(ctx context.Context, req *chunk.RecordBatch) error { if span := opentracing.SpanFromContext(ctx); span != nil && span.Tracer() != nil { span1 := span.Tracer().StartSpan("topN.Next", opentracing.ChildOf(span.Context())) defer span1.Finish() } if e.runtimeStats != nil { start := time.Now() - defer func() { e.runtimeStats.Record(time.Now().Sub(start), chk.NumRows()) }() + defer func() { e.runtimeStats.Record(time.Now().Sub(start), req.NumRows()) }() } - chk.Reset() + req.Reset() if !e.fetched { e.totalLimit = e.limit.Offset + e.limit.Count e.Idx = int(e.limit.Offset) @@ -331,9 +331,9 @@ func (e *TopNExec) Next(ctx context.Context, chk *chunk.Chunk) error { if e.Idx >= len(e.rowPtrs) { return nil } - for chk.NumRows() < e.maxChunkSize && e.Idx < len(e.rowPtrs) { + for req.NumRows() < e.maxChunkSize && e.Idx < len(e.rowPtrs) { row := e.rowChunks.GetRow(e.rowPtrs[e.Idx]) - chk.AppendRow(row) + req.AppendRow(row) e.Idx++ } return nil @@ -346,7 +346,7 @@ func (e *TopNExec) loadChunksUntilTotalLimit(ctx context.Context) error { e.rowChunks.GetMemTracker().SetLabel("rowChunks") for uint64(e.rowChunks.Len()) < e.totalLimit { srcChk := e.children[0].newFirstChunk() - err := e.children[0].Next(ctx, srcChk) + err := e.children[0].Next(ctx, chunk.NewRecordBatch(srcChk)) if err != nil { return errors.Trace(err) } @@ -382,7 +382,7 @@ func (e *TopNExec) executeTopN(ctx context.Context) error { } childRowChk := e.children[0].newFirstChunk() for { - err := e.children[0].Next(ctx, childRowChk) + err := e.children[0].Next(ctx, chunk.NewRecordBatch(childRowChk)) if err != nil { return errors.Trace(err) } diff --git a/executor/table_reader.go b/executor/table_reader.go index af2202919e511..2f1bf0b3569db 100644 --- a/executor/table_reader.go +++ b/executor/table_reader.go @@ -100,16 +100,16 @@ func (e *TableReaderExecutor) Open(ctx context.Context) error { // Next fills data into the chunk passed by its caller. // The task was actually done by tableReaderHandler. -func (e *TableReaderExecutor) Next(ctx context.Context, chk *chunk.Chunk) error { +func (e *TableReaderExecutor) Next(ctx context.Context, req *chunk.RecordBatch) error { if span := opentracing.SpanFromContext(ctx); span != nil && span.Tracer() != nil { span1 := span.Tracer().StartSpan("tableReader.Next", opentracing.ChildOf(span.Context())) defer span1.Finish() } if e.runtimeStats != nil { start := time.Now() - defer func() { e.runtimeStats.Record(time.Now().Sub(start), chk.NumRows()) }() + defer func() { e.runtimeStats.Record(time.Now().Sub(start), req.NumRows()) }() } - if err := e.resultHandler.nextChunk(ctx, chk); err != nil { + if err := e.resultHandler.nextChunk(ctx, req.Chunk); err != nil { e.feedback.Invalidate() return err } diff --git a/executor/trace.go b/executor/trace.go index 0ae3cfa003a19..8b75fa2b4564c 100644 --- a/executor/trace.go +++ b/executor/trace.go @@ -16,18 +16,16 @@ package executor import ( "context" "encoding/json" + "sort" "time" "github.com/opentracing/basictracer-go" "github.com/opentracing/opentracing-go" "github.com/pingcap/errors" "github.com/pingcap/parser/ast" - "github.com/pingcap/tidb/planner" - plannercore "github.com/pingcap/tidb/planner/core" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/util/chunk" "github.com/pingcap/tidb/util/sqlexec" - "github.com/pingcap/tidb/util/tracing" "sourcegraph.com/sourcegraph/appdash" traceImpl "sourcegraph.com/sourcegraph/appdash/opentracing" ) @@ -50,137 +48,117 @@ type TraceExec struct { } // Next executes real query and collects span later. -func (e *TraceExec) Next(ctx context.Context, chk *chunk.Chunk) error { - chk.Reset() +func (e *TraceExec) Next(ctx context.Context, req *chunk.RecordBatch) error { + req.Reset() if e.exhausted { return nil } - - if e.format == "json" { - if se, ok := e.ctx.(sqlexec.SQLExecutor); ok { - store := appdash.NewMemoryStore() - tracer := traceImpl.NewTracer(store) - span := tracer.StartSpan("trace") - defer span.Finish() - ctx = opentracing.ContextWithSpan(ctx, span) - recordSets, err := se.Execute(ctx, e.stmtNode.Text()) - if err != nil { - return errors.Trace(err) - } - - for _, rs := range recordSets { - _, err = drainRecordSet(ctx, e.ctx, rs) - if err != nil { - return errors.Trace(err) - } - if err = rs.Close(); err != nil { - return errors.Trace(err) - } - } - - traces, err := store.Traces(appdash.TracesOpts{}) - if err != nil { - return errors.Trace(err) - } - data, err := json.Marshal(traces) - if err != nil { - return errors.Trace(err) - } - - // Split json data into rows to avoid the max packet size limitation. - const maxRowLen = 4096 - for len(data) > maxRowLen { - chk.AppendString(0, string(data[:maxRowLen])) - data = data[maxRowLen:] - } - chk.AppendString(0, string(data)) - } + se, ok := e.ctx.(sqlexec.SQLExecutor) + if !ok { e.exhausted = true return nil } - // TODO: If the following code is never used, remove it later. - // record how much time was spent for optimizeing plan - optimizeSp := e.rootTrace.Tracer().StartSpan("plan_optimize", opentracing.FollowsFrom(e.rootTrace.Context())) - stmtPlan, err := planner.Optimize(e.builder.ctx, e.stmtNode, e.builder.is) + store := appdash.NewMemoryStore() + tracer := traceImpl.NewTracer(store) + span := tracer.StartSpan("trace") + defer span.Finish() + ctx = opentracing.ContextWithSpan(ctx, span) + recordSets, err := se.Execute(ctx, e.stmtNode.Text()) if err != nil { - return err + return errors.Trace(err) } - optimizeSp.Finish() - pp, ok := stmtPlan.(plannercore.PhysicalPlan) - if !ok { - return errors.New("cannot cast logical plan to physical plan") + for _, rs := range recordSets { + _, err = drainRecordSet(ctx, e.ctx, rs) + if err != nil { + return errors.Trace(err) + } + if err = rs.Close(); err != nil { + return errors.Trace(err) + } } - // append select executor to trace executor - stmtExec := e.builder.build(pp) - - e.rootTrace = tracing.NewRecordedTrace("trace_exec", func(sp basictracer.RawSpan) { - e.CollectedSpans = append(e.CollectedSpans, sp) - }) - err = stmtExec.Open(ctx) + traces, err := store.Traces(appdash.TracesOpts{}) if err != nil { return errors.Trace(err) } - stmtExecChk := stmtExec.newFirstChunk() - // store span into context - ctx = opentracing.ContextWithSpan(ctx, e.rootTrace) - - for { - if err := stmtExec.Next(ctx, stmtExecChk); err != nil { - return errors.Trace(err) - } - if stmtExecChk.NumRows() == 0 { - break + // Row format. + if e.format != "json" { + if len(traces) < 1 { + e.exhausted = true + return nil } + trace := traces[0] + sortTraceByStartTime(trace) + dfsTree(trace, "", false, req.Chunk) + e.exhausted = true + return nil } - e.rootTrace.LogKV("event", "tracing completed") - e.rootTrace.Finish() - var rootSpan basictracer.RawSpan - - treeSpans := make(map[uint64][]basictracer.RawSpan) - for _, sp := range e.CollectedSpans { - treeSpans[sp.ParentSpanID] = append(treeSpans[sp.ParentSpanID], sp) - // if a span's parentSpanID is 0, then it is root span - // this is by design - if sp.ParentSpanID == 0 { - rootSpan = sp - } + // Json format. + data, err := json.Marshal(traces) + if err != nil { + return errors.Trace(err) } - dfsTree(rootSpan, treeSpans, "", false, chk) + // Split json data into rows to avoid the max packet size limitation. + const maxRowLen = 4096 + for len(data) > maxRowLen { + req.AppendString(0, string(data[:maxRowLen])) + data = data[maxRowLen:] + } + req.AppendString(0, string(data)) e.exhausted = true return nil } func drainRecordSet(ctx context.Context, sctx sessionctx.Context, rs sqlexec.RecordSet) ([]chunk.Row, error) { var rows []chunk.Row - chk := rs.NewChunk() + req := rs.NewRecordBatch() for { - err := rs.Next(ctx, chk) - if err != nil || chk.NumRows() == 0 { + err := rs.Next(ctx, req) + if err != nil || req.NumRows() == 0 { return rows, errors.Trace(err) } - iter := chunk.NewIterator4Chunk(chk) + iter := chunk.NewIterator4Chunk(req.Chunk) for r := iter.Begin(); r != iter.End(); r = iter.Next() { rows = append(rows, r) } - chk = chunk.Renew(chk, sctx.GetSessionVars().MaxChunkSize) + req.Chunk = chunk.Renew(req.Chunk, sctx.GetSessionVars().MaxChunkSize) } } -func dfsTree(span basictracer.RawSpan, tree map[uint64][]basictracer.RawSpan, prefix string, isLast bool, chk *chunk.Chunk) { - suffix := "" - spans := tree[span.Context.SpanID] - var newPrefix string - if span.ParentSpanID == 0 { - newPrefix = prefix +type sortByStartTime []*appdash.Trace + +func (t sortByStartTime) Len() int { return len(t) } +func (t sortByStartTime) Less(i, j int) bool { + return getStartTime(t[j]).After(getStartTime(t[i])) +} +func (t sortByStartTime) Swap(i, j int) { t[i], t[j] = t[j], t[i] } + +func getStartTime(trace *appdash.Trace) (t time.Time) { + if e, err := trace.TimespanEvent(); err == nil { + t = e.Start() + } + return +} + +func sortTraceByStartTime(trace *appdash.Trace) { + sort.Sort(sortByStartTime(trace.Sub)) + for _, t := range trace.Sub { + sortTraceByStartTime(t) + } +} + +func dfsTree(t *appdash.Trace, prefix string, isLast bool, chk *chunk.Chunk) { + var newPrefix, suffix string + if len(prefix) == 0 { + newPrefix = prefix + " " } else { - if len(tree[span.ParentSpanID]) > 0 && !isLast { + if !isLast { suffix = "├─" newPrefix = prefix + "│ " } else { @@ -189,11 +167,19 @@ func dfsTree(span basictracer.RawSpan, tree map[uint64][]basictracer.RawSpan, pr } } - chk.AppendString(0, prefix+suffix+span.Operation) - chk.AppendString(1, span.Start.Format(time.StampNano)) - chk.AppendString(2, span.Duration.String()) + var start time.Time + var duration time.Duration + if e, err := t.TimespanEvent(); err == nil { + start = e.Start() + end := e.End() + duration = end.Sub(start) + } + + chk.AppendString(0, prefix+suffix+t.Span.Name()) + chk.AppendString(1, start.Format("15:04:05.000000")) + chk.AppendString(2, duration.String()) - for i, sp := range spans { - dfsTree(sp, tree, newPrefix, i == (len(spans))-1 /*last element of array*/, chk) + for i, sp := range t.Sub { + dfsTree(sp, newPrefix, i == (len(t.Sub))-1 /*last element of array*/, chk) } } diff --git a/executor/trace_test.go b/executor/trace_test.go index ec298500a918d..fe60b58692cd5 100644 --- a/executor/trace_test.go +++ b/executor/trace_test.go @@ -26,4 +26,19 @@ func (s *testSuite1) TestTraceExec(c *C) { tk.MustExec("trace insert into trace (c1, c2, c3) values (1, 2, 3)") rows := tk.MustQuery("trace select * from trace where id = 0;").Rows() c.Assert(rows, HasLen, 1) + + // +---------------------------+-----------------+------------+ + // | operation | startTS | duration | + // +---------------------------+-----------------+------------+ + // | session.getTxnFuture | 22:08:38.247834 | 78.909µs | + // | ├─session.Execute | 22:08:38.247829 | 1.478487ms | + // | ├─session.ParseSQL | 22:08:38.248457 | 71.159µs | + // | ├─executor.Compile | 22:08:38.248578 | 45.329µs | + // | ├─session.runStmt | 22:08:38.248661 | 75.13µs | + // | ├─session.CommitTxn | 22:08:38.248699 | 13.213µs | + // | └─recordSet.Next | 22:08:38.249340 | 155.317µs | + // +---------------------------+-----------------+------------+ + rows = tk.MustQuery("trace format='row' select * from trace where id = 0;").Rows() + + c.Assert(len(rows) > 1, IsTrue) } diff --git a/executor/union_scan.go b/executor/union_scan.go index 187d8e0c7a112..dc8bd0f10e49d 100644 --- a/executor/union_scan.go +++ b/executor/union_scan.go @@ -125,14 +125,14 @@ func (us *UnionScanExec) Open(ctx context.Context) error { } // Next implements the Executor Next interface. -func (us *UnionScanExec) Next(ctx context.Context, chk *chunk.Chunk) error { +func (us *UnionScanExec) Next(ctx context.Context, req *chunk.RecordBatch) error { if us.runtimeStats != nil { start := time.Now() - defer func() { us.runtimeStats.Record(time.Now().Sub(start), chk.NumRows()) }() + defer func() { us.runtimeStats.Record(time.Now().Sub(start), req.NumRows()) }() } - chk.GrowAndReset(us.maxChunkSize) + req.GrowAndReset(us.maxChunkSize) mutableRow := chunk.MutRowFromTypes(us.retTypes()) - for i, batchSize := 0, chk.Capacity(); i < batchSize; i++ { + for i, batchSize := 0, req.Capacity(); i < batchSize; i++ { row, err := us.getOneRow(ctx) if err != nil { return errors.Trace(err) @@ -142,7 +142,7 @@ func (us *UnionScanExec) Next(ctx context.Context, chk *chunk.Chunk) error { return nil } mutableRow.SetDatums(row...) - chk.AppendRow(mutableRow.ToRow()) + req.AppendRow(mutableRow.ToRow()) } return nil } @@ -197,7 +197,7 @@ func (us *UnionScanExec) getSnapshotRow(ctx context.Context) ([]types.Datum, err us.cursor4SnapshotRows = 0 us.snapshotRows = us.snapshotRows[:0] for len(us.snapshotRows) == 0 { - err = us.children[0].Next(ctx, us.snapshotChunkBuffer) + err = us.children[0].Next(ctx, chunk.NewRecordBatch(us.snapshotChunkBuffer)) if err != nil || us.snapshotChunkBuffer.NumRows() == 0 { return nil, errors.Trace(err) } diff --git a/executor/update.go b/executor/update.go index 9227bfa690ed6..3e9a5db6cfc81 100644 --- a/executor/update.go +++ b/executor/update.go @@ -132,13 +132,13 @@ func (e *UpdateExec) canNotUpdate(handle types.Datum) bool { } // Next implements the Executor Next interface. -func (e *UpdateExec) Next(ctx context.Context, chk *chunk.Chunk) error { +func (e *UpdateExec) Next(ctx context.Context, req *chunk.RecordBatch) error { if span := opentracing.SpanFromContext(ctx); span != nil && span.Tracer() != nil { span1 := span.Tracer().StartSpan("update.Next", opentracing.ChildOf(span.Context())) defer span1.Finish() } - chk.Reset() + req.Reset() if !e.fetched { err := e.fetchChunkRows(ctx) if err != nil { @@ -181,7 +181,7 @@ func (e *UpdateExec) fetchChunkRows(ctx context.Context) error { chk := e.children[0].newFirstChunk() e.evalBuffer = chunk.MutRowFromTypes(fields) for { - err := e.children[0].Next(ctx, chk) + err := e.children[0].Next(ctx, chunk.NewRecordBatch(chk)) if err != nil { return errors.Trace(err) } diff --git a/executor/write_test.go b/executor/write_test.go index f00ebcf2bdee4..ccbb18726c673 100644 --- a/executor/write_test.go +++ b/executor/write_test.go @@ -237,6 +237,27 @@ func (s *testSuite2) TestInsert(c *C) { tk.MustExec("insert into test values(2, 3)") tk.MustQuery("select * from test use index (id) where id = 2").Check(testkit.Rows("2 2", "2 3")) + // issue 6360 + tk.MustExec("drop table if exists t;") + tk.MustExec("create table t(a bigint unsigned);") + tk.MustExec(" set @orig_sql_mode = @@sql_mode; set @@sql_mode = 'strict_all_tables';") + _, err = tk.Exec("insert into t value (-1);") + c.Assert(types.ErrWarnDataOutOfRange.Equal(err), IsTrue) + tk.MustExec("set @@sql_mode = '';") + tk.MustExec("insert into t value (-1);") + // TODO: the following warning messages are not consistent with MySQL, fix them in the future PRs + tk.MustQuery("show warnings").Check(testkit.Rows("Warning 1690 constant -1 overflows bigint")) + tk.MustExec("insert into t select -1;") + tk.MustQuery("show warnings").Check(testkit.Rows("Warning 1690 constant -1 overflows bigint")) + tk.MustExec("insert into t select cast(-1 as unsigned);") + tk.MustExec("insert into t value (-1.111);") + tk.MustQuery("show warnings").Check(testkit.Rows("Warning 1690 constant -1 overflows bigint")) + tk.MustExec("insert into t value ('-1.111');") + tk.MustQuery("show warnings").Check(testkit.Rows("Warning 1690 BIGINT UNSIGNED value is out of range in '-1'")) + r = tk.MustQuery("select * from t;") + r.Check(testkit.Rows("0", "0", "18446744073709551615", "0", "0")) + tk.MustExec("set @@sql_mode = @orig_sql_mode;") + // issue 6424 tk.MustExec("drop table if exists t") tk.MustExec("create table t(a time(6))") @@ -271,6 +292,13 @@ func (s *testSuite2) TestInsert(c *C) { tk.MustExec("truncate table t") tk.MustExec("insert into t (b) values(default(a))") tk.MustQuery("select * from t").Check(testkit.Rows("1 1")) + + tk.MustExec("create view v as select * from t") + _, err = tk.Exec("insert into v values(1,2)") + c.Assert(err.Error(), Equals, "insert into view v is not supported now.") + _, err = tk.Exec("replace into v values(1,2)") + c.Assert(err.Error(), Equals, "replace into view v is not supported now.") + tk.MustExec("drop view v") } func (s *testSuite2) TestInsertAutoInc(c *C) { @@ -1302,6 +1330,11 @@ func (s *testSuite) TestUpdate(c *C) { tk.MustExec("update t set b = ''") tk.MustQuery("select * from t").Check(testkit.Rows("0000-00-00 00:00:00 ")) tk.MustExec("set @@sql_mode=@orig_sql_mode;") + + tk.MustExec("create view v as select * from t") + _, err = tk.Exec("update v set a = '2000-11-11'") + c.Assert(err.Error(), Equals, "update view v is not supported now.") + tk.MustExec("drop view v") } func (s *testSuite2) TestPartitionedTableUpdate(c *C) { @@ -1559,6 +1592,11 @@ func (s *testSuite) TestDelete(c *C) { tk.MustExec(`delete from delete_test ;`) tk.CheckExecResult(1, 0) + + tk.MustExec("create view v as select * from delete_test") + _, err = tk.Exec("delete from v where name = 'aaa'") + c.Assert(err.Error(), Equals, "delete view v is not supported now.") + tk.MustExec("drop view v") } func (s *testSuite2) TestPartitionedTableDelete(c *C) { @@ -1920,6 +1958,26 @@ func (s *testSuite2) TestLoadDataIgnoreLines(c *C) { checkCases(tests, ld, c, tk, ctx, selectSQL, deleteSQL) } +// related to issue 6360 +func (s *testSuite2) TestLoadDataOverflowBigintUnsigned(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test; drop table if exists load_data_test;") + tk.MustExec("CREATE TABLE load_data_test (a bigint unsigned);") + tk.MustExec("load data local infile '/tmp/nonexistence.csv' into table load_data_test") + ctx := tk.Se.(sessionctx.Context) + ld, ok := ctx.Value(executor.LoadDataVarKey).(*executor.LoadDataInfo) + c.Assert(ok, IsTrue) + defer ctx.SetValue(executor.LoadDataVarKey, nil) + c.Assert(ld, NotNil) + tests := []testCase{ + {nil, []byte("-1\n-18446744073709551615\n-18446744073709551616\n"), []string{"0", "0", "0"}, nil, "Records: 3 Deleted: 0 Skipped: 0 Warnings: 3"}, + {nil, []byte("-9223372036854775809\n18446744073709551616\n"), []string{"0", "18446744073709551615"}, nil, "Records: 2 Deleted: 0 Skipped: 0 Warnings: 2"}, + } + deleteSQL := "delete from load_data_test" + selectSQL := "select * from load_data_test;" + checkCases(tests, ld, c, tk, ctx, selectSQL, deleteSQL) +} + func (s *testSuite2) TestBatchInsertDelete(c *C) { originLimit := atomic.LoadUint64(&kv.TxnEntryCountLimit) defer func() { diff --git a/expression/builtin_cast.go b/expression/builtin_cast.go index e1550ef7c3012..80163ef3699b1 100644 --- a/expression/builtin_cast.go +++ b/expression/builtin_cast.go @@ -464,7 +464,8 @@ func (b *builtinCastIntAsRealSig) evalReal(row chunk.Row) (res float64, isNull b res = 0 } else { var uVal uint64 - uVal, err = types.ConvertIntToUint(val, types.UnsignedUpperBound[mysql.TypeLonglong], mysql.TypeLonglong) + sc := b.ctx.GetSessionVars().StmtCtx + uVal, err = types.ConvertIntToUint(sc, val, types.UnsignedUpperBound[mysql.TypeLonglong], mysql.TypeLonglong) res = float64(uVal) } return res, false, err @@ -491,7 +492,8 @@ func (b *builtinCastIntAsDecimalSig) evalDecimal(row chunk.Row) (res *types.MyDe res = &types.MyDecimal{} } else { var uVal uint64 - uVal, err = types.ConvertIntToUint(val, types.UnsignedUpperBound[mysql.TypeLonglong], mysql.TypeLonglong) + sc := b.ctx.GetSessionVars().StmtCtx + uVal, err = types.ConvertIntToUint(sc, val, types.UnsignedUpperBound[mysql.TypeLonglong], mysql.TypeLonglong) if err != nil { return res, false, err } @@ -520,7 +522,8 @@ func (b *builtinCastIntAsStringSig) evalString(row chunk.Row) (res string, isNul res = strconv.FormatInt(val, 10) } else { var uVal uint64 - uVal, err = types.ConvertIntToUint(val, types.UnsignedUpperBound[mysql.TypeLonglong], mysql.TypeLonglong) + sc := b.ctx.GetSessionVars().StmtCtx + uVal, err = types.ConvertIntToUint(sc, val, types.UnsignedUpperBound[mysql.TypeLonglong], mysql.TypeLonglong) if err != nil { return res, false, err } @@ -750,7 +753,8 @@ func (b *builtinCastRealAsIntSig) evalInt(row chunk.Row) (res int64, isNull bool res = 0 } else { var uintVal uint64 - uintVal, err = types.ConvertFloatToUint(val, types.UnsignedUpperBound[mysql.TypeLonglong], mysql.TypeDouble) + sc := b.ctx.GetSessionVars().StmtCtx + uintVal, err = types.ConvertFloatToUint(sc, val, types.UnsignedUpperBound[mysql.TypeLonglong], mysql.TypeDouble) res = int64(uintVal) } return res, isNull, err @@ -1001,7 +1005,7 @@ func (b *builtinCastDecimalAsDurationSig) Clone() builtinFunc { func (b *builtinCastDecimalAsDurationSig) evalDuration(row chunk.Row) (res types.Duration, isNull bool, err error) { val, isNull, err := b.args[0].EvalDecimal(b.ctx, row) if isNull || err != nil { - return res, false, err + return res, true, err } res, err = types.ParseDuration(b.ctx.GetSessionVars().StmtCtx, string(val.ToString()), b.tp.Decimal) if types.ErrTruncatedWrongVal.Equal(err) { diff --git a/expression/builtin_encryption.go b/expression/builtin_encryption.go index cf4744efb8d0e..85083cc5cfeb3 100644 --- a/expression/builtin_encryption.go +++ b/expression/builtin_encryption.go @@ -93,6 +93,9 @@ var aesModes = map[string]*aesModeAttr{ "aes-128-cbc": {"cbc", 16, true}, "aes-192-cbc": {"cbc", 24, true}, "aes-256-cbc": {"cbc", 32, true}, + "aes-128-ofb": {"ofb", 16, true}, + "aes-192-ofb": {"ofb", 24, true}, + "aes-256-ofb": {"ofb", 32, true}, "aes-128-cfb": {"cfb", 16, true}, "aes-192-cfb": {"cfb", 24, true}, "aes-256-cfb": {"cfb", 32, true}, @@ -212,6 +215,8 @@ func (b *builtinAesDecryptIVSig) evalString(row chunk.Row) (string, bool, error) switch b.modeName { case "cbc": plainText, err = encrypt.AESDecryptWithCBC([]byte(cryptStr), key, []byte(iv)) + case "ofb": + plainText, err = encrypt.AESDecryptWithOFB([]byte(cryptStr), key, []byte(iv)) case "cfb": plainText, err = encrypt.AESDecryptWithCFB([]byte(cryptStr), key, []byte(iv)) default: @@ -337,6 +342,8 @@ func (b *builtinAesEncryptIVSig) evalString(row chunk.Row) (string, bool, error) switch b.modeName { case "cbc": cipherText, err = encrypt.AESEncryptWithCBC([]byte(str), key, []byte(iv)) + case "ofb": + cipherText, err = encrypt.AESEncryptWithOFB([]byte(str), key, []byte(iv)) case "cfb": cipherText, err = encrypt.AESEncryptWithCFB([]byte(str), key, []byte(iv)) default: diff --git a/expression/builtin_encryption_test.go b/expression/builtin_encryption_test.go index e66b9053dc2cb..4a12f14f1a9af 100644 --- a/expression/builtin_encryption_test.go +++ b/expression/builtin_encryption_test.go @@ -101,6 +101,13 @@ var aesTests = []struct { {"aes-256-cbc", "pingcap", []interface{}{"1234567890123456", "1234567890123456"}, "5D0E22C1E77523AEF5C3E10B65653C8F"}, {"aes-256-cbc", "pingcap", []interface{}{"12345678901234561234567890123456", "1234567890123456"}, "A26BA27CA4BE9D361D545AA84A17002D"}, {"aes-256-cbc", "pingcap", []interface{}{"1234567890123456", "12345678901234561234567890123456"}, "5D0E22C1E77523AEF5C3E10B65653C8F"}, + // test for ofb + {"aes-128-ofb", "pingcap", []interface{}{"1234567890123456", "1234567890123456"}, "0515A36BBF3DE0"}, + {"aes-128-ofb", "pingcap", []interface{}{"123456789012345678901234", "1234567890123456"}, "C2A93A93818546"}, + {"aes-192-ofb", "pingcap", []interface{}{"1234567890123456", "1234567890123456"}, "FE09DCCF14D458"}, + {"aes-256-ofb", "pingcap", []interface{}{"1234567890123456", "1234567890123456"}, "2E70FCAC0C0834"}, + {"aes-256-ofb", "pingcap", []interface{}{"12345678901234561234567890123456", "1234567890123456"}, "83E2B30A71F011"}, + {"aes-256-ofb", "pingcap", []interface{}{"1234567890123456", "12345678901234561234567890123456"}, "2E70FCAC0C0834"}, // test for cfb {"aes-128-cfb", "pingcap", []interface{}{"1234567890123456", "1234567890123456"}, "0515A36BBF3DE0"}, {"aes-128-cfb", "pingcap", []interface{}{"123456789012345678901234", "1234567890123456"}, "C2A93A93818546"}, diff --git a/expression/builtin_json.go b/expression/builtin_json.go index 4f5486e9b1af3..7fb2aec1e2164 100644 --- a/expression/builtin_json.go +++ b/expression/builtin_json.go @@ -403,6 +403,11 @@ func (b *builtinJSONMergeSig) evalJSON(row chunk.Row) (res json.BinaryJSON, isNu values = append(values, value) } res = json.MergeBinary(values) + // function "JSON_MERGE" is deprecated since MySQL 5.7.22. Synonym for function "JSON_MERGE_PRESERVE". + // See https://dev.mysql.com/doc/refman/5.7/en/json-modification-functions.html#function_json-merge + if b.pbCode == tipb.ScalarFuncSig_JsonMergeSig { + b.ctx.GetSessionVars().StmtCtx.AppendWarning(errDeprecatedSyntaxNoReplacement.GenWithStackByArgs("JSON_MERGE")) + } return res, false, nil } @@ -720,7 +725,17 @@ type jsonMergePreserveFunctionClass struct { } func (c *jsonMergePreserveFunctionClass) getFunction(ctx sessionctx.Context, args []Expression) (builtinFunc, error) { - return nil, errFunctionNotExists.GenWithStackByArgs("FUNCTION", "JSON_MERGE_PRESERVE") + if err := c.verifyArgs(args); err != nil { + return nil, err + } + argTps := make([]types.EvalType, 0, len(args)) + for range args { + argTps = append(argTps, types.ETJson) + } + bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETJson, argTps...) + sig := &builtinJSONMergeSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_JsonMergePreserveSig) + return sig, nil } type jsonPrettyFunctionClass struct { diff --git a/expression/builtin_json_test.go b/expression/builtin_json_test.go index 7b78081803a7e..8f6888e8c035c 100644 --- a/expression/builtin_json_test.go +++ b/expression/builtin_json_test.go @@ -187,6 +187,39 @@ func (s *testEvaluatorSuite) TestJSONMerge(c *C) { j2 := d.GetMysqlJSON() cmp := json.CompareBinary(j1, j2) c.Assert(cmp, Equals, 0, Commentf("got %v expect %v", j1.String(), j2.String())) + case nil: + c.Assert(d.IsNull(), IsTrue) + } + } +} + +func (s *testEvaluatorSuite) TestJSONMergePreserve(c *C) { + defer testleak.AfterTest(c)() + fc := funcs[ast.JSONMergePreserve] + tbl := []struct { + Input []interface{} + Expected interface{} + }{ + {[]interface{}{nil, nil}, nil}, + {[]interface{}{`{}`, `[]`}, `[{}]`}, + {[]interface{}{`{}`, `[]`, `3`, `"4"`}, `[{}, 3, "4"]`}, + } + for _, t := range tbl { + args := types.MakeDatums(t.Input...) + f, err := fc.getFunction(s.ctx, s.datumsToConstants(args)) + c.Assert(err, IsNil) + d, err := evalBuiltinFunc(f, chunk.Row{}) + c.Assert(err, IsNil) + + switch x := t.Expected.(type) { + case string: + j1, err := json.ParseBinaryFromString(x) + c.Assert(err, IsNil) + j2 := d.GetMysqlJSON() + cmp := json.CompareBinary(j1, j2) + c.Assert(cmp, Equals, 0, Commentf("got %v expect %v", j1.String(), j2.String())) + case nil: + c.Assert(d.IsNull(), IsTrue) } } } diff --git a/expression/errors.go b/expression/errors.go index 8a18acfa2ad5a..1ebc4ffdaba7b 100644 --- a/expression/errors.go +++ b/expression/errors.go @@ -72,7 +72,7 @@ func handleInvalidTimeError(ctx sessionctx.Context, err error) error { return err } sc := ctx.GetSessionVars().StmtCtx - if ctx.GetSessionVars().StrictSQLMode && (sc.InInsertStmt || sc.InUpdateOrDeleteStmt) { + if ctx.GetSessionVars().StrictSQLMode && (sc.InInsertStmt || sc.InUpdateStmt || sc.InDeleteStmt) { return err } sc.AppendWarning(err) @@ -82,7 +82,7 @@ func handleInvalidTimeError(ctx sessionctx.Context, err error) error { // handleDivisionByZeroError reports error or warning depend on the context. func handleDivisionByZeroError(ctx sessionctx.Context) error { sc := ctx.GetSessionVars().StmtCtx - if sc.InInsertStmt || sc.InUpdateOrDeleteStmt { + if sc.InInsertStmt || sc.InUpdateStmt || sc.InDeleteStmt { if !ctx.GetSessionVars().SQLMode.HasErrorForDivisionByZeroMode() { return nil } diff --git a/expression/integration_test.go b/expression/integration_test.go index aed28975c7227..cf6d02b574954 100644 --- a/expression/integration_test.go +++ b/expression/integration_test.go @@ -1006,6 +1006,21 @@ func (s *testIntegrationSuite) TestEncryptionBuiltin(c *C) { result.Check(testkit.Rows("341672829F84CB6B0BE690FEC4C4DAE9 341672829F84CB6B0BE690FEC4C4DAE9 D43734E147A12BB96C6897C4BBABA283 16F2C972411948DCEF3659B726D2CCB04AD1379A1A367FA64242058A50211B67 41E71D0C58967C1F50EEC074523946D1 1117D292E2D39C3EAA3B435371BE56FC 8ACB7ECC0883B672D7BD1CFAA9FA5FAF5B731ADE978244CD581F114D591C2E7E D2B13C30937E3251AEDA73859BA32E4B 2CF4A6051FF248A67598A17AA2C17267")) result = tk.MustQuery("select HEX(AES_ENCRYPT('123', 'foobar', '1234567890123456')), HEX(AES_ENCRYPT(123, 'foobar', '1234567890123456')), HEX(AES_ENCRYPT('', 'foobar', '1234567890123456')), HEX(AES_ENCRYPT('你好', 'foobar', '1234567890123456')), AES_ENCRYPT(NULL, 'foobar', '1234567890123456')") result.Check(testkit.Rows(`80D5646F07B4654B05A02D9085759770 80D5646F07B4654B05A02D9085759770 B3C14BA15030D2D7E99376DBE011E752 0CD2936EE4FEC7A8CDF6208438B2BC05 `)) + tk.MustExec("SET block_encryption_mode='aes-128-ofb';") + result = tk.MustQuery("select HEX(AES_ENCRYPT(a, 'key', '1234567890123456')), HEX(AES_ENCRYPT(b, 'key', '1234567890123456')), HEX(AES_ENCRYPT(c, 'key', '1234567890123456')), HEX(AES_ENCRYPT(d, 'key', '1234567890123456')), HEX(AES_ENCRYPT(e, 'key', '1234567890123456')), HEX(AES_ENCRYPT(f, 'key', '1234567890123456')), HEX(AES_ENCRYPT(g, 'key', '1234567890123456')), HEX(AES_ENCRYPT(h, 'key', '1234567890123456')), HEX(AES_ENCRYPT(i, 'key', '1234567890123456')) from t") + result.Check(testkit.Rows("40 40 40C35C 40DD5EBDFCAA397102386E27DDF97A39ECCEC5 43DF55BAE0A0386D 78 47DC5D8AD19A085C32094E16EFC34A08D6FEF459 46D5 06840BE8")) + result = tk.MustQuery("select HEX(AES_ENCRYPT('123', 'foobar', '1234567890123456')), HEX(AES_ENCRYPT(123, 'foobar', '1234567890123456')), HEX(AES_ENCRYPT('', 'foobar', '1234567890123456')), HEX(AES_ENCRYPT('你好', 'foobar', '1234567890123456')), AES_ENCRYPT(NULL, 'foobar', '1234567890123456')") + result.Check(testkit.Rows(`48E38A 48E38A 9D6C199101C3 `)) + tk.MustExec("SET block_encryption_mode='aes-192-ofb';") + result = tk.MustQuery("select HEX(AES_ENCRYPT(a, 'key', '1234567890123456')), HEX(AES_ENCRYPT(b, 'key', '1234567890123456')), HEX(AES_ENCRYPT(c, 'key', '1234567890123456')), HEX(AES_ENCRYPT(d, 'key', '1234567890123456')), HEX(AES_ENCRYPT(e, 'key', '1234567890123456')), HEX(AES_ENCRYPT(f, 'key', '1234567890123456')), HEX(AES_ENCRYPT(g, 'key', '1234567890123456')), HEX(AES_ENCRYPT(h, 'key', '1234567890123456')), HEX(AES_ENCRYPT(i, 'key', '1234567890123456')) from t") + result.Check(testkit.Rows("4B 4B 4B573F 4B493D42572E6477233A429BF3E0AD39DB816D 484B36454B24656B 73 4C483E757A1E555A130B62AAC1DA9D08E1B15C47 4D41 0D106817")) + result = tk.MustQuery("select HEX(AES_ENCRYPT('123', 'foobar', '1234567890123456')), HEX(AES_ENCRYPT(123, 'foobar', '1234567890123456')), HEX(AES_ENCRYPT('', 'foobar', '1234567890123456')), HEX(AES_ENCRYPT('你好', 'foobar', '1234567890123456')), AES_ENCRYPT(NULL, 'foobar', '1234567890123456')") + result.Check(testkit.Rows(`3A76B0 3A76B0 EFF92304268E `)) + tk.MustExec("SET block_encryption_mode='aes-256-ofb';") + result = tk.MustQuery("select HEX(AES_ENCRYPT(a, 'key', '1234567890123456')), HEX(AES_ENCRYPT(b, 'key', '1234567890123456')), HEX(AES_ENCRYPT(c, 'key', '1234567890123456')), HEX(AES_ENCRYPT(d, 'key', '1234567890123456')), HEX(AES_ENCRYPT(e, 'key', '1234567890123456')), HEX(AES_ENCRYPT(f, 'key', '1234567890123456')), HEX(AES_ENCRYPT(g, 'key', '1234567890123456')), HEX(AES_ENCRYPT(h, 'key', '1234567890123456')), HEX(AES_ENCRYPT(i, 'key', '1234567890123456')) from t") + result.Check(testkit.Rows("16 16 16D103 16CF01CBC95D33E2ED721CBD930262415A69AD 15CD0ACCD55732FE 2E 11CE02FCE46D02CFDD433C8CA138527060599C35 10C7 5096549E")) + result = tk.MustQuery("select HEX(AES_ENCRYPT('123', 'foobar', '1234567890123456')), HEX(AES_ENCRYPT(123, 'foobar', '1234567890123456')), HEX(AES_ENCRYPT('', 'foobar', '1234567890123456')), HEX(AES_ENCRYPT('你好', 'foobar', '1234567890123456')), AES_ENCRYPT(NULL, 'foobar', '1234567890123456')") + result.Check(testkit.Rows(`E842C5 E842C5 3DCD5646767D `)) // for AES_DECRYPT tk.MustExec("SET block_encryption_mode='aes-128-ecb';") @@ -1018,6 +1033,21 @@ func (s *testIntegrationSuite) TestEncryptionBuiltin(c *C) { result.Check(testkit.Rows("foo")) result = tk.MustQuery("select AES_DECRYPT(UNHEX('80D5646F07B4654B05A02D9085759770'), 'foobar', '1234567890123456'), AES_DECRYPT(UNHEX('B3C14BA15030D2D7E99376DBE011E752'), 'foobar', '1234567890123456'), AES_DECRYPT(UNHEX('0CD2936EE4FEC7A8CDF6208438B2BC05'), 'foobar', '1234567890123456'), AES_DECRYPT(NULL, 'foobar', '1234567890123456'), AES_DECRYPT('SOME_THING_STRANGE', 'foobar', '1234567890123456')") result.Check(testkit.Rows(`123 你好 `)) + tk.MustExec("SET block_encryption_mode='aes-128-ofb';") + result = tk.MustQuery("select AES_DECRYPT(AES_ENCRYPT('foo', 'bar', '1234567890123456'), 'bar', '1234567890123456')") + result.Check(testkit.Rows("foo")) + result = tk.MustQuery("select AES_DECRYPT(UNHEX('48E38A'), 'foobar', '1234567890123456'), AES_DECRYPT(UNHEX(''), 'foobar', '1234567890123456'), AES_DECRYPT(UNHEX('9D6C199101C3'), 'foobar', '1234567890123456'), AES_DECRYPT(NULL, 'foobar', '1234567890123456'), HEX(AES_DECRYPT('SOME_THING_STRANGE', 'foobar', '1234567890123456'))") + result.Check(testkit.Rows(`123 你好 2A9EF431FB2ACB022D7F2E7C71EEC48C7D2B`)) + tk.MustExec("SET block_encryption_mode='aes-192-ofb';") + result = tk.MustQuery("select AES_DECRYPT(AES_ENCRYPT('foo', 'bar', '1234567890123456'), 'bar', '1234567890123456')") + result.Check(testkit.Rows("foo")) + result = tk.MustQuery("select AES_DECRYPT(UNHEX('3A76B0'), 'foobar', '1234567890123456'), AES_DECRYPT(UNHEX(''), 'foobar', '1234567890123456'), AES_DECRYPT(UNHEX('EFF92304268E'), 'foobar', '1234567890123456'), AES_DECRYPT(NULL, 'foobar', '1234567890123456'), HEX(AES_DECRYPT('SOME_THING_STRANGE', 'foobar', '1234567890123456'))") + result.Check(testkit.Rows(`123 你好 580BCEA4DC67CF33FF2C7C570D36ECC89437`)) + tk.MustExec("SET block_encryption_mode='aes-256-ofb';") + result = tk.MustQuery("select AES_DECRYPT(AES_ENCRYPT('foo', 'bar', '1234567890123456'), 'bar', '1234567890123456')") + result.Check(testkit.Rows("foo")) + result = tk.MustQuery("select AES_DECRYPT(UNHEX('E842C5'), 'foobar', '1234567890123456'), AES_DECRYPT(UNHEX(''), 'foobar', '1234567890123456'), AES_DECRYPT(UNHEX('3DCD5646767D'), 'foobar', '1234567890123456'), AES_DECRYPT(NULL, 'foobar', '1234567890123456'), HEX(AES_DECRYPT('SOME_THING_STRANGE', 'foobar', '1234567890123456'))") + result.Check(testkit.Rows(`123 你好 8A3FBBE68C9465834584430E3AEEBB04B1F5`)) // for COMPRESS tk.MustExec("DROP TABLE IF EXISTS t1;") @@ -3797,3 +3827,33 @@ func (s *testIntegrationSuite) TestUserVarMockWindFunc(c *C) { `3 6 3 key3-value6 insert_order6`, )) } + +func (s *testIntegrationSuite) TestCastAsTime(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec(`use test;`) + tk.MustExec(`drop table if exists t;`) + tk.MustExec(`create table t (col1 bigint, col2 double, col3 decimal, col4 varchar(20), col5 json);`) + tk.MustExec(`insert into t values (1, 1, 1, "1", "1");`) + tk.MustExec(`insert into t values (null, null, null, null, null);`) + tk.MustQuery(`select cast(col1 as time), cast(col2 as time), cast(col3 as time), cast(col4 as time), cast(col5 as time) from t where col1 = 1;`).Check(testkit.Rows( + `00:00:01 00:00:01 00:00:01 00:00:01 00:00:01`, + )) + tk.MustQuery(`select cast(col1 as time), cast(col2 as time), cast(col3 as time), cast(col4 as time), cast(col5 as time) from t where col1 is null;`).Check(testkit.Rows( + ` `, + )) + + err := tk.ExecToErr(`select cast(col1 as time(31)) from t where col1 is null;`) + c.Assert(err.Error(), Equals, "[expression:1426]Too big precision 31 specified for column 'CAST'. Maximum is 6.") + + err = tk.ExecToErr(`select cast(col2 as time(31)) from t where col1 is null;`) + c.Assert(err.Error(), Equals, "[expression:1426]Too big precision 31 specified for column 'CAST'. Maximum is 6.") + + err = tk.ExecToErr(`select cast(col3 as time(31)) from t where col1 is null;`) + c.Assert(err.Error(), Equals, "[expression:1426]Too big precision 31 specified for column 'CAST'. Maximum is 6.") + + err = tk.ExecToErr(`select cast(col4 as time(31)) from t where col1 is null;`) + c.Assert(err.Error(), Equals, "[expression:1426]Too big precision 31 specified for column 'CAST'. Maximum is 6.") + + err = tk.ExecToErr(`select cast(col5 as time(31)) from t where col1 is null;`) + c.Assert(err.Error(), Equals, "[expression:1426]Too big precision 31 specified for column 'CAST'. Maximum is 6.") +} diff --git a/go.mod b/go.mod index fcf325a9541d3..aca95a2dd3866 100644 --- a/go.mod +++ b/go.mod @@ -21,7 +21,7 @@ require ( github.com/ghodss/yaml v1.0.0 // indirect github.com/go-ole/go-ole v1.2.1 // indirect github.com/go-sql-driver/mysql v0.0.0-20170715192408-3955978caca4 - github.com/gogo/protobuf v1.1.1 // indirect + github.com/gogo/protobuf v1.2.0 // indirect github.com/golang/groupcache v0.0.0-20181024230925-c65c006176ff // indirect github.com/golang/protobuf v1.2.0 github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db // indirect @@ -47,10 +47,10 @@ require ( github.com/pingcap/errors v0.11.0 github.com/pingcap/gofail v0.0.0-20181217135706-6a951c1e42c3 github.com/pingcap/goleveldb v0.0.0-20171020122428-b9ff6c35079e - github.com/pingcap/kvproto v0.0.0-20181203065228-c14302da291c - github.com/pingcap/parser v0.0.0-20190106063416-3483d83d44bd + github.com/pingcap/kvproto v0.0.0-20190110035000-d4fe6b336379 + github.com/pingcap/parser v0.0.0-20190114015132-c5c6ec2eb454 github.com/pingcap/pd v2.1.0-rc.4+incompatible - github.com/pingcap/tidb-tools v2.1.1-0.20181218072513-b2235d442b06+incompatible + github.com/pingcap/tidb-tools v2.1.3-0.20190104033906-883b07a04a73+incompatible github.com/pingcap/tipb v0.0.0-20181012112600-11e33c750323 github.com/pkg/errors v0.8.0 // indirect github.com/prometheus/client_golang v0.9.0 @@ -77,10 +77,13 @@ require ( go.uber.org/atomic v1.3.2 // indirect go.uber.org/multierr v1.1.0 // indirect go.uber.org/zap v1.9.1 // indirect - golang.org/x/net v0.0.0-20181029044818-c44066c5c816 + golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e + golang.org/x/sys v0.0.0-20190109145017-48ac38b7c8cb // indirect golang.org/x/text v0.3.0 golang.org/x/time v0.0.0-20181108054448-85acf8d2951c // indirect - google.golang.org/grpc v1.16.0 + golang.org/x/tools v0.0.0-20190110015856-aa033095749b // indirect + google.golang.org/genproto v0.0.0-20190108161440-ae2f86662275 // indirect + google.golang.org/grpc v1.17.0 gopkg.in/natefinch/lumberjack.v2 v2.0.0 gopkg.in/stretchr/testify.v1 v1.2.2 // indirect sourcegraph.com/sourcegraph/appdash v0.0.0-20180531100431-4c381bd170b4 diff --git a/go.sum b/go.sum index 708d32b715daa..e82541c9b339c 100644 --- a/go.sum +++ b/go.sum @@ -45,8 +45,8 @@ github.com/go-ole/go-ole v1.2.1/go.mod h1:7FAglXiTm7HKlQRDeOQ6ZNUHidzCWXuZWq/1dT github.com/go-sql-driver/mysql v0.0.0-20170715192408-3955978caca4 h1:3DFRjZdCDhzvxDf0U6/1qAryzOqD7Y5iAj0DJRRl1bs= github.com/go-sql-driver/mysql v0.0.0-20170715192408-3955978caca4/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG5ZlKdlhCg5w= github.com/gogo/protobuf v0.0.0-20180717141946-636bf0302bc9/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= -github.com/gogo/protobuf v1.1.1 h1:72R+M5VuhED/KujmZVcIquuo8mBgX4oVda//DQb3PXo= -github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= +github.com/gogo/protobuf v1.2.0 h1:xU6/SpYbvkNYiptHJYEDRseDLvYE7wSqhYYNy0QSUzI= +github.com/gogo/protobuf v1.2.0/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b h1:VKtxabqXZkF25pY9ekfRL6a582T4P37/31XEstQ5p58= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/groupcache v0.0.0-20181024230925-c65c006176ff h1:kOkM9whyQYodu09SJ6W3NCsHG7crFaJILQ22Gozp3lg= @@ -108,15 +108,14 @@ github.com/pingcap/gofail v0.0.0-20181217135706-6a951c1e42c3 h1:04yuCf5NMvLU8rB2 github.com/pingcap/gofail v0.0.0-20181217135706-6a951c1e42c3/go.mod h1:DazNTg0PTldtpsQiT9I5tVJwV1onHMKBBgXzmJUlMns= github.com/pingcap/goleveldb v0.0.0-20171020122428-b9ff6c35079e h1:P73/4dPCL96rGrobssy1nVy2VaVpNCuLpCbr+FEaTA8= github.com/pingcap/goleveldb v0.0.0-20171020122428-b9ff6c35079e/go.mod h1:O17XtbryoCJhkKGbT62+L2OlrniwqiGLSqrmdHCMzZw= -github.com/pingcap/kvproto v0.0.0-20181203065228-c14302da291c h1:Qf5St5XGwKgKQLar9lEXoeO0hJMVaFBj3JqvFguWtVg= -github.com/pingcap/kvproto v0.0.0-20181203065228-c14302da291c/go.mod h1:Ja9XPjot9q4/3JyCZodnWDGNXt4pKemhIYCvVJM7P24= -github.com/pingcap/parser v0.0.0-20190106063416-3483d83d44bd h1:FWAAGBZWj5oL4XwYF3p2rfhMSH5JeApOyW2c7ggCsUY= -github.com/pingcap/parser v0.0.0-20190106063416-3483d83d44bd/go.mod h1:1FNvfp9+J0wvc4kl8eGNh7Rqrxveg15jJoWo/a0uHwA= +github.com/pingcap/kvproto v0.0.0-20190110035000-d4fe6b336379 h1:l4KInBOtxjbgQLjCFHzX66vZgNzsH4a+RiuVZGrO0xk= +github.com/pingcap/kvproto v0.0.0-20190110035000-d4fe6b336379/go.mod h1:QMdbTAXCHzzygQzqcG9uVUgU2fKeSN1GmfMiykdSzzY= +github.com/pingcap/parser v0.0.0-20190114015132-c5c6ec2eb454 h1:8wqFaAY5HLvDH35UkzMhtuxb4Q0fk6/yeiOscfOmMpo= +github.com/pingcap/parser v0.0.0-20190114015132-c5c6ec2eb454/go.mod h1:1FNvfp9+J0wvc4kl8eGNh7Rqrxveg15jJoWo/a0uHwA= github.com/pingcap/pd v2.1.0-rc.4+incompatible h1:/buwGk04aHO5odk/+O8ZOXGs4qkUjYTJ2UpCJXna8NE= github.com/pingcap/pd v2.1.0-rc.4+incompatible/go.mod h1:nD3+EoYes4+aNNODO99ES59V83MZSI+dFbhyr667a0E= -github.com/pingcap/tidb-tools v2.1.1-0.20181218072513-b2235d442b06+incompatible h1:Bsd+NHosPVowEGB3BCx+2d8wUQGDTXSSC5ljeNS6cXo= -github.com/pingcap/tidb-tools v2.1.1-0.20181218072513-b2235d442b06+incompatible/go.mod h1:XGdcy9+yqlDSEMTpOXnwf3hiTeqrV6MN/u1se9N8yIM= -github.com/pingcap/tipb v0.0.0-20170310053819-1043caee48da/go.mod h1:RtkHW8WbcNxj8lsbzjaILci01CtYnYbIkQhjyZWrWVI= +github.com/pingcap/tidb-tools v2.1.3-0.20190104033906-883b07a04a73+incompatible h1:Ba48wwPwPq5hd1kkQpgua49dqB5cthC2zXVo7fUUDec= +github.com/pingcap/tidb-tools v2.1.3-0.20190104033906-883b07a04a73+incompatible/go.mod h1:XGdcy9+yqlDSEMTpOXnwf3hiTeqrV6MN/u1se9N8yIM= github.com/pingcap/tipb v0.0.0-20181012112600-11e33c750323 h1:mRKKzRjDNaUNPnAkPAHnRqpNmwNWBX1iA+hxlmvQ93I= github.com/pingcap/tipb v0.0.0-20181012112600-11e33c750323/go.mod h1:RtkHW8WbcNxj8lsbzjaILci01CtYnYbIkQhjyZWrWVI= github.com/pkg/errors v0.8.0 h1:WdK/asTD0HN+q6hsWO3/vpuAkAr+tw6aNJNDFFf0+qw= @@ -177,35 +176,44 @@ go.uber.org/zap v1.9.1/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= golang.org/x/crypto v0.0.0-20180904163835-0709b304e793 h1:u+LnwYTOOW7Ukr/fppxEb1Nwz0AtPflrblfvUudpo+I= golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/lint v0.0.0-20180702182130-06c8688daad7/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= +golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20181005035420-146acd28ed58/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20181029044818-c44066c5c816 h1:mVFkLpejdFLXVUv9E42f3XJVfMdqd0IVLVIVLjZWn5o= -golang.org/x/net v0.0.0-20181029044818-c44066c5c816/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20181106065722-10aee1819953/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e h1:bRhVy7zSSasaqNksaRZiA5EEI+Ei4I1nO5Jh72wfHlg= +golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f h1:wMNYb4v58l5UBM7MYRLPG6ZhfOqbKu7X5eyFl8ZhKvA= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4 h1:YUO/7uOKsKeq9UokNS62b8FYywz3ker1l1vDZRCRefw= +golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e h1:o3PsSEY8E4eXWkXrIP9YJALUkVZqzHJT5DOasTyn8Vs= golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/text v0.0.0-20171214130843-f21a4dfb5e38/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/sys v0.0.0-20190109145017-48ac38b7c8cb h1:1w588/yEchbPNpa9sEvOcMZYbWHedwJjg4VOAdDHWHk= +golang.org/x/sys v0.0.0-20190109145017-48ac38b7c8cb/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c h1:fqgJT0MGcGpPgpWU7VRdRjuArfcOvC4AoJmILihzhDg= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/tools v0.0.0-20180828015842-6cd1fcedba52 h1:JG/0uqcGdTNgq7FdU+61l5Pdmb8putNZlXb65bJBROs= golang.org/x/tools v0.0.0-20180828015842-6cd1fcedba52/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20181008205924-a2b3f7f249e9 h1:T3nuFyXXDj5KXX9CqQm/r/YEL4Gua01s/ZEdfdLyJ2c= -golang.org/x/tools v0.0.0-20181008205924-a2b3f7f249e9/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190110015856-aa033095749b h1:G5tsw1T5VA7PD7VmXyGtX/hQp3ABPSCPRKVfsdUcVxs= +golang.org/x/tools v0.0.0-20190110015856-aa033095749b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8 h1:Nw54tB0rB7hY/N0NQvRW8DG4Yk3Q6T9cu9RcFQDu1tc= google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= google.golang.org/genproto v0.0.0-20181004005441-af9cb2a35e7f h1:FU37niK8AQ59mHcskRyQL7H0ErSeNh650vdcj8HqdSI= google.golang.org/genproto v0.0.0-20181004005441-af9cb2a35e7f/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= +google.golang.org/genproto v0.0.0-20190108161440-ae2f86662275 h1:9oFlwfEGIvmxXTcY53ygNyxIQtWciRHjrnUvZJCYXYU= +google.golang.org/genproto v0.0.0-20190108161440-ae2f86662275/go.mod h1:7Ep/1NZk928CDR8SjdVbjWNpdIf6nzjE3BTgJDr2Atg= google.golang.org/grpc v0.0.0-20180607172857-7a6a684ca69e/go.mod h1:yo6s7OP7yaDglbqo1J04qKzAhqBH6lvTonzMVmEdcZw= google.golang.org/grpc v1.16.0 h1:dz5IJGuC2BB7qXR5AyHNwAUBhZscK2xVez7mznh72sY= google.golang.org/grpc v1.16.0/go.mod h1:0JHn/cJsOMiMfNA9+DeHDlAU7KAAB5GDlYFpa9MZMio= +google.golang.org/grpc v1.17.0 h1:TRJYBgMclJvGYn2rIMjj+h9KtMt5r1Ij7ODVRIZkwhk= +google.golang.org/grpc v1.17.0/go.mod h1:6QZJwpn2B+Zp71q/5VxRsJ6NXXVCE5NRUHRo+f3cWCs= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/fsnotify.v1 v1.4.7 h1:xOHLXZwVvI9hhs+cLKq5+I5onOuwQLhQwiu63xxlHs4= diff --git a/infoschema/infoschema.go b/infoschema/infoschema.go index d8bdf96b10682..44ab45ff7abfa 100644 --- a/infoschema/infoschema.go +++ b/infoschema/infoschema.go @@ -358,5 +358,13 @@ func initInfoSchemaDB() { // IsMemoryDB checks if the db is in memory. func IsMemoryDB(dbName string) bool { - return dbName == "information_schema" || dbName == "performance_schema" + if dbName == "information_schema" { + return true + } + for _, driver := range drivers { + if driver.DBInfo.Name.L == dbName { + return true + } + } + return false } diff --git a/infoschema/infoschema_test.go b/infoschema/infoschema_test.go index 3381f19d64aa0..41c2e128d0f88 100644 --- a/infoschema/infoschema_test.go +++ b/infoschema/infoschema_test.go @@ -24,7 +24,6 @@ import ( "github.com/pingcap/tidb/infoschema" "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/meta" - "github.com/pingcap/tidb/perfschema" "github.com/pingcap/tidb/store/mockstore" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/testleak" @@ -120,7 +119,7 @@ func (*testSuite) TestT(c *C) { schemaNames := is.AllSchemaNames() c.Assert(schemaNames, HasLen, 3) - c.Assert(testutil.CompareUnorderedStringSlice(schemaNames, []string{infoschema.Name, perfschema.Name, "Test"}), IsTrue) + c.Assert(testutil.CompareUnorderedStringSlice(schemaNames, []string{infoschema.Name, "PERFORMANCE_SCHEMA", "Test"}), IsTrue) schemas := is.AllSchemas() c.Assert(schemas, HasLen, 3) diff --git a/perfschema/const.go b/infoschema/perfschema/const.go similarity index 100% rename from perfschema/const.go rename to infoschema/perfschema/const.go diff --git a/perfschema/init.go b/infoschema/perfschema/init.go similarity index 100% rename from perfschema/init.go rename to infoschema/perfschema/init.go diff --git a/infoschema/perfschema/tables.go b/infoschema/perfschema/tables.go new file mode 100644 index 0000000000000..dbe3e68155659 --- /dev/null +++ b/infoschema/perfschema/tables.go @@ -0,0 +1,79 @@ +// Copyright 2017 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package perfschema + +import ( + "github.com/pingcap/parser/model" + "github.com/pingcap/tidb/infoschema" + "github.com/pingcap/tidb/meta/autoid" + "github.com/pingcap/tidb/table" +) + +// perfSchemaTable stands for the fake table all its data is in the memory. +type perfSchemaTable struct { + infoschema.VirtualTable + meta *model.TableInfo + cols []*table.Column +} + +var pluginTable = make(map[string]func(autoid.Allocator, *model.TableInfo) (table.Table, error)) + +// RegisterTable registers a new table into TiDB. +func RegisterTable(tableName, sql string, + tableFromMeta func(autoid.Allocator, *model.TableInfo) (table.Table, error)) { + perfSchemaTables = append(perfSchemaTables, sql) + pluginTable[tableName] = tableFromMeta +} + +func tableFromMeta(alloc autoid.Allocator, meta *model.TableInfo) (table.Table, error) { + if f, ok := pluginTable[meta.Name.L]; ok { + ret, err := f(alloc, meta) + return ret, err + } + return createPerfSchemaTable(meta), nil +} + +// createPerfSchemaTable creates all perfSchemaTables +func createPerfSchemaTable(meta *model.TableInfo) *perfSchemaTable { + columns := make([]*table.Column, 0, len(meta.Columns)) + for _, colInfo := range meta.Columns { + col := table.ToColumn(colInfo) + columns = append(columns, col) + } + t := &perfSchemaTable{ + meta: meta, + cols: columns, + } + return t +} + +// Cols implements table.Table Type interface. +func (vt *perfSchemaTable) Cols() []*table.Column { + return vt.cols +} + +// WritableCols implements table.Table Type interface. +func (vt *perfSchemaTable) WritableCols() []*table.Column { + return vt.cols +} + +// GetID implements table.Table GetID interface. +func (vt *perfSchemaTable) GetPhysicalID() int64 { + return vt.meta.ID +} + +// Meta implements table.Table Type interface. +func (vt *perfSchemaTable) Meta() *model.TableInfo { + return vt.meta +} diff --git a/perfschema/tables_test.go b/infoschema/perfschema/tables_test.go similarity index 100% rename from perfschema/tables_test.go rename to infoschema/perfschema/tables_test.go diff --git a/infoschema/tables.go b/infoschema/tables.go index 2d03d6be03b3a..dc18a9a1d302d 100644 --- a/infoschema/tables.go +++ b/infoschema/tables.go @@ -1603,3 +1603,120 @@ func (it *infoschemaTable) Seek(ctx sessionctx.Context, h int64) (int64, bool, e func (it *infoschemaTable) Type() table.Type { return table.VirtualTable } + +// VirtualTable is a dummy table.Table implementation. +type VirtualTable struct{} + +// IterRecords implements table.Table Type interface. +func (vt *VirtualTable) IterRecords(ctx sessionctx.Context, startKey kv.Key, cols []*table.Column, + fn table.RecordIterFunc) error { + if len(startKey) != 0 { + return table.ErrUnsupportedOp + } + return nil +} + +// RowWithCols implements table.Table Type interface. +func (vt *VirtualTable) RowWithCols(ctx sessionctx.Context, h int64, cols []*table.Column) ([]types.Datum, error) { + return nil, table.ErrUnsupportedOp +} + +// Row implements table.Table Type interface. +func (vt *VirtualTable) Row(ctx sessionctx.Context, h int64) ([]types.Datum, error) { + return nil, table.ErrUnsupportedOp +} + +// Cols implements table.Table Type interface. +func (vt *VirtualTable) Cols() []*table.Column { + return nil +} + +// WritableCols implements table.Table Type interface. +func (vt *VirtualTable) WritableCols() []*table.Column { + return nil +} + +// Indices implements table.Table Type interface. +func (vt *VirtualTable) Indices() []table.Index { + return nil +} + +// WritableIndices implements table.Table Type interface. +func (vt *VirtualTable) WritableIndices() []table.Index { + return nil +} + +// DeletableIndices implements table.Table Type interface. +func (vt *VirtualTable) DeletableIndices() []table.Index { + return nil +} + +// RecordPrefix implements table.Table Type interface. +func (vt *VirtualTable) RecordPrefix() kv.Key { + return nil +} + +// IndexPrefix implements table.Table Type interface. +func (vt *VirtualTable) IndexPrefix() kv.Key { + return nil +} + +// FirstKey implements table.Table Type interface. +func (vt *VirtualTable) FirstKey() kv.Key { + return nil +} + +// RecordKey implements table.Table Type interface. +func (vt *VirtualTable) RecordKey(h int64) kv.Key { + return nil +} + +// AddRecord implements table.Table Type interface. +func (vt *VirtualTable) AddRecord(ctx sessionctx.Context, r []types.Datum, opts ...*table.AddRecordOpt) (recordID int64, err error) { + return 0, table.ErrUnsupportedOp +} + +// RemoveRecord implements table.Table Type interface. +func (vt *VirtualTable) RemoveRecord(ctx sessionctx.Context, h int64, r []types.Datum) error { + return table.ErrUnsupportedOp +} + +// UpdateRecord implements table.Table Type interface. +func (vt *VirtualTable) UpdateRecord(ctx sessionctx.Context, h int64, oldData, newData []types.Datum, touched []bool) error { + return table.ErrUnsupportedOp +} + +// AllocAutoID implements table.Table Type interface. +func (vt *VirtualTable) AllocAutoID(ctx sessionctx.Context) (int64, error) { + return 0, table.ErrUnsupportedOp +} + +// Allocator implements table.Table Type interface. +func (vt *VirtualTable) Allocator(ctx sessionctx.Context) autoid.Allocator { + return nil +} + +// RebaseAutoID implements table.Table Type interface. +func (vt *VirtualTable) RebaseAutoID(ctx sessionctx.Context, newBase int64, isSetStep bool) error { + return table.ErrUnsupportedOp +} + +// Meta implements table.Table Type interface. +func (vt *VirtualTable) Meta() *model.TableInfo { + return nil +} + +// GetPhysicalID implements table.Table GetPhysicalID interface. +func (vt *VirtualTable) GetPhysicalID() int64 { + return 0 +} + +// Seek implements table.Table Type interface. +func (vt *VirtualTable) Seek(ctx sessionctx.Context, h int64) (int64, bool, error) { + return 0, false, table.ErrUnsupportedOp +} + +// Type implements table.Table Type interface. +func (vt *VirtualTable) Type() table.Type { + return table.VirtualTable +} diff --git a/metrics/metrics.go b/metrics/metrics.go index 56e7fe7d2132a..53a404aee9b88 100644 --- a/metrics/metrics.go +++ b/metrics/metrics.go @@ -35,6 +35,9 @@ const ( LabelGCWorker = "gcworker" LabelAnalyze = "analyze" + LabelBatchRecvLoop = "batch-recv-loop" + LabelBatchSendLoop = "batch-send-loop" + opSucc = "ok" opFailed = "err" ) @@ -132,4 +135,6 @@ func RegisterMetrics() { prometheus.MustRegister(TotalCopProcHistogram) prometheus.MustRegister(TotalCopWaitHistogram) prometheus.MustRegister(CPUUsagePercentageGauge) + prometheus.MustRegister(TiKVPendingBatchRequests) + prometheus.MustRegister(TiKVBatchWaitDuration) } diff --git a/metrics/tikvclient.go b/metrics/tikvclient.go index bce08635c6cc4..4a530c0512ae2 100644 --- a/metrics/tikvclient.go +++ b/metrics/tikvclient.go @@ -187,4 +187,23 @@ var ( Help: "Wait time of a get local latch.", Buckets: prometheus.ExponentialBuckets(0.0005, 2, 20), }) + + // TiKVPendingBatchRequests indicates the number of requests pending in the batch channel. + TiKVPendingBatchRequests = prometheus.NewGauge( + prometheus.GaugeOpts{ + Namespace: "tidb", + Subsystem: "tikvclient", + Name: "pending_batch_requests", + Help: "Pending batch requests", + }) + + TiKVBatchWaitDuration = prometheus.NewHistogram( + prometheus.HistogramOpts{ + Namespace: "tidb", + Subsystem: "tikvclient", + Name: "batch_wait_duration", + // Min bucket is [0, 1ns). + Buckets: prometheus.ExponentialBuckets(1, 2, 30), + Help: "batch wait duration", + }) ) diff --git a/perfschema/tables.go b/perfschema/tables.go deleted file mode 100644 index 2288311cd9bea..0000000000000 --- a/perfschema/tables.go +++ /dev/null @@ -1,161 +0,0 @@ -// Copyright 2017 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// See the License for the specific language governing permissions and -// limitations under the License. - -package perfschema - -import ( - "github.com/pingcap/parser/model" - "github.com/pingcap/tidb/kv" - "github.com/pingcap/tidb/meta/autoid" - "github.com/pingcap/tidb/sessionctx" - "github.com/pingcap/tidb/table" - "github.com/pingcap/tidb/types" -) - -// perfSchemaTable stands for the fake table all its data is in the memory. -type perfSchemaTable struct { - meta *model.TableInfo - cols []*table.Column -} - -func tableFromMeta(alloc autoid.Allocator, meta *model.TableInfo) (table.Table, error) { - return createPerfSchemaTable(meta), nil -} - -// createPerfSchemaTable creates all perfSchemaTables -func createPerfSchemaTable(meta *model.TableInfo) *perfSchemaTable { - columns := make([]*table.Column, 0, len(meta.Columns)) - for _, colInfo := range meta.Columns { - col := table.ToColumn(colInfo) - columns = append(columns, col) - } - t := &perfSchemaTable{ - meta: meta, - cols: columns, - } - return t -} - -// IterRecords implements table.Table Type interface. -func (vt *perfSchemaTable) IterRecords(ctx sessionctx.Context, startKey kv.Key, cols []*table.Column, - fn table.RecordIterFunc) error { - if len(startKey) != 0 { - return table.ErrUnsupportedOp - } - return nil -} - -// RowWithCols implements table.Table Type interface. -func (vt *perfSchemaTable) RowWithCols(ctx sessionctx.Context, h int64, cols []*table.Column) ([]types.Datum, error) { - return nil, table.ErrUnsupportedOp -} - -// Row implements table.Table Type interface. -func (vt *perfSchemaTable) Row(ctx sessionctx.Context, h int64) ([]types.Datum, error) { - return nil, table.ErrUnsupportedOp -} - -// Cols implements table.Table Type interface. -func (vt *perfSchemaTable) Cols() []*table.Column { - return vt.cols -} - -// WritableCols implements table.Table Type interface. -func (vt *perfSchemaTable) WritableCols() []*table.Column { - return vt.cols -} - -// Indices implements table.Table Type interface. -func (vt *perfSchemaTable) Indices() []table.Index { - return nil -} - -// WritableIndices implements table.Table Type interface. -func (vt *perfSchemaTable) WritableIndices() []table.Index { - return nil -} - -// DeletableIndices implements table.Table Type interface. -func (vt *perfSchemaTable) DeletableIndices() []table.Index { - return nil -} - -// RecordPrefix implements table.Table Type interface. -func (vt *perfSchemaTable) RecordPrefix() kv.Key { - return nil -} - -// IndexPrefix implements table.Table Type interface. -func (vt *perfSchemaTable) IndexPrefix() kv.Key { - return nil -} - -// FirstKey implements table.Table Type interface. -func (vt *perfSchemaTable) FirstKey() kv.Key { - return nil -} - -// RecordKey implements table.Table Type interface. -func (vt *perfSchemaTable) RecordKey(h int64) kv.Key { - return nil -} - -// AddRecord implements table.Table Type interface. -func (vt *perfSchemaTable) AddRecord(ctx sessionctx.Context, r []types.Datum, opts ...*table.AddRecordOpt) (recordID int64, err error) { - return 0, table.ErrUnsupportedOp -} - -// RemoveRecord implements table.Table Type interface. -func (vt *perfSchemaTable) RemoveRecord(ctx sessionctx.Context, h int64, r []types.Datum) error { - return table.ErrUnsupportedOp -} - -// UpdateRecord implements table.Table Type interface. -func (vt *perfSchemaTable) UpdateRecord(ctx sessionctx.Context, h int64, oldData, newData []types.Datum, touched []bool) error { - return table.ErrUnsupportedOp -} - -// AllocAutoID implements table.Table Type interface. -func (vt *perfSchemaTable) AllocAutoID(ctx sessionctx.Context) (int64, error) { - return 0, table.ErrUnsupportedOp -} - -// Allocator implements table.Table Type interface. -func (vt *perfSchemaTable) Allocator(ctx sessionctx.Context) autoid.Allocator { - return nil -} - -// RebaseAutoID implements table.Table Type interface. -func (vt *perfSchemaTable) RebaseAutoID(ctx sessionctx.Context, newBase int64, isSetStep bool) error { - return table.ErrUnsupportedOp -} - -// Meta implements table.Table Type interface. -func (vt *perfSchemaTable) Meta() *model.TableInfo { - return vt.meta -} - -// GetID implements table.Table GetID interface. -func (vt *perfSchemaTable) GetPhysicalID() int64 { - return vt.meta.ID -} - -// Seek implements table.Table Type interface. -func (vt *perfSchemaTable) Seek(ctx sessionctx.Context, h int64) (int64, bool, error) { - return 0, false, table.ErrUnsupportedOp -} - -// Type implements table.Table Type interface. -func (vt *perfSchemaTable) Type() table.Type { - return table.VirtualTable -} diff --git a/planner/core/cbo_test.go b/planner/core/cbo_test.go index 8454b042f7b7d..e8d8e164c50b5 100644 --- a/planner/core/cbo_test.go +++ b/planner/core/cbo_test.go @@ -472,6 +472,11 @@ func (s *testAnalyzeSuite) TestAnalyze(c *C) { testKit.MustExec("insert into t4 (a,b) values (1,1),(1,2),(1,3),(1,4),(2,5),(2,6),(2,7),(2,8)") testKit.MustExec("analyze table t4") + testKit.MustExec("create view v as select * from t") + _, err = testKit.Exec("analyze table v") + c.Assert(err.Error(), Equals, "analyze v is not supported now.") + testKit.MustExec("drop view v") + tests := []struct { sql string best string diff --git a/planner/core/errors.go b/planner/core/errors.go index 562657db2b97b..4f43ca6670f11 100644 --- a/planner/core/errors.go +++ b/planner/core/errors.go @@ -60,6 +60,7 @@ const ( codeWindowNoInherentFrame = mysql.ErrWindowNoInherentFrame codeWindowNoRedefineOrderBy = mysql.ErrWindowNoRedefineOrderBy codeWindowDuplicateName = mysql.ErrWindowDuplicateName + codeErrTooBigPrecision = mysql.ErrTooBigPrecision ) // error definitions. @@ -106,6 +107,7 @@ var ( ErrWindowNoInherentFrame = terror.ClassOptimizer.New(codeWindowNoInherentFrame, mysql.MySQLErrName[mysql.ErrWindowNoInherentFrame]) ErrWindowNoRedefineOrderBy = terror.ClassOptimizer.New(codeWindowNoRedefineOrderBy, mysql.MySQLErrName[mysql.ErrWindowNoRedefineOrderBy]) ErrWindowDuplicateName = terror.ClassOptimizer.New(codeWindowDuplicateName, mysql.MySQLErrName[mysql.ErrWindowDuplicateName]) + errTooBigPrecision = terror.ClassExpression.New(mysql.ErrTooBigPrecision, mysql.MySQLErrName[mysql.ErrTooBigPrecision]) ) func init() { @@ -142,6 +144,7 @@ func init() { codeWindowNoInherentFrame: mysql.ErrWindowNoInherentFrame, codeWindowNoRedefineOrderBy: mysql.ErrWindowNoRedefineOrderBy, codeWindowDuplicateName: mysql.ErrWindowDuplicateName, + codeErrTooBigPrecision: mysql.ErrTooBigPrecision, } terror.ErrClassToMySQLCodes[terror.ClassOptimizer] = mysqlErrCodeMap } diff --git a/planner/core/expression_rewriter.go b/planner/core/expression_rewriter.go index 4f00fbb23c962..5adc459cbca8f 100644 --- a/planner/core/expression_rewriter.go +++ b/planner/core/expression_rewriter.go @@ -797,6 +797,13 @@ func (er *expressionRewriter) Leave(originInNode ast.Node) (retNode ast.Node, ok if er.err != nil { return retNode, false } + + // check the decimal precision of "CAST(AS TIME)". + er.err = er.checkTimePrecision(v.Tp) + if er.err != nil { + return retNode, false + } + er.ctxStack[len(er.ctxStack)-1] = expression.BuildCastFunction(er.ctx, arg, v.Tp) case *ast.PatternLikeExpr: er.likeToScalarFunc(v) @@ -827,6 +834,13 @@ func (er *expressionRewriter) Leave(originInNode ast.Node) (retNode ast.Node, ok return originInNode, true } +func (er *expressionRewriter) checkTimePrecision(ft *types.FieldType) error { + if ft.EvalType() == types.ETDuration && ft.Decimal > types.MaxFsp { + return errTooBigPrecision.GenWithStackByArgs(ft.Decimal, "CAST", types.MaxFsp) + } + return nil +} + func (er *expressionRewriter) useCache() bool { return er.ctx.GetSessionVars().StmtCtx.UseCache } @@ -865,7 +879,13 @@ func (er *expressionRewriter) rewriteVariable(v *ast.VariableExpr) { return } } - if v.IsGlobal { + sysVar := variable.SysVars[name] + if sysVar == nil { + er.err = variable.UnknownSystemVar.GenWithStackByArgs(name) + return + } + // Variable is @@gobal.variable_name or variable is only global scope variable. + if v.IsGlobal || sysVar.Scope == variable.ScopeGlobal { val, err = variable.GetGlobalSystemVar(sessionVars, name) } else { val, err = variable.GetSessionSystemVar(sessionVars, name) diff --git a/planner/core/logical_plan_builder.go b/planner/core/logical_plan_builder.go index b677cdf7c96f5..640bf46caac1e 100644 --- a/planner/core/logical_plan_builder.go +++ b/planner/core/logical_plan_builder.go @@ -1333,7 +1333,7 @@ func (g *gbyResolver) Leave(inNode ast.Node) (ast.Node, bool) { func tblInfoFromCol(from ast.ResultSetNode, col *expression.Column) *model.TableInfo { var tableList []*ast.TableName - tableList = extractTableList(from, tableList) + tableList = extractTableList(from, tableList, true) for _, field := range tableList { if field.Name.L == col.TblName.L { return field.TableInfo @@ -2304,12 +2304,15 @@ func (b *PlanBuilder) buildUpdate(update *ast.UpdateStmt) (Plan, error) { } var tableList []*ast.TableName - tableList = extractTableList(sel.From.TableRefs, tableList) + tableList = extractTableList(sel.From.TableRefs, tableList, false) for _, t := range tableList { dbName := t.Schema.L if dbName == "" { dbName = b.ctx.GetSessionVars().CurrentDB } + if t.TableInfo.IsView() { + return nil, errors.Errorf("update view %s is not supported now.", t.Name.O) + } b.visitInfo = appendVisitInfo(b.visitInfo, mysql.SelectPriv, dbName, t.Name.L, "", nil) } @@ -2424,7 +2427,12 @@ func (b *PlanBuilder) buildUpdateLists(tableList []*ast.TableName, list []*ast.A } for _, assign := range newList { col := assign.Col - b.visitInfo = appendVisitInfo(b.visitInfo, mysql.UpdatePriv, col.DBName.L, col.TblName.L, "", nil) + + dbName := col.DBName.L + if dbName == "" { + dbName = b.ctx.GetSessionVars().CurrentDB + } + b.visitInfo = appendVisitInfo(b.visitInfo, mysql.UpdatePriv, dbName, col.OrigTblName.L, "", nil) } return newList, p, nil } @@ -2538,7 +2546,7 @@ func (b *PlanBuilder) buildDelete(delete *ast.DeleteStmt) (Plan, error) { del.SetSchema(expression.NewSchema()) var tableList []*ast.TableName - tableList = extractTableList(delete.TableRefs.TableRefs, tableList) + tableList = extractTableList(delete.TableRefs.TableRefs, tableList, true) // Collect visitInfo. if delete.Tables != nil { @@ -2574,11 +2582,17 @@ func (b *PlanBuilder) buildDelete(delete *ast.DeleteStmt) (Plan, error) { // check sql like: `delete b from (select * from t) as a, t` return nil, ErrUnknownTable.GenWithStackByArgs(tn.Name.O, "MULTI DELETE") } + if tn.TableInfo.IsView() { + return nil, errors.Errorf("delete view %s is not supported now.", tn.Name.O) + } b.visitInfo = appendVisitInfo(b.visitInfo, mysql.DeletePriv, tn.Schema.L, tn.TableInfo.Name.L, "", nil) } } else { // Delete from a, b, c, d. for _, v := range tableList { + if v.TableInfo.IsView() { + return nil, errors.Errorf("delete view %s is not supported now.", v.Name.O) + } dbName := v.Schema.L if dbName == "" { dbName = b.ctx.GetSessionVars().CurrentDB @@ -2762,14 +2776,16 @@ func buildWindowSpecs(specs []ast.WindowSpec) (map[string]ast.WindowSpec, error) } // extractTableList extracts all the TableNames from node. -func extractTableList(node ast.ResultSetNode, input []*ast.TableName) []*ast.TableName { +// If asName is true, extract AsName prior to OrigName. +// Privilege check should use OrigName, while expression may use AsName. +func extractTableList(node ast.ResultSetNode, input []*ast.TableName, asName bool) []*ast.TableName { switch x := node.(type) { case *ast.Join: - input = extractTableList(x.Left, input) - input = extractTableList(x.Right, input) + input = extractTableList(x.Left, input, asName) + input = extractTableList(x.Right, input, asName) case *ast.TableSource: if s, ok := x.Source.(*ast.TableName); ok { - if x.AsName.L != "" { + if x.AsName.L != "" && asName { newTableName := *s newTableName.Name = x.AsName s.Name = x.AsName diff --git a/planner/core/logical_plan_test.go b/planner/core/logical_plan_test.go index 3abf7e6caad04..9c50e8d38cdcc 100644 --- a/planner/core/logical_plan_test.go +++ b/planner/core/logical_plan_test.go @@ -1385,6 +1385,13 @@ func (s *testPlanSuite) TestVisitInfo(c *C) { {mysql.SelectPriv, "test", "t", "", nil}, }, }, + { + sql: "update t a1 set a1.a = a1.a + 1", + ans: []visitInfo{ + {mysql.UpdatePriv, "test", "t", "", nil}, + {mysql.SelectPriv, "test", "t", "", nil}, + }, + }, { sql: "select a, sum(e) from t group by a", ans: []visitInfo{ diff --git a/planner/core/planbuilder.go b/planner/core/planbuilder.go index ccd0f72e2f156..cea8103c07fb1 100644 --- a/planner/core/planbuilder.go +++ b/planner/core/planbuilder.go @@ -705,6 +705,9 @@ func getPhysicalIDs(tblInfo *model.TableInfo, partitionNames []model.CIStr) ([]i func (b *PlanBuilder) buildAnalyzeTable(as *ast.AnalyzeTableStmt) (Plan, error) { p := &Analyze{MaxNumBuckets: as.MaxNumBuckets} for _, tbl := range as.TableNames { + if tbl.TableInfo.IsView() { + return nil, errors.Errorf("analyze %s is not supported now.", tbl.Name.O) + } idxInfo, colInfo, pkInfo := getColsInfo(tbl) physicalIDs, err := getPhysicalIDs(tbl.TableInfo, as.PartitionNames) if err != nil { @@ -1097,6 +1100,13 @@ func (b *PlanBuilder) buildInsert(insert *ast.InsertStmt) (Plan, error) { return nil, infoschema.ErrTableNotExists.GenWithStackByArgs() } tableInfo := tn.TableInfo + if tableInfo.IsView() { + err := errors.Errorf("insert into view %s is not supported now.", tableInfo.Name.O) + if insert.IsReplace { + err = errors.Errorf("replace into view %s is not supported now.", tableInfo.Name.O) + } + return nil, err + } // Build Schema with DBName otherwise ColumnRef with DBName cannot match any Column in Schema. schema := expression.TableInfo2SchemaWithDBName(b.ctx, tn.Schema, tableInfo) tableInPlan, ok := b.is.TableByID(tableInfo.ID) @@ -1712,9 +1722,9 @@ func buildShowSchema(s *ast.ShowStmt) (schema *expression.Schema) { mysql.TypeVarchar, mysql.TypeVarchar, mysql.TypeLonglong, mysql.TypeLonglong, mysql.TypeVarchar, mysql.TypeVarchar, mysql.TypeVarchar, mysql.TypeVarchar, mysql.TypeVarchar} case ast.ShowPlugins: - names = []string{"Name", "Status", "Type", "Library", "License"} + names = []string{"Name", "Status", "Type", "Library", "License", "Version"} ftypes = []byte{ - mysql.TypeVarchar, mysql.TypeVarchar, mysql.TypeVarchar, mysql.TypeVarchar, mysql.TypeVarchar, + mysql.TypeVarchar, mysql.TypeVarchar, mysql.TypeVarchar, mysql.TypeVarchar, mysql.TypeVarchar, mysql.TypeVarchar, } case ast.ShowProcessList: names = []string{"Id", "User", "Host", "db", "Command", "Time", "State", "Info"} diff --git a/plugin/conn_ip_example/conn_ip_example.go b/plugin/conn_ip_example/conn_ip_example.go new file mode 100644 index 0000000000000..30cafec357fed --- /dev/null +++ b/plugin/conn_ip_example/conn_ip_example.go @@ -0,0 +1,49 @@ +// Copyright 2019 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "context" + "fmt" + + "github.com/pingcap/tidb/plugin" + "github.com/pingcap/tidb/sessionctx/variable" +) + +// Validate implements TiDB plugin's Validate SPI. +func Validate(ctx context.Context, m *plugin.Manifest) error { + fmt.Println("conn_ip_example validate called") + return nil +} + +// OnInit implements TiDB plugin's OnInit SPI. +func OnInit(ctx context.Context, manifest *plugin.Manifest) error { + fmt.Println("conn_ip_example init called") + fmt.Println("read cfg in init", manifest.SysVars["conn_ip_example_test_variable"].Value) + return nil +} + +// OnShutdown implements TiDB plugin's OnShutdown SPI. +func OnShutdown(ctx context.Context, manifest *plugin.Manifest) error { + fmt.Println("conn_ip_examples hutdown called") + return nil +} + +// NotifyEvent implements TiDB Audit plugin's NotifyEvent SPI. +func NotifyEvent(ctx context.Context) error { + fmt.Println("conn_ip_example notifiy called") + fmt.Println("variable test: ", variable.GetSysVar("conn_ip_example_test_variable").Value) + fmt.Printf("new connection by %s\n", ctx.Value("ip")) + return nil +} diff --git a/plugin/conn_ip_example/conn_ip_example_test.go b/plugin/conn_ip_example/conn_ip_example_test.go new file mode 100644 index 0000000000000..9c35f1e67f73e --- /dev/null +++ b/plugin/conn_ip_example/conn_ip_example_test.go @@ -0,0 +1,61 @@ +// Copyright 2019 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package main_test + +import ( + "context" + + "github.com/pingcap/tidb/plugin" + "github.com/pingcap/tidb/sessionctx/variable" +) + +func Example_LoadRunShutdownPlugin() { + ctx := context.Background() + var pluginVarNames []string + cfg := plugin.Config{ + Plugins: []string{"conn_ip_example-1"}, + PluginDir: "/home/robi/Code/go/src/github.com/pingcap/tidb/plugin/conn_ip_example", + GlobalSysVar: &variable.SysVars, + PluginVarNames: &pluginVarNames, + } + + err := plugin.Init(ctx, cfg) + if err != nil { + panic(err) + } + + ps := plugin.GetByKind(plugin.Audit) + for _, auditPlugin := range ps { + if auditPlugin.State != plugin.Ready { + continue + } + plugin.DeclareAuditManifest(auditPlugin.Manifest).NotifyEvent(context.Background(), nil) + } + + err = plugin.Reload(ctx, cfg, plugin.ID("conn_ip_example-2")) + if err != nil { + panic(err) + } + + for _, auditPlugin := range plugin.GetByKind(plugin.Audit) { + if auditPlugin.State != plugin.Ready { + continue + } + plugin.DeclareAuditManifest(auditPlugin.Manifest).NotifyEvent( + context.WithValue(context.Background(), "ip", "1.1.1.2"), nil, + ) + } + + plugin.Shutdown(context.Background()) +} diff --git a/plugin/conn_ip_example/manifest.toml b/plugin/conn_ip_example/manifest.toml new file mode 100644 index 0000000000000..8f1a2c74ba7f8 --- /dev/null +++ b/plugin/conn_ip_example/manifest.toml @@ -0,0 +1,15 @@ +name = "conn_ip_example" +kind = "Audit" +description = "just a test" +version = "1" +license = "" +sysVars = [ + {name="conn_ip_example_test_variable", scope="Global", value="2"}, + {name="conn_ip_example_test_variable2", scope="Session", value="2"}, +] +validate = "Validate" +onInit = "OnInit" +onShutdown = "OnShutdown" +export = [ + {extPoint="NotifyEvent", impl="NotifyEvent"} +] diff --git a/plugin/const.go b/plugin/const.go new file mode 100644 index 0000000000000..88ba5432110d6 --- /dev/null +++ b/plugin/const.go @@ -0,0 +1,70 @@ +// Copyright 2019 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package plugin + +// Kind presents the kind of plugin. +type Kind uint8 + +const ( + // Audit indicates it is a Audit plugin. + Audit Kind = 1 + iota + // Authentication indicate it is a Authentication plugin. + Authentication + // Schema indicate a plugin that can change TiDB schema. + Schema + // Daemon indicate a plugin that can run as daemon task. + Daemon +) + +func (k Kind) String() (str string) { + switch k { + case Audit: + str = "Audit" + case Authentication: + str = "Authentication" + case Schema: + str = "Schema" + case Daemon: + str = "Daemon" + } + return +} + +// State present the state of plugin. +type State uint8 + +const ( + // Uninitialized indicates plugin is uninitialized. + Uninitialized State = iota + // Ready indicates plugin is ready to work. + Ready + // Dying indicates plugin will be close soon. + Dying + // Disable indicate plugin is disabled. + Disable +) + +func (s State) String() (str string) { + switch s { + case Uninitialized: + str = "Uninitialized" + case Ready: + str = "Ready" + case Dying: + str = "Dying" + case Disable: + str = "Disable" + } + return +} diff --git a/plugin/errors.go b/plugin/errors.go new file mode 100644 index 0000000000000..938b5635cf4d4 --- /dev/null +++ b/plugin/errors.go @@ -0,0 +1,50 @@ +// Copyright 2019 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package plugin + +import ( + "github.com/pingcap/parser/mysql" + "github.com/pingcap/parser/terror" +) + +var ( + errInvalidPluginID = createPluginError(mysql.ErrInvalidPluginID) + errInvalidPluginManifest = createPluginError(mysql.ErrInvalidPluginManifest) + errInvalidPluginName = createPluginError(mysql.ErrInvalidPluginName) + errInvalidPluginVersion = createPluginError(mysql.ErrInvalidPluginVersion) + errDuplicatePlugin = createPluginError(mysql.ErrDuplicatePlugin) + errInvalidPluginSysVarName = createPluginError(mysql.ErrInvalidPluginSysVarName) + errRequireVersionCheckFail = createPluginError(mysql.ErrRequireVersionCheckFail) + errUnsupportedReloadPlugin = createPluginError(mysql.ErrUnsupportedReloadPlugin) + errUnsupportedReloadPluginVar = createPluginError(mysql.ErrUnsupportedReloadPluginVar) +) + +func createPluginError(code terror.ErrCode) *terror.Error { + return terror.ClassPlugin.New(code, mysql.MySQLErrName[uint16(code)]) +} + +func init() { + pluginMySQLErrCodes := map[terror.ErrCode]uint16{ + mysql.ErrInvalidPluginID: mysql.ErrInvalidPluginID, + mysql.ErrInvalidPluginManifest: mysql.ErrInvalidPluginManifest, + mysql.ErrInvalidPluginName: mysql.ErrInvalidPluginName, + mysql.ErrInvalidPluginVersion: mysql.ErrInvalidPluginVersion, + mysql.ErrDuplicatePlugin: mysql.ErrDuplicatePlugin, + mysql.ErrInvalidPluginSysVarName: mysql.ErrInvalidPluginSysVarName, + mysql.ErrRequireVersionCheckFail: mysql.ErrRequireVersionCheckFail, + mysql.ErrUnsupportedReloadPlugin: mysql.ErrUnsupportedReloadPlugin, + mysql.ErrUnsupportedReloadPluginVar: mysql.ErrUnsupportedReloadPluginVar, + } + terror.ErrClassToMySQLCodes[terror.ClassPlugin] = pluginMySQLErrCodes +} diff --git a/plugin/helper.go b/plugin/helper.go new file mode 100644 index 0000000000000..1d81cd9ac952b --- /dev/null +++ b/plugin/helper.go @@ -0,0 +1,54 @@ +// Copyright 2019 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package plugin + +import ( + "strings" + "unsafe" +) + +// DeclareAuditManifest declares manifest as AuditManifest. +func DeclareAuditManifest(m *Manifest) *AuditManifest { + return (*AuditManifest)(unsafe.Pointer(m)) +} + +// DeclareAuthenticationManifest declares manifest as AuthenticationManifest. +func DeclareAuthenticationManifest(m *Manifest) *AuthenticationManifest { + return (*AuthenticationManifest)(unsafe.Pointer(m)) +} + +// DeclareSchemaManifest declares manifest as SchemaManifest. +func DeclareSchemaManifest(m *Manifest) *SchemaManifest { + return (*SchemaManifest)(unsafe.Pointer(m)) +} + +// DeclareDaemonManifest declares manifest as DaemonManifest. +func DeclareDaemonManifest(m *Manifest) *DaemonManifest { + return (*DaemonManifest)(unsafe.Pointer(m)) +} + +// ID present plugin identity. +type ID string + +// Decode decodes a plugin id into name, version parts. +func (n ID) Decode() (name string, version string, err error) { + splits := strings.Split(string(n), "-") + if len(splits) != 2 { + err = errInvalidPluginID.GenWithStackByArgs(string(n)) + return + } + name = splits[0] + version = splits[1] + return +} diff --git a/plugin/plugin.go b/plugin/plugin.go new file mode 100644 index 0000000000000..a4a06fac15dbb --- /dev/null +++ b/plugin/plugin.go @@ -0,0 +1,373 @@ +// Copyright 2019 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package plugin + +import ( + "context" + "path/filepath" + gplugin "plugin" + "strconv" + "strings" + "sync/atomic" + "unsafe" + + "github.com/pingcap/errors" + "github.com/pingcap/tidb/sessionctx/variable" +) + +// pluginGlobal holds all global variables for plugin. +var pluginGlobal copyOnWriteContext + +// copyOnWriteContext wraps a context follow COW idiom. +type copyOnWriteContext struct { + tiPlugins unsafe.Pointer // *plugins +} + +// plugins collects loaded plugins info. +type plugins struct { + plugins map[Kind][]Plugin + versions map[string]uint16 + dyingPlugins []Plugin +} + +// clone deep copies plugins info. +func (p *plugins) clone() *plugins { + np := &plugins{ + plugins: make(map[Kind][]Plugin, len(p.plugins)), + versions: make(map[string]uint16, len(p.versions)), + } + for key, value := range p.plugins { + np.plugins[key] = append([]Plugin(nil), value...) + } + for key, value := range p.versions { + np.versions[key] = value + } + for key, value := range p.dyingPlugins { + np.dyingPlugins[key] = value + } + return np +} + +// add adds a plugin to loaded plugin collection. +func (p plugins) add(plugin *Plugin) { + plugins, ok := p.plugins[plugin.Kind] + if !ok { + plugins = make([]Plugin, 0) + } + plugins = append(plugins, *plugin) + p.plugins[plugin.Kind] = plugins + p.versions[plugin.Name] = plugin.Version +} + +// plugins got plugin in COW context. +func (p copyOnWriteContext) plugins() *plugins { + return (*plugins)(atomic.LoadPointer(&p.tiPlugins)) +} + +// Config presents the init configuration for plugin framework. +type Config struct { + Plugins []string + PluginDir string + GlobalSysVar *map[string]*variable.SysVar + PluginVarNames *[]string + SkipWhenFail bool + EnvVersion map[string]uint16 +} + +// Plugin presents a TiDB plugin. +type Plugin struct { + *Manifest + library *gplugin.Plugin + State State + Path string +} + +type validateMode int + +const ( + initMode validateMode = iota + reloadMode +) + +func (p *Plugin) validate(ctx context.Context, tiPlugins *plugins, mode validateMode) error { + if mode == reloadMode { + var oldPlugin *Plugin + for i, item := range tiPlugins.plugins[p.Kind] { + if item.Name == p.Name { + oldPlugin = &tiPlugins.plugins[p.Kind][i] + break + } + } + if oldPlugin == nil { + return errUnsupportedReloadPlugin.GenWithStackByArgs(p.Name) + } + if len(p.SysVars) != len(oldPlugin.SysVars) { + return errUnsupportedReloadPluginVar.GenWithStackByArgs("") + } + for varName, varVal := range p.SysVars { + if oldPlugin.SysVars[varName] == nil || *oldPlugin.SysVars[varName] != *varVal { + return errUnsupportedReloadPluginVar.GenWithStackByArgs(varVal) + } + } + } + if p.RequireVersion != nil { + for component, reqVer := range p.RequireVersion { + if ver, ok := tiPlugins.versions[component]; !ok || ver < reqVer { + return errRequireVersionCheckFail.GenWithStackByArgs(p.Name, component, reqVer, ver) + } + } + } + if p.SysVars != nil { + for varName := range p.SysVars { + if !strings.HasPrefix(varName, p.Name) { + return errInvalidPluginSysVarName.GenWithStackByArgs(p.Name, varName, p.Name) + } + } + } + if p.Manifest.Validate != nil { + if err := p.Manifest.Validate(ctx, p.Manifest); err != nil { + return err + } + } + return nil +} + +// Init initializes the plugin and load plugin by config param. +// This method isn't thread-safe and must be called before any other plugin operation. +func Init(ctx context.Context, cfg Config) (err error) { + tiPlugins := &plugins{ + plugins: make(map[Kind][]Plugin), + versions: make(map[string]uint16), + dyingPlugins: make([]Plugin, 0), + } + + // Setup component version info for plugin running env. + for component, version := range cfg.EnvVersion { + tiPlugins.versions[component] = version + } + + // Load plugin dl & manifest. + for _, pluginID := range cfg.Plugins { + var pName string + pName, _, err = ID(pluginID).Decode() + if err != nil { + err = errors.Trace(err) + return + } + // Check duplicate. + _, dup := tiPlugins.versions[pName] + if dup { + if cfg.SkipWhenFail { + continue + } + err = errDuplicatePlugin.GenWithStackByArgs(pluginID) + return + } + // Load dl. + var plugin Plugin + plugin, err = loadOne(cfg.PluginDir, ID(pluginID)) + if err != nil { + if cfg.SkipWhenFail { + continue + } + return + } + tiPlugins.add(&plugin) + } + + // Cross validate & Load plugins. + for kind := range tiPlugins.plugins { + for i := range tiPlugins.plugins[kind] { + if err = tiPlugins.plugins[kind][i].validate(ctx, tiPlugins, initMode); err != nil { + if cfg.SkipWhenFail { + tiPlugins.plugins[kind][i].State = Disable + err = nil + continue + } + return + } + p := tiPlugins.plugins[kind][i] + if err = p.OnInit(ctx, p.Manifest); err != nil { + if cfg.SkipWhenFail { + tiPlugins.plugins[kind][i].State = Disable + err = nil + continue + } + return + } + if cfg.GlobalSysVar != nil { + for key, value := range tiPlugins.plugins[kind][i].SysVars { + (*cfg.GlobalSysVar)[key] = value + if value.Scope != variable.ScopeSession && cfg.PluginVarNames != nil { + *cfg.PluginVarNames = append(*cfg.PluginVarNames, key) + } + } + } + tiPlugins.plugins[kind][i].State = Ready + } + } + pluginGlobal = copyOnWriteContext{tiPlugins: unsafe.Pointer(tiPlugins)} + err = nil + return +} + +func loadOne(dir string, pluginID ID) (plugin Plugin, err error) { + plugin.Path = filepath.Join(dir, string(pluginID)+LibrarySuffix) + plugin.library, err = gplugin.Open(plugin.Path) + if err != nil { + err = errors.Trace(err) + return + } + manifestSym, err := plugin.library.Lookup(ManifestSymbol) + if err != nil { + err = errors.Trace(err) + return + } + manifest, ok := manifestSym.(func() *Manifest) + if !ok { + err = errInvalidPluginManifest.GenWithStackByArgs(string(pluginID)) + return + } + pName, pVersion, err := pluginID.Decode() + if err != nil { + err = errors.Trace(err) + return + } + plugin.Manifest = manifest() + if plugin.Name != pName { + err = errInvalidPluginName.GenWithStackByArgs(string(pluginID), plugin.Name) + return + } + if strconv.Itoa(int(plugin.Version)) != pVersion { + err = errInvalidPluginVersion.GenWithStackByArgs(string(pluginID)) + return + } + return +} + +// Reload hot swap a old plugin with new version. +// Limit: loaded plugins shouldn't be unload and only be mark dying. +func Reload(ctx context.Context, cfg Config, pluginID ID) (err error) { + newPlugin, err := loadOne(cfg.PluginDir, pluginID) + if err != nil { + return + } + _, err = replace(ctx, cfg, newPlugin.Name, newPlugin) + return +} + +func replace(ctx context.Context, cfg Config, name string, newPlugin Plugin) (replaced bool, err error) { + + oldPlugins := pluginGlobal.plugins() + if oldPlugins.versions[name] == newPlugin.Version { + replaced = false + return + } + err = newPlugin.validate(ctx, oldPlugins, reloadMode) + if err != nil { + return + } + err = newPlugin.OnInit(ctx, newPlugin.Manifest) + if err != nil { + return + } + if cfg.GlobalSysVar != nil { + for key, value := range newPlugin.SysVars { + (*cfg.GlobalSysVar)[key] = value + } + } + + for { + oldPlugins = pluginGlobal.plugins() + newPlugins := oldPlugins.clone() + replaced = true + tiPluginKind := newPlugins.plugins[newPlugin.Kind] + var oldPlugin *Plugin + for i, p := range tiPluginKind { + if p.Name == name { + oldPlugin = &tiPluginKind[i] + tiPluginKind = append(tiPluginKind[:i], tiPluginKind[i+1:]...) + } + } + + if oldPlugin != nil { + oldPlugin.State = Dying + newPlugins.dyingPlugins = append(newPlugins.dyingPlugins, *oldPlugin) + err = oldPlugin.OnShutdown(ctx, oldPlugin.Manifest) + if err != nil { + // When shutdown failure, the plugin is in stranger state, so make it as Dying. + return + } + } + + newPlugin.State = Ready + tiPluginKind = append(tiPluginKind, newPlugin) + newPlugins.plugins[newPlugin.Kind] = tiPluginKind + newPlugins.versions[newPlugin.Name] = newPlugin.Version + + if atomic.CompareAndSwapPointer(&pluginGlobal.tiPlugins, unsafe.Pointer(oldPlugins), unsafe.Pointer(newPlugins)) { + return + } + } +} + +// Shutdown cleanups all plugin resources. +// Notice: it just cleanups the resource of plugin, but cannot unload plugins(limited by go plugin). +func Shutdown(ctx context.Context) { + for { + tiPlugins := pluginGlobal.plugins() + for _, plugins := range tiPlugins.plugins { + for _, p := range plugins { + p.State = Dying + if err := p.OnShutdown(ctx, p.Manifest); err != nil { + } + } + } + if atomic.CompareAndSwapPointer(&pluginGlobal.tiPlugins, unsafe.Pointer(tiPlugins), nil) { + return + } + } +} + +// Get finds and returns plugin by kind and name parameters. +func Get(kind Kind, name string) *Plugin { + plugins := pluginGlobal.plugins() + if plugins == nil { + return nil + } + for _, p := range plugins.plugins[kind] { + if p.Name == name { + return &p + } + } + return nil +} + +// GetByKind finds and returns plugin by kind parameters. +func GetByKind(kind Kind) []Plugin { + plugins := pluginGlobal.plugins() + if plugins == nil { + return nil + } + return plugins.plugins[kind] +} + +// GetAll finds and returns all plugins. +func GetAll() map[Kind][]Plugin { + plugins := pluginGlobal.plugins() + if plugins == nil { + return nil + } + return plugins.plugins +} diff --git a/plugin/spi.go b/plugin/spi.go new file mode 100644 index 0000000000000..15684f7b0e154 --- /dev/null +++ b/plugin/spi.go @@ -0,0 +1,77 @@ +// Copyright 2019 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package plugin + +import ( + "context" + "reflect" + "unsafe" + + "github.com/pingcap/tidb/sessionctx/variable" +) + +const ( + // LibrarySuffix defines TiDB plugin's file suffix. + LibrarySuffix = ".so" + // ManifestSymbol defines TiDB plugin's entrance symbol. + // Plugin take manifest info from this symbol. + ManifestSymbol = "PluginManifest" +) + +// Manifest describes plugin info and how it can do by plugin itself. +type Manifest struct { + Kind Kind + Name string + Description string + Version uint16 + RequireVersion map[string]uint16 + License string + BuildTime string + SysVars map[string]*variable.SysVar + Validate func(ctx context.Context, manifest *Manifest) error + OnInit func(ctx context.Context, manifest *Manifest) error + OnShutdown func(ctx context.Context, manifest *Manifest) error +} + +// ExportManifest exports a manifest to TiDB as a known format. +// it just casts sub-manifest to manifest. +func ExportManifest(m interface{}) *Manifest { + v := reflect.ValueOf(m) + return (*Manifest)(unsafe.Pointer(v.Pointer())) +} + +// AuditManifest presents a sub-manifest that every audit plugin must provide. +type AuditManifest struct { + Manifest + NotifyEvent func(ctx context.Context, sctx *variable.SessionVars) error +} + +// AuthenticationManifest presents a sub-manifest that every audit plugin must provide. +type AuthenticationManifest struct { + Manifest + AuthenticateUser func() + GenerateAuthenticationString func() + ValidateAuthenticationString func() + SetSalt func() +} + +// SchemaManifest presents a sub-manifest that every schema plugins must provide. +type SchemaManifest struct { + Manifest +} + +// DaemonManifest presents a sub-manifest that every DaemonManifest plugins must provide. +type DaemonManifest struct { + Manifest +} diff --git a/plugin/spi_test.go b/plugin/spi_test.go new file mode 100644 index 0000000000000..98e676acfcc3a --- /dev/null +++ b/plugin/spi_test.go @@ -0,0 +1,51 @@ +// Copyright 2019 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package plugin_test + +import ( + "context" + "testing" + + "github.com/pingcap/tidb/plugin" + "github.com/pingcap/tidb/sessionctx/variable" +) + +func TestExportManifest(t *testing.T) { + callRecorder := struct { + OnInitCalled bool + NotifyEventCalled bool + }{} + manifest := &plugin.AuditManifest{ + Manifest: plugin.Manifest{ + Kind: plugin.Authentication, + Name: "test audit", + Version: 1, + OnInit: func(ctx context.Context, manifest *plugin.Manifest) error { + callRecorder.OnInitCalled = true + return nil + }, + }, + NotifyEvent: func(ctx context.Context, sctx *variable.SessionVars) error { + callRecorder.NotifyEventCalled = true + return nil + }, + } + exported := plugin.ExportManifest(manifest) + exported.OnInit(context.Background(), exported) + audit := plugin.DeclareAuditManifest(exported) + audit.NotifyEvent(context.Background(), nil) + if !callRecorder.NotifyEventCalled || !callRecorder.OnInitCalled { + t.Fatalf("export test failure") + } +} diff --git a/privilege/privileges/cache.go b/privilege/privileges/cache.go index bba0e41fcd40b..fd80102ba61a4 100644 --- a/privilege/privileges/cache.go +++ b/privilege/privileges/cache.go @@ -277,16 +277,16 @@ func (p *MySQLPrivilege) loadTable(sctx sessionctx.Context, sql string, defer terror.Call(rs.Close) fs := rs.Fields() - chk := rs.NewChunk() + req := rs.NewRecordBatch() for { - err = rs.Next(context.TODO(), chk) + err = rs.Next(context.TODO(), req) if err != nil { return errors.Trace(err) } - if chk.NumRows() == 0 { + if req.NumRows() == 0 { return nil } - it := chunk.NewIterator4Chunk(chk) + it := chunk.NewIterator4Chunk(req.Chunk) for row := it.Begin(); row != it.End(); row = it.Next() { err = decodeTableRow(row, fs) if err != nil { @@ -296,7 +296,7 @@ func (p *MySQLPrivilege) loadTable(sctx sessionctx.Context, sql string, // NOTE: decodeTableRow decodes data from a chunk Row, that is a shallow copy. // The result will reference memory in the chunk, so the chunk must not be reused // here, otherwise some werid bug will happen! - chk = chunk.Renew(chk, sctx.GetSessionVars().MaxChunkSize) + req.Chunk = chunk.Renew(req.Chunk, sctx.GetSessionVars().MaxChunkSize) } } diff --git a/server/conn.go b/server/conn.go index 9068b09f886fc..7eebdd7a8bdbe 100644 --- a/server/conn.go +++ b/server/conn.go @@ -525,7 +525,6 @@ func (cc *clientConn) openSessionAndDoAuth(authData []byte) error { // This function returns and the connection is closed if there is an IO error or there is a panic. func (cc *clientConn) Run() { const size = 4096 - closedOutside := false defer func() { r := recover() if r != nil { @@ -535,7 +534,7 @@ func (cc *clientConn) Run() { log.Errorf("lastCmd %s, %v, %s", cc.lastCmd, r, buf) metrics.PanicCounter.WithLabelValues(metrics.LabelSession).Inc() } - if !closedOutside { + if atomic.LoadInt32(&cc.status) != connStatusShutdown { err := cc.Close() terror.Log(errors.Trace(err)) } @@ -548,9 +547,6 @@ func (cc *clientConn) Run() { // by CAS operation, it would then take some actions accordingly. for { if atomic.CompareAndSwapInt32(&cc.status, connStatusDispatching, connStatusReading) == false { - if atomic.LoadInt32(&cc.status) == connStatusShutdown { - closedOutside = true - } return } @@ -577,9 +573,6 @@ func (cc *clientConn) Run() { } if atomic.CompareAndSwapInt32(&cc.status, connStatusReading, connStatusDispatching) == false { - if atomic.LoadInt32(&cc.status) == connStatusShutdown { - closedOutside = true - } return } @@ -1131,11 +1124,11 @@ func (cc *clientConn) writeColumnInfo(columns []*ColumnInfo, serverStatus uint16 // serverStatus, a flag bit represents server information func (cc *clientConn) writeChunks(ctx context.Context, rs ResultSet, binary bool, serverStatus uint16) error { data := make([]byte, 4, 1024) - chk := rs.NewChunk() + req := rs.NewRecordBatch() gotColumnInfo := false for { // Here server.tidbResultSet implements Next method. - err := rs.Next(ctx, chk) + err := rs.Next(ctx, req) if err != nil { return errors.Trace(err) } @@ -1149,16 +1142,16 @@ func (cc *clientConn) writeChunks(ctx context.Context, rs ResultSet, binary bool } gotColumnInfo = true } - rowCount := chk.NumRows() + rowCount := req.NumRows() if rowCount == 0 { break } for i := 0; i < rowCount; i++ { data = data[0:4] if binary { - data, err = dumpBinaryRow(data, rs.Columns(), chk.GetRow(i)) + data, err = dumpBinaryRow(data, rs.Columns(), req.GetRow(i)) } else { - data, err = dumpTextRow(data, rs.Columns(), chk.GetRow(i)) + data, err = dumpTextRow(data, rs.Columns(), req.GetRow(i)) } if err != nil { return errors.Trace(err) @@ -1179,22 +1172,22 @@ func (cc *clientConn) writeChunksWithFetchSize(ctx context.Context, rs ResultSet fetchedRows := rs.GetFetchedRows() // if fetchedRows is not enough, getting data from recordSet. - chk := rs.NewChunk() + req := rs.NewRecordBatch() for len(fetchedRows) < fetchSize { // Here server.tidbResultSet implements Next method. - err := rs.Next(ctx, chk) + err := rs.Next(ctx, req) if err != nil { return errors.Trace(err) } - rowCount := chk.NumRows() + rowCount := req.NumRows() if rowCount == 0 { break } // filling fetchedRows with chunk for i := 0; i < rowCount; i++ { - fetchedRows = append(fetchedRows, chk.GetRow(i)) + fetchedRows = append(fetchedRows, req.GetRow(i)) } - chk = chunk.Renew(chk, cc.ctx.GetSessionVars().MaxChunkSize) + req.Chunk = chunk.Renew(req.Chunk, cc.ctx.GetSessionVars().MaxChunkSize) } // tell the client COM_STMT_FETCH has finished by setting proper serverStatus, diff --git a/server/driver.go b/server/driver.go index 4e577ffbd7be8..8ac62457c16cb 100644 --- a/server/driver.go +++ b/server/driver.go @@ -136,8 +136,8 @@ type PreparedStatement interface { // ResultSet is the result set of an query. type ResultSet interface { Columns() []*ColumnInfo - NewChunk() *chunk.Chunk - Next(context.Context, *chunk.Chunk) error + NewRecordBatch() *chunk.RecordBatch + Next(context.Context, *chunk.RecordBatch) error StoreFetchedRows(rows []chunk.Row) GetFetchedRows() []chunk.Row Close() error diff --git a/server/driver_tidb.go b/server/driver_tidb.go index 7d57e966c3bad..67dcd546dc54b 100644 --- a/server/driver_tidb.go +++ b/server/driver_tidb.go @@ -356,12 +356,12 @@ type tidbResultSet struct { closed bool } -func (trs *tidbResultSet) NewChunk() *chunk.Chunk { - return trs.recordSet.NewChunk() +func (trs *tidbResultSet) NewRecordBatch() *chunk.RecordBatch { + return trs.recordSet.NewRecordBatch() } -func (trs *tidbResultSet) Next(ctx context.Context, chk *chunk.Chunk) error { - return trs.recordSet.Next(ctx, chk) +func (trs *tidbResultSet) Next(ctx context.Context, req *chunk.RecordBatch) error { + return trs.recordSet.Next(ctx, req) } func (trs *tidbResultSet) StoreFetchedRows(rows []chunk.Row) { diff --git a/server/server.go b/server/server.go index 6f52d0ddb2d19..fe04fcb6d395e 100644 --- a/server/server.go +++ b/server/server.go @@ -29,9 +29,11 @@ package server import ( + "context" "crypto/tls" "crypto/x509" "fmt" + "io" "io/ioutil" "math/rand" "net" @@ -80,6 +82,7 @@ type Server struct { tlsConfig *tls.Config driver IDriver listener net.Listener + socket net.Listener rwlock *sync.RWMutex concurrentLimiter *TokenLimiter clients map[uint32]*clientConn @@ -133,6 +136,39 @@ func (s *Server) isUnixSocket() bool { return s.cfg.Socket != "" } +func (s *Server) forwardUnixSocketToTCP() { + addr := fmt.Sprintf("%s:%d", s.cfg.Host, s.cfg.Port) + for { + if s.listener == nil { + return // server shutdown has started + } + if uconn, err := s.socket.Accept(); err == nil { + log.Infof("server socket forwarding from [%s] to [%s]", s.cfg.Socket, addr) + go s.handleForwardedConnection(uconn, addr) + } else { + if s.listener != nil { + log.Errorf("server failed to forward from [%s] to [%s], err: %s", s.cfg.Socket, addr, err) + } + } + } +} + +func (s *Server) handleForwardedConnection(uconn net.Conn, addr string) { + defer terror.Call(uconn.Close) + if tconn, err := net.Dial("tcp", addr); err == nil { + go func() { + if _, err := io.Copy(uconn, tconn); err != nil { + log.Warningf("copy server to socket failed: %s", err) + } + }() + if _, err := io.Copy(tconn, uconn); err != nil { + log.Warningf("socket forward copy failed: %s", err) + } + } else { + log.Warningf("socket forward failed: could not connect to [%s], err: %s", addr, err) + } +} + // NewServer creates a new Server. func NewServer(cfg *config.Config, driver IDriver) (*Server, error) { s := &Server{ @@ -151,15 +187,24 @@ func NewServer(cfg *config.Config, driver IDriver) (*Server, error) { } var err error - if cfg.Socket != "" { - if s.listener, err = net.Listen("unix", cfg.Socket); err == nil { - log.Infof("Server is running MySQL Protocol through Socket [%s]", cfg.Socket) - } - } else { + + if s.cfg.Host != "" && s.cfg.Port != 0 { addr := fmt.Sprintf("%s:%d", s.cfg.Host, s.cfg.Port) if s.listener, err = net.Listen("tcp", addr); err == nil { log.Infof("Server is running MySQL Protocol at [%s]", addr) + if cfg.Socket != "" { + if s.socket, err = net.Listen("unix", s.cfg.Socket); err == nil { + log.Infof("Server redirecting [%s] to [%s]", s.cfg.Socket, addr) + go s.forwardUnixSocketToTCP() + } + } + } + } else if cfg.Socket != "" { + if s.listener, err = net.Listen("unix", cfg.Socket); err == nil { + log.Infof("Server is running MySQL Protocol through Socket [%s]", cfg.Socket) } + } else { + err = errors.New("Server not configured to listen on either -socket or -host and -port") } if cfg.ProxyProtocol.Networks != "" { @@ -292,6 +337,11 @@ func (s *Server) Close() { terror.Log(errors.Trace(err)) s.listener = nil } + if s.socket != nil { + err := s.socket.Close() + terror.Log(errors.Trace(err)) + s.socket = nil + } if s.statusServer != nil { err := s.statusServer.Close() terror.Log(errors.Trace(err)) @@ -387,22 +437,49 @@ func (s *Server) KillAllConnections() { } } +var gracefulCloseConnectionsTimeout = 15 * time.Second + +// TryGracefulDown will try to gracefully close all connection first with timeout. if timeout, will close all connection directly. +func (s *Server) TryGracefulDown() { + ctx, cancel := context.WithTimeout(context.Background(), gracefulCloseConnectionsTimeout) + defer cancel() + done := make(chan struct{}) + go func() { + s.GracefulDown(ctx, done) + }() + select { + case <-ctx.Done(): + s.KillAllConnections() + case <-done: + return + } +} + // GracefulDown waits all clients to close. -func (s *Server) GracefulDown() { +func (s *Server) GracefulDown(ctx context.Context, done chan struct{}) { log.Info("[server] graceful shutdown.") metrics.ServerEventCounter.WithLabelValues(metrics.EventGracefulDown).Inc() count := s.ConnectionCount() for i := 0; count > 0; i++ { - time.Sleep(time.Second) s.kickIdleConnection() count = s.ConnectionCount() + if count == 0 { + break + } // Print information for every 30s. if i%30 == 0 { log.Infof("graceful shutdown...connection count %d\n", count) } + ticker := time.After(time.Second) + select { + case <-ctx.Done(): + return + case <-ticker: + } } + close(done) } func (s *Server) kickIdleConnection() { @@ -419,7 +496,7 @@ func (s *Server) kickIdleConnection() { for _, cc := range conns { err := cc.Close() if err != nil { - log.Error("close connection error:", err) + log.Errorf("close connection error: %s", err) } } } diff --git a/server/tidb_test.go b/server/tidb_test.go index 2a7a18264f15e..65a96c19072fe 100644 --- a/server/tidb_test.go +++ b/server/tidb_test.go @@ -160,9 +160,34 @@ func (ts *TidbTestSuite) TestMultiStatements(c *C) { runTestMultiStatements(c) } +func (ts *TidbTestSuite) TestSocketForwarding(c *C) { + cfg := config.NewConfig() + cfg.Socket = "/tmp/tidbtest.sock" + cfg.Port = 3999 + os.Remove(cfg.Socket) + cfg.Status.ReportStatus = false + + server, err := NewServer(cfg, ts.tidbdrv) + c.Assert(err, IsNil) + go server.Run() + time.Sleep(time.Millisecond * 100) + defer server.Close() + + runTestRegression(c, func(config *mysql.Config) { + config.User = "root" + config.Net = "unix" + config.Addr = "/tmp/tidbtest.sock" + config.DBName = "test" + config.Strict = true + }, "SocketRegression") +} + func (ts *TidbTestSuite) TestSocket(c *C) { cfg := config.NewConfig() cfg.Socket = "/tmp/tidbtest.sock" + cfg.Port = 0 + os.Remove(cfg.Socket) + cfg.Host = "" cfg.Status.ReportStatus = false server, err := NewServer(cfg, ts.tidbdrv) @@ -178,6 +203,7 @@ func (ts *TidbTestSuite) TestSocket(c *C) { config.DBName = "test" config.Strict = true }, "SocketRegression") + } // generateCert generates a private key and a certificate in PEM format based on parameters. @@ -415,14 +441,14 @@ func (ts *TidbTestSuite) TestCreateTableFlen(c *C) { c.Assert(err, IsNil) rs, err := qctx.Execute(ctx, "show create table t1") c.Assert(err, IsNil) - chk := rs[0].NewChunk() - err = rs[0].Next(ctx, chk) + req := rs[0].NewRecordBatch() + err = rs[0].Next(ctx, req) c.Assert(err, IsNil) cols := rs[0].Columns() c.Assert(err, IsNil) c.Assert(len(cols), Equals, 2) c.Assert(int(cols[0].ColumnLength), Equals, 5*tmysql.MaxBytesOfCharacter) - c.Assert(int(cols[1].ColumnLength), Equals, len(chk.GetRow(0).GetString(1))*tmysql.MaxBytesOfCharacter) + c.Assert(int(cols[1].ColumnLength), Equals, len(req.GetRow(0).GetString(1))*tmysql.MaxBytesOfCharacter) // for issue#5246 rs, err = qctx.Execute(ctx, "select y, z from t1") @@ -445,8 +471,8 @@ func (ts *TidbTestSuite) TestShowTablesFlen(c *C) { c.Assert(err, IsNil) rs, err := qctx.Execute(ctx, "show tables") c.Assert(err, IsNil) - chk := rs[0].NewChunk() - err = rs[0].Next(ctx, chk) + req := rs[0].NewRecordBatch() + err = rs[0].Next(ctx, req) c.Assert(err, IsNil) cols := rs[0].Columns() c.Assert(err, IsNil) diff --git a/session/bench_test.go b/session/bench_test.go index 395180962fcc2..df8c8d89ce64a 100644 --- a/session/bench_test.go +++ b/session/bench_test.go @@ -84,16 +84,16 @@ func prepareJoinBenchData(se Session, colType string, valueFormat string, valueC } func readResult(ctx context.Context, rs sqlexec.RecordSet, count int) { - chk := rs.NewChunk() + req := rs.NewRecordBatch() for count > 0 { - err := rs.Next(ctx, chk) + err := rs.Next(ctx, req) if err != nil { log.Fatal(err) } - if chk.NumRows() == 0 { + if req.NumRows() == 0 { log.Fatal(count) } - count -= chk.NumRows() + count -= req.NumRows() } rs.Close() } diff --git a/session/bootstrap.go b/session/bootstrap.go index bf73987d9cbcd..4ef7e2ee2a37a 100644 --- a/session/bootstrap.go +++ b/session/bootstrap.go @@ -262,6 +262,7 @@ const ( version22 = 22 version23 = 23 version24 = 24 + version25 = 25 ) func checkBootstrapped(s Session) (bool, error) { @@ -303,12 +304,12 @@ func getTiDBVar(s Session, name string) (sVal string, isNull bool, e error) { } r := rs[0] defer terror.Call(r.Close) - chk := r.NewChunk() - err = r.Next(ctx, chk) - if err != nil || chk.NumRows() == 0 { + req := r.NewRecordBatch() + err = r.Next(ctx, req) + if err != nil || req.NumRows() == 0 { return "", true, errors.Trace(err) } - row := chk.GetRow(0) + row := req.GetRow(0) if row.IsNull(0) { return "", true, nil } @@ -416,6 +417,10 @@ func upgrade(s Session) { upgradeToVer24(s) } + if ver < version25 { + upgradeToVer25(s) + } + updateBootstrapVer(s) _, err = s.Execute(context.Background(), "COMMIT") @@ -536,10 +541,10 @@ func upgradeToVer12(s Session) { r := rs[0] sqls := make([]string, 0, 1) defer terror.Call(r.Close) - chk := r.NewChunk() - it := chunk.NewIterator4Chunk(chk) - err = r.Next(ctx, chk) - for err == nil && chk.NumRows() != 0 { + req := r.NewRecordBatch() + it := chunk.NewIterator4Chunk(req.Chunk) + err = r.Next(ctx, req) + for err == nil && req.NumRows() != 0 { for row := it.Begin(); row != it.End(); row = it.Next() { user := row.GetString(0) host := row.GetString(1) @@ -550,7 +555,7 @@ func upgradeToVer12(s Session) { updateSQL := fmt.Sprintf(`UPDATE HIGH_PRIORITY mysql.user set password = "%s" where user="%s" and host="%s"`, newPass, user, host) sqls = append(sqls, updateSQL) } - err = r.Next(ctx, chk) + err = r.Next(ctx, req) } terror.MustNil(err) @@ -670,6 +675,13 @@ func upgradeToVer24(s Session) { writeSystemTZ(s) } +// upgradeToVer25 updates tidb_max_chunk_size to new low bound value 32 if previous value is small than 32. +func upgradeToVer25(s Session) { + sql := fmt.Sprintf("UPDATE HIGH_PRIORITY %[1]s.%[2]s SET VARIABLE_VALUE = '%[4]d' WHERE VARIABLE_NAME = '%[3]s' AND VARIABLE_VALUE < %[4]d", + mysql.SystemDB, mysql.GlobalVariablesTable, variable.TiDBMaxChunkSize, variable.DefInitChunkSize) + mustExecute(s, sql) +} + // updateBootstrapVer updates bootstrap version variable in mysql.TiDB table. func updateBootstrapVer(s Session) { // Update bootstrap version. diff --git a/session/bootstrap_test.go b/session/bootstrap_test.go index 63c6c71668c33..2c8cd6ecff5a4 100644 --- a/session/bootstrap_test.go +++ b/session/bootstrap_test.go @@ -51,11 +51,11 @@ func (s *testBootstrapSuite) TestBootstrap(c *C) { r := mustExecSQL(c, se, `select * from user;`) c.Assert(r, NotNil) ctx := context.Background() - chk := r.NewChunk() - err := r.Next(ctx, chk) + req := r.NewRecordBatch() + err := r.Next(ctx, req) c.Assert(err, IsNil) - c.Assert(chk.NumRows() == 0, IsFalse) - datums := statistics.RowToDatums(chk.GetRow(0), r.Fields()) + c.Assert(req.NumRows() == 0, IsFalse) + datums := statistics.RowToDatums(req.GetRow(0), r.Fields()) match(c, datums, []byte(`%`), []byte("root"), []byte(""), "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y") c.Assert(se.Auth(&auth.UserIdentity{Username: "root", Hostname: "anyhost"}, []byte(""), []byte("")), IsTrue) @@ -67,10 +67,10 @@ func (s *testBootstrapSuite) TestBootstrap(c *C) { // Check privilege tables. r = mustExecSQL(c, se, "SELECT COUNT(*) from mysql.global_variables;") c.Assert(r, NotNil) - chk = r.NewChunk() - err = r.Next(ctx, chk) + req = r.NewRecordBatch() + err = r.Next(ctx, req) c.Assert(err, IsNil) - c.Assert(chk.GetRow(0).GetInt64(0), Equals, globalVarsCount()) + c.Assert(req.GetRow(0).GetInt64(0), Equals, globalVarsCount()) // Check a storage operations are default autocommit after the second start. mustExecSQL(c, se, "USE test;") @@ -88,10 +88,10 @@ func (s *testBootstrapSuite) TestBootstrap(c *C) { r = mustExecSQL(c, se, "select * from t") c.Assert(r, NotNil) - chk = r.NewChunk() - err = r.Next(ctx, chk) + req = r.NewRecordBatch() + err = r.Next(ctx, req) c.Assert(err, IsNil) - datums = statistics.RowToDatums(chk.GetRow(0), r.Fields()) + datums = statistics.RowToDatums(req.GetRow(0), r.Fields()) match(c, datums, 3) mustExecSQL(c, se, "drop table if exists t") se.Close() @@ -154,11 +154,11 @@ func (s *testBootstrapSuite) TestBootstrapWithError(c *C) { se := newSession(c, store, s.dbNameBootstrap) mustExecSQL(c, se, "USE mysql;") r := mustExecSQL(c, se, `select * from user;`) - chk := r.NewChunk() - err = r.Next(ctx, chk) + req := r.NewRecordBatch() + err = r.Next(ctx, req) c.Assert(err, IsNil) - c.Assert(chk.NumRows() == 0, IsFalse) - row := chk.GetRow(0) + c.Assert(req.NumRows() == 0, IsFalse) + row := req.GetRow(0) datums := statistics.RowToDatums(row, r.Fields()) match(c, datums, []byte(`%`), []byte("root"), []byte(""), "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y") c.Assert(r.Close(), IsNil) @@ -170,19 +170,19 @@ func (s *testBootstrapSuite) TestBootstrapWithError(c *C) { mustExecSQL(c, se, "SELECT * from mysql.columns_priv;") // Check global variables. r = mustExecSQL(c, se, "SELECT COUNT(*) from mysql.global_variables;") - chk = r.NewChunk() - err = r.Next(ctx, chk) + req = r.NewRecordBatch() + err = r.Next(ctx, req) c.Assert(err, IsNil) - v := chk.GetRow(0) + v := req.GetRow(0) c.Assert(v.GetInt64(0), Equals, globalVarsCount()) c.Assert(r.Close(), IsNil) r = mustExecSQL(c, se, `SELECT VARIABLE_VALUE from mysql.TiDB where VARIABLE_NAME="bootstrapped";`) - chk = r.NewChunk() - err = r.Next(ctx, chk) + req = r.NewRecordBatch() + err = r.Next(ctx, req) c.Assert(err, IsNil) - c.Assert(chk.NumRows() == 0, IsFalse) - row = chk.GetRow(0) + c.Assert(req.NumRows() == 0, IsFalse) + row = req.GetRow(0) c.Assert(row.Len(), Equals, 1) c.Assert(row.GetBytes(0), BytesEquals, []byte("True")) c.Assert(r.Close(), IsNil) @@ -199,11 +199,11 @@ func (s *testBootstrapSuite) TestUpgrade(c *C) { // bootstrap with currentBootstrapVersion r := mustExecSQL(c, se, `SELECT VARIABLE_VALUE from mysql.TiDB where VARIABLE_NAME="tidb_server_version";`) - chk := r.NewChunk() - err := r.Next(ctx, chk) - row := chk.GetRow(0) + req := r.NewRecordBatch() + err := r.Next(ctx, req) + row := req.GetRow(0) c.Assert(err, IsNil) - c.Assert(chk.NumRows() == 0, IsFalse) + c.Assert(req.NumRows() == 0, IsFalse) c.Assert(row.Len(), Equals, 1) c.Assert(row.GetBytes(0), BytesEquals, []byte(fmt.Sprintf("%d", currentBootstrapVersion))) c.Assert(r.Close(), IsNil) @@ -229,10 +229,10 @@ func (s *testBootstrapSuite) TestUpgrade(c *C) { delete(storeBootstrapped, store.UUID()) // Make sure the version is downgraded. r = mustExecSQL(c, se1, `SELECT VARIABLE_VALUE from mysql.TiDB where VARIABLE_NAME="tidb_server_version";`) - chk = r.NewChunk() - err = r.Next(ctx, chk) + req = r.NewRecordBatch() + err = r.Next(ctx, req) c.Assert(err, IsNil) - c.Assert(chk.NumRows() == 0, IsTrue) + c.Assert(req.NumRows() == 0, IsTrue) c.Assert(r.Close(), IsNil) ver, err = getBootstrapVersion(se1) @@ -245,11 +245,11 @@ func (s *testBootstrapSuite) TestUpgrade(c *C) { defer dom1.Close() se2 := newSession(c, store, s.dbName) r = mustExecSQL(c, se2, `SELECT VARIABLE_VALUE from mysql.TiDB where VARIABLE_NAME="tidb_server_version";`) - chk = r.NewChunk() - err = r.Next(ctx, chk) + req = r.NewRecordBatch() + err = r.Next(ctx, req) c.Assert(err, IsNil) - c.Assert(chk.NumRows() == 0, IsFalse) - row = chk.GetRow(0) + c.Assert(req.NumRows() == 0, IsFalse) + row = req.GetRow(0) c.Assert(row.Len(), Equals, 1) c.Assert(row.GetBytes(0), BytesEquals, []byte(fmt.Sprintf("%d", currentBootstrapVersion))) c.Assert(r.Close(), IsNil) diff --git a/session/session.go b/session/session.go index 776dcdb7e5985..419777ae197da 100644 --- a/session/session.go +++ b/session/session.go @@ -39,6 +39,7 @@ import ( "github.com/pingcap/parser/model" "github.com/pingcap/parser/mysql" "github.com/pingcap/parser/terror" + "github.com/pingcap/tidb/config" "github.com/pingcap/tidb/domain" "github.com/pingcap/tidb/executor" "github.com/pingcap/tidb/kv" @@ -46,6 +47,7 @@ import ( "github.com/pingcap/tidb/metrics" "github.com/pingcap/tidb/owner" plannercore "github.com/pingcap/tidb/planner/core" + "github.com/pingcap/tidb/plugin" "github.com/pingcap/tidb/privilege" "github.com/pingcap/tidb/privilege/privileges" "github.com/pingcap/tidb/sessionctx" @@ -514,23 +516,42 @@ func (s *session) isRetryableError(err error) bool { return kv.IsRetryableError(err) || domain.ErrInfoSchemaChanged.Equal(err) } -func (s *session) retry(ctx context.Context, maxCnt uint) error { - connID := s.sessionVars.ConnectionID - if s.sessionVars.TxnCtx.ForUpdate { - return errForUpdateCantRetry.GenWithStackByArgs(connID) +func (s *session) checkTxnAborted(stmt sqlexec.Statement) error { + if s.txn.doNotCommit == nil { + return nil } - s.sessionVars.RetryInfo.Retrying = true + // If the transaction is aborted, the following statements do not need to execute, except `commit` and `rollback`, + // because they are used to finish the aborted transaction. + if _, ok := stmt.(*executor.ExecStmt).StmtNode.(*ast.CommitStmt); ok { + return nil + } + if _, ok := stmt.(*executor.ExecStmt).StmtNode.(*ast.RollbackStmt); ok { + return nil + } + return errors.New("current transaction is aborted, commands ignored until end of transaction block") +} + +func (s *session) retry(ctx context.Context, maxCnt uint) (err error) { var retryCnt uint defer func() { s.sessionVars.RetryInfo.Retrying = false - s.txn.changeToInvalid() // retryCnt only increments on retryable error, so +1 here. metrics.SessionRetry.Observe(float64(retryCnt + 1)) s.sessionVars.SetStatusFlag(mysql.ServerStatusInTrans, false) + if err != nil { + s.rollbackOnError(ctx) + } + s.txn.changeToInvalid() }() + connID := s.sessionVars.ConnectionID + s.sessionVars.RetryInfo.Retrying = true + if s.sessionVars.TxnCtx.ForUpdate { + err = errForUpdateCantRetry.GenWithStackByArgs(connID) + return err + } + nh := GetHistory(s) - var err error var schemaVersion int64 sessVars := s.GetSessionVars() orgStartTS := sessVars.TxnCtx.StartTS @@ -739,17 +760,17 @@ func createSessionWithDomainFunc(store kv.Storage) func(*domain.Domain) (pools.R func drainRecordSet(ctx context.Context, se *session, rs sqlexec.RecordSet) ([]chunk.Row, error) { var rows []chunk.Row - chk := rs.NewChunk() + req := rs.NewRecordBatch() for { - err := rs.Next(ctx, chk) - if err != nil || chk.NumRows() == 0 { + err := rs.Next(ctx, req) + if err != nil || req.NumRows() == 0 { return rows, errors.Trace(err) } - iter := chunk.NewIterator4Chunk(chk) + iter := chunk.NewIterator4Chunk(req.Chunk) for r := iter.Begin(); r != iter.End(); r = iter.Next() { rows = append(rows, r) } - chk = chunk.Renew(chk, se.sessionVars.MaxChunkSize) + req.Chunk = chunk.Renew(req.Chunk, se.sessionVars.MaxChunkSize) } } @@ -825,11 +846,10 @@ func (s *session) SetGlobalSysVar(name, value string) error { if err != nil { return errors.Trace(err) } - + name = strings.ToLower(name) sql := fmt.Sprintf(`REPLACE %s.%s VALUES ('%s', '%s');`, - mysql.SystemDB, mysql.GlobalVariablesTable, strings.ToLower(name), sVal) + mysql.SystemDB, mysql.GlobalVariablesTable, name, sVal) _, _, err = s.ExecRestrictedSQL(s, sql) - return errors.Trace(err) } @@ -1249,15 +1269,30 @@ func loadSystemTZ(se *session) (string, error) { log.Error(errors.ErrorStack(err)) } }() - chk := rss[0].NewChunk() - if err := rss[0].Next(context.Background(), chk); err != nil { + req := rss[0].NewRecordBatch() + if err := rss[0].Next(context.Background(), req); err != nil { return "", errors.Trace(err) } - return chk.GetRow(0).GetString(0), nil + return req.GetRow(0).GetString(0), nil } // BootstrapSession runs the first time when the TiDB server start. func BootstrapSession(store kv.Storage) (*domain.Domain, error) { + cfg := config.GetGlobalConfig() + if len(cfg.Plugin.Load) > 0 { + err := plugin.Init(context.Background(), plugin.Config{ + Plugins: strings.Split(cfg.Plugin.Load, ","), + PluginDir: cfg.Plugin.Dir, + GlobalSysVar: &variable.SysVars, + PluginVarNames: &variable.PluginVarNames, + }) + if err != nil { + return nil, err + } + } + + initLoadCommonGlobalVarsSQL() + ver := getStoreBootstrapVersion(store) if ver == notBootstrapped { runInBootstrapSession(store, bootstrap) @@ -1379,7 +1414,7 @@ func createSessionWithDomain(store kv.Storage, dom *domain.Domain) (*session, er const ( notBootstrapped = 0 - currentBootstrapVersion = 24 + currentBootstrapVersion = 25 ) func getStoreBootstrapVersion(store kv.Storage) int64 { @@ -1428,41 +1463,59 @@ func finishBootstrap(store kv.Storage) { } const quoteCommaQuote = "', '" -const loadCommonGlobalVarsSQL = "select HIGH_PRIORITY * from mysql.global_variables where variable_name in ('" + - variable.AutocommitVar + quoteCommaQuote + - variable.SQLModeVar + quoteCommaQuote + - variable.MaxAllowedPacket + quoteCommaQuote + - variable.TimeZone + quoteCommaQuote + - variable.BlockEncryptionMode + quoteCommaQuote + - variable.WaitTimeout + quoteCommaQuote + - variable.InteractiveTimeout + quoteCommaQuote + - variable.MaxPreparedStmtCount + quoteCommaQuote + + +var builtinGlobalVariable = []string{ + variable.AutocommitVar, + variable.SQLModeVar, + variable.MaxAllowedPacket, + variable.TimeZone, + variable.BlockEncryptionMode, + variable.WaitTimeout, + variable.InteractiveTimeout, + variable.MaxPreparedStmtCount, /* TiDB specific global variables: */ - variable.TiDBSkipUTF8Check + quoteCommaQuote + - variable.TiDBIndexJoinBatchSize + quoteCommaQuote + - variable.TiDBIndexLookupSize + quoteCommaQuote + - variable.TiDBIndexLookupConcurrency + quoteCommaQuote + - variable.TiDBIndexLookupJoinConcurrency + quoteCommaQuote + - variable.TiDBIndexSerialScanConcurrency + quoteCommaQuote + - variable.TiDBHashJoinConcurrency + quoteCommaQuote + - variable.TiDBProjectionConcurrency + quoteCommaQuote + - variable.TiDBHashAggPartialConcurrency + quoteCommaQuote + - variable.TiDBHashAggFinalConcurrency + quoteCommaQuote + - variable.TiDBBackoffLockFast + quoteCommaQuote + - variable.TiDBConstraintCheckInPlace + quoteCommaQuote + - variable.TiDBDDLReorgWorkerCount + quoteCommaQuote + - variable.TiDBDDLReorgBatchSize + quoteCommaQuote + - variable.TiDBOptInSubqToJoinAndAgg + quoteCommaQuote + - variable.TiDBDistSQLScanConcurrency + quoteCommaQuote + - variable.TiDBInitChunkSize + quoteCommaQuote + - variable.TiDBMaxChunkSize + quoteCommaQuote + - variable.TiDBEnableCascadesPlanner + quoteCommaQuote + - variable.TiDBRetryLimit + quoteCommaQuote + - variable.TiDBDisableTxnAutoRetry + quoteCommaQuote + - variable.TiDBEnableWindowFunction + "')" + variable.TiDBSkipUTF8Check, + variable.TiDBIndexJoinBatchSize, + variable.TiDBIndexLookupSize, + variable.TiDBIndexLookupConcurrency, + variable.TiDBIndexLookupJoinConcurrency, + variable.TiDBIndexSerialScanConcurrency, + variable.TiDBHashJoinConcurrency, + variable.TiDBProjectionConcurrency, + variable.TiDBHashAggPartialConcurrency, + variable.TiDBHashAggFinalConcurrency, + variable.TiDBBackoffLockFast, + variable.TiDBConstraintCheckInPlace, + variable.TiDBDDLReorgWorkerCount, + variable.TiDBDDLReorgBatchSize, + variable.TiDBOptInSubqToJoinAndAgg, + variable.TiDBDistSQLScanConcurrency, + variable.TiDBInitChunkSize, + variable.TiDBMaxChunkSize, + variable.TiDBEnableCascadesPlanner, + variable.TiDBRetryLimit, + variable.TiDBDisableTxnAutoRetry, + variable.TiDBEnableWindowFunction, +} + +var ( + loadCommonGlobalVarsSQLOnce sync.Once + loadCommonGlobalVarsSQL string +) + +func initLoadCommonGlobalVarsSQL() { + loadCommonGlobalVarsSQLOnce.Do(func() { + vars := append(make([]string, 0, len(builtinGlobalVariable)+len(variable.PluginVarNames)), builtinGlobalVariable...) + if len(variable.PluginVarNames) > 0 { + vars = append(vars, variable.PluginVarNames...) + } + loadCommonGlobalVarsSQL = "select HIGH_PRIORITY * from mysql.global_variables where variable_name in ('" + strings.Join(vars, quoteCommaQuote) + "')" + }) +} // loadCommonGlobalVariablesIfNeeded loads and applies commonly used global variables for the session. func (s *session) loadCommonGlobalVariablesIfNeeded() error { + initLoadCommonGlobalVarsSQL() vars := s.sessionVars if vars.CommonGlobalLoaded { return nil diff --git a/session/session_fail_test.go b/session/session_fail_test.go index 5459d23dc2d2f..9f160a8fda1ec 100644 --- a/session/session_fail_test.go +++ b/session/session_fail_test.go @@ -34,13 +34,48 @@ func (s *testSessionSuite) TestFailStatementCommit(c *C) { gofail.Disable("github.com/pingcap/tidb/session/mockStmtCommitError") - tk.MustQuery("select * from t").Check(testkit.Rows("1")) - tk.MustExec("insert into t values (3)") - tk.MustExec("insert into t values (4)") + _, err = tk.Exec("select * from t") + c.Assert(err, NotNil) + _, err = tk.Exec("insert into t values (3)") + c.Assert(err, NotNil) + _, err = tk.Exec("insert into t values (4)") + c.Assert(err, NotNil) _, err = tk.Exec("commit") c.Assert(err, NotNil) tk.MustQuery(`select * from t`).Check(testkit.Rows()) + + tk.MustExec("insert into t values (1)") + + tk.MustExec("begin") + tk.MustExec("insert into t values (2)") + tk.MustExec("commit") + + tk.MustExec("begin") + tk.MustExec("insert into t values (3)") + tk.MustExec("rollback") + + tk.MustQuery(`select * from t`).Check(testkit.Rows("1", "2")) +} + +func (s *testSessionSuite) TestFailStatementCommitInRetry(c *C) { + tk := testkit.NewTestKitWithInit(c, s.store) + tk.MustExec("create table t (id int)") + + tk.MustExec("begin") + tk.MustExec("insert into t values (1)") + tk.MustExec("insert into t values (2),(3),(4),(5)") + tk.MustExec("insert into t values (6)") + + gofail.Enable("github.com/pingcap/tidb/session/mockCommitError8942", `return(true)`) + gofail.Enable("github.com/pingcap/tidb/session/mockStmtCommitError", `return(true)`) + _, err := tk.Exec("commit") + c.Assert(err, NotNil) + gofail.Disable("github.com/pingcap/tidb/session/mockCommitError8942") + gofail.Disable("github.com/pingcap/tidb/session/mockStmtCommitError") + + tk.MustExec("insert into t values (6)") + tk.MustQuery(`select * from t`).Check(testkit.Rows("6")) } func (s *testSessionSuite) TestGetTSFailDirtyState(c *C) { diff --git a/session/session_test.go b/session/session_test.go index df77ed3c401fc..218452da99361 100644 --- a/session/session_test.go +++ b/session/session_test.go @@ -1128,10 +1128,10 @@ func (s *testSessionSuite) TestResultType(c *C) { tk := testkit.NewTestKitWithInit(c, s.store) rs, err := tk.Exec(`select cast(null as char(30))`) c.Assert(err, IsNil) - chk := rs.NewChunk() - err = rs.Next(context.Background(), chk) + req := rs.NewRecordBatch() + err = rs.Next(context.Background(), req) c.Assert(err, IsNil) - c.Assert(chk.GetRow(0).IsNull(0), IsTrue) + c.Assert(req.GetRow(0).IsNull(0), IsTrue) c.Assert(rs.Fields()[0].Column.FieldType.Tp, Equals, mysql.TypeVarString) } @@ -1877,18 +1877,18 @@ func (s *testSchemaSuite) TestTableReaderChunk(c *C) { }() rs, err := tk.Exec("select * from chk") c.Assert(err, IsNil) - chk := rs.NewChunk() + req := rs.NewRecordBatch() var count int var numChunks int for { - err = rs.Next(context.TODO(), chk) + err = rs.Next(context.TODO(), req) c.Assert(err, IsNil) - numRows := chk.NumRows() + numRows := req.NumRows() if numRows == 0 { break } for i := 0; i < numRows; i++ { - c.Assert(chk.GetRow(i).GetInt64(0), Equals, int64(count)) + c.Assert(req.GetRow(i).GetInt64(0), Equals, int64(count)) count++ } numChunks++ @@ -1914,15 +1914,15 @@ func (s *testSchemaSuite) TestInsertExecChunk(c *C) { c.Assert(err, IsNil) var idx int for { - chk := rs.NewChunk() - err = rs.Next(context.TODO(), chk) + req := rs.NewRecordBatch() + err = rs.Next(context.TODO(), req) c.Assert(err, IsNil) - if chk.NumRows() == 0 { + if req.NumRows() == 0 { break } - for rowIdx := 0; rowIdx < chk.NumRows(); rowIdx++ { - row := chk.GetRow(rowIdx) + for rowIdx := 0; rowIdx < req.NumRows(); rowIdx++ { + row := req.GetRow(rowIdx) c.Assert(row.GetInt64(0), Equals, int64(idx)) idx++ } @@ -1948,15 +1948,15 @@ func (s *testSchemaSuite) TestUpdateExecChunk(c *C) { c.Assert(err, IsNil) var idx int for { - chk := rs.NewChunk() - err = rs.Next(context.TODO(), chk) + req := rs.NewRecordBatch() + err = rs.Next(context.TODO(), req) c.Assert(err, IsNil) - if chk.NumRows() == 0 { + if req.NumRows() == 0 { break } - for rowIdx := 0; rowIdx < chk.NumRows(); rowIdx++ { - row := chk.GetRow(rowIdx) + for rowIdx := 0; rowIdx < req.NumRows(); rowIdx++ { + row := req.GetRow(rowIdx) c.Assert(row.GetInt64(0), Equals, int64(idx+100)) idx++ } @@ -1983,12 +1983,12 @@ func (s *testSchemaSuite) TestDeleteExecChunk(c *C) { rs, err := tk.Exec("select * from chk") c.Assert(err, IsNil) - chk := rs.NewChunk() - err = rs.Next(context.TODO(), chk) + req := rs.NewRecordBatch() + err = rs.Next(context.TODO(), req) c.Assert(err, IsNil) - c.Assert(chk.NumRows(), Equals, 1) + c.Assert(req.NumRows(), Equals, 1) - row := chk.GetRow(0) + row := req.GetRow(0) c.Assert(row.GetInt64(0), Equals, int64(99)) rs.Close() } @@ -2015,16 +2015,16 @@ func (s *testSchemaSuite) TestDeleteMultiTableExecChunk(c *C) { var idx int for { - chk := rs.NewChunk() - err = rs.Next(context.TODO(), chk) + req := rs.NewRecordBatch() + err = rs.Next(context.TODO(), req) c.Assert(err, IsNil) - if chk.NumRows() == 0 { + if req.NumRows() == 0 { break } - for i := 0; i < chk.NumRows(); i++ { - row := chk.GetRow(i) + for i := 0; i < req.NumRows(); i++ { + row := req.GetRow(i) c.Assert(row.GetInt64(0), Equals, int64(idx+50)) idx++ } @@ -2035,10 +2035,10 @@ func (s *testSchemaSuite) TestDeleteMultiTableExecChunk(c *C) { rs, err = tk.Exec("select * from chk2") c.Assert(err, IsNil) - chk := rs.NewChunk() - err = rs.Next(context.TODO(), chk) + req := rs.NewRecordBatch() + err = rs.Next(context.TODO(), req) c.Assert(err, IsNil) - c.Assert(chk.NumRows(), Equals, 0) + c.Assert(req.NumRows(), Equals, 0) rs.Close() } @@ -2058,18 +2058,18 @@ func (s *testSchemaSuite) TestIndexLookUpReaderChunk(c *C) { tk.Se.GetSessionVars().IndexLookupSize = 10 rs, err := tk.Exec("select * from chk order by k") c.Assert(err, IsNil) - chk := rs.NewChunk() + req := rs.NewRecordBatch() var count int for { - err = rs.Next(context.TODO(), chk) + err = rs.Next(context.TODO(), req) c.Assert(err, IsNil) - numRows := chk.NumRows() + numRows := req.NumRows() if numRows == 0 { break } for i := 0; i < numRows; i++ { - c.Assert(chk.GetRow(i).GetInt64(0), Equals, int64(count)) - c.Assert(chk.GetRow(i).GetInt64(1), Equals, int64(count)) + c.Assert(req.GetRow(i).GetInt64(0), Equals, int64(count)) + c.Assert(req.GetRow(i).GetInt64(1), Equals, int64(count)) count++ } } @@ -2078,17 +2078,17 @@ func (s *testSchemaSuite) TestIndexLookUpReaderChunk(c *C) { rs, err = tk.Exec("select k from chk where c < 90 order by k") c.Assert(err, IsNil) - chk = rs.NewChunk() + req = rs.NewRecordBatch() count = 0 for { - err = rs.Next(context.TODO(), chk) + err = rs.Next(context.TODO(), req) c.Assert(err, IsNil) - numRows := chk.NumRows() + numRows := req.NumRows() if numRows == 0 { break } for i := 0; i < numRows; i++ { - c.Assert(chk.GetRow(i).GetInt64(0), Equals, int64(count)) + c.Assert(req.GetRow(i).GetInt64(0), Equals, int64(count)) count++ } } @@ -2457,6 +2457,17 @@ func (s *testSessionSuite) TestUpdatePrivilege(c *C) { // In fact, the privlege check for t1 should be update, and for t2 should be select. _, err = tk1.Exec("update t1,t2 set t1.id = t2.id;") c.Assert(err, IsNil) + + // Fix issue 8911 + tk.MustExec("create database weperk") + tk.MustExec("use weperk") + tk.MustExec("create table tb_wehub_server (id int, active_count int, used_count int)") + tk.MustExec("create user 'weperk'") + tk.MustExec("grant all privileges on weperk.* to 'weperk'@'%'") + c.Assert(tk1.Se.Auth(&auth.UserIdentity{Username: "weperk", Hostname: "%"}, + []byte(""), []byte("")), IsTrue) + tk1.MustExec("use weperk") + tk1.MustExec("update tb_wehub_server a set a.active_count=a.active_count+1,a.used_count=a.used_count+1 where id=1") } func (s *testSessionSuite) TestTxnGoString(c *C) { diff --git a/session/tidb.go b/session/tidb.go index 3df30dc955367..99160fc3bb867 100644 --- a/session/tidb.go +++ b/session/tidb.go @@ -192,6 +192,10 @@ func runStmt(ctx context.Context, sctx sessionctx.Context, s sqlexec.Statement) var err error var rs sqlexec.RecordSet se := sctx.(*session) + err = se.checkTxnAborted(s) + if err != nil { + return nil, err + } rs, err = s.Exec(ctx) sessVars := se.GetSessionVars() // All the history should be added here. @@ -245,23 +249,23 @@ func GetRows4Test(ctx context.Context, sctx sessionctx.Context, rs sqlexec.Recor return nil, nil } var rows []chunk.Row - chk := rs.NewChunk() + req := rs.NewRecordBatch() for { // Since we collect all the rows, we can not reuse the chunk. - iter := chunk.NewIterator4Chunk(chk) + iter := chunk.NewIterator4Chunk(req.Chunk) - err := rs.Next(ctx, chk) + err := rs.Next(ctx, req) if err != nil { return nil, errors.Trace(err) } - if chk.NumRows() == 0 { + if req.NumRows() == 0 { break } for row := iter.Begin(); row != iter.End(); row = iter.Next() { rows = append(rows, row) } - chk = chunk.Renew(chk, sctx.GetSessionVars().MaxChunkSize) + req.Chunk = chunk.Renew(req.Chunk, sctx.GetSessionVars().MaxChunkSize) } return rows, nil } diff --git a/session/txn.go b/session/txn.go index 6e4f7ad55c9d5..a4f28f94732d6 100644 --- a/session/txn.go +++ b/session/txn.go @@ -164,6 +164,13 @@ func (st *TxnState) Commit(ctx context.Context) error { } return errors.Trace(st.doNotCommit) } + + // mockCommitError8942 is used for PR #8942. + // gofail: var mockCommitError8942 bool + // if mockCommitError8942 { + // return kv.ErrRetryable + // } + return errors.Trace(st.Transaction.Commit(ctx)) } diff --git a/sessionctx/stmtctx/stmtctx.go b/sessionctx/stmtctx/stmtctx.go index 1830f096e46dd..87d856d089c41 100644 --- a/sessionctx/stmtctx/stmtctx.go +++ b/sessionctx/stmtctx/stmtctx.go @@ -47,8 +47,10 @@ type StatementContext struct { // If IsDDLJobInQueue is true, it means the DDL job is in the queue of storage, and it can be handled by the DDL worker. IsDDLJobInQueue bool InInsertStmt bool - InUpdateOrDeleteStmt bool + InUpdateStmt bool + InDeleteStmt bool InSelectStmt bool + InLoadDataStmt bool IgnoreTruncate bool IgnoreZeroInDate bool DupKeyAsWarning bool @@ -379,3 +381,21 @@ func (sc *StatementContext) GetExecDetails() execdetails.ExecDetails { sc.mu.Unlock() return details } + +// ShouldClipToZero indicates whether values less than 0 should be clipped to 0 for unsigned integer types. +// This is the case for `insert`, `update`, `alter table` and `load data infile` statements, when not in strict SQL mode. +// see https://dev.mysql.com/doc/refman/5.7/en/out-of-range-and-overflow.html +func (sc *StatementContext) ShouldClipToZero() bool { + // TODO: Currently altering column of integer to unsigned integer is not supported. + // If it is supported one day, that case should be added here. + return sc.InInsertStmt || sc.InLoadDataStmt +} + +// ShouldIgnoreOverflowError indicates whether we should ignore the error when type conversion overflows, +// so we can leave it for further processing like clipping values less than 0 to 0 for unsigned integer types. +func (sc *StatementContext) ShouldIgnoreOverflowError() bool { + if (sc.InInsertStmt && sc.TruncateAsWarning) || sc.InLoadDataStmt { + return true + } + return false +} diff --git a/sessionctx/variable/session.go b/sessionctx/variable/session.go index 97fba074939ff..66ba563f0d422 100644 --- a/sessionctx/variable/session.go +++ b/sessionctx/variable/session.go @@ -681,10 +681,6 @@ func (s *SessionVars) SetSystemVar(name string, val string) error { s.OptimizerSelectivityLevel = tidbOptPositiveInt32(val, DefTiDBOptimizerSelectivityLevel) case TiDBEnableTablePartition: s.EnableTablePartition = val - case TiDBDDLReorgWorkerCount: - SetDDLReorgWorkerCounter(int32(tidbOptPositiveInt32(val, DefTiDBDDLReorgWorkerCount))) - case TiDBDDLReorgBatchSize: - SetDDLReorgBatchSize(int32(tidbOptPositiveInt32(val, DefTiDBDDLReorgBatchSize))) case TiDBDDLReorgPriority: s.setDDLReorgPriority(val) case TiDBForcePriority: @@ -698,6 +694,16 @@ func (s *SessionVars) SetSystemVar(name string, val string) error { return nil } +// SetLocalSystemVar sets values of the local variables which in "server" scope. +func SetLocalSystemVar(name string, val string) { + switch name { + case TiDBDDLReorgWorkerCount: + SetDDLReorgWorkerCounter(int32(tidbOptPositiveInt32(val, DefTiDBDDLReorgWorkerCount))) + case TiDBDDLReorgBatchSize: + SetDDLReorgBatchSize(int32(tidbOptPositiveInt32(val, DefTiDBDDLReorgBatchSize))) + } +} + // special session variables. const ( SQLModeVar = "sql_mode" diff --git a/sessionctx/variable/sysvar.go b/sessionctx/variable/sysvar.go index 5ec78717fe7f5..9befa39267a1c 100644 --- a/sessionctx/variable/sysvar.go +++ b/sessionctx/variable/sysvar.go @@ -57,6 +57,9 @@ func GetSysVar(name string) *SysVar { return SysVars[name] } +// PluginVarNames is global plugin var names set. +var PluginVarNames []string + // Variable error codes. const ( CodeUnknownStatusVar terror.ErrCode = 1 @@ -315,7 +318,8 @@ var defaultSysVars = []*SysVar{ {ScopeGlobal | ScopeSession, "sort_buffer_size", "262144"}, {ScopeGlobal, "innodb_flush_neighbors", "1"}, {ScopeNone, "innodb_use_sys_malloc", "ON"}, - {ScopeNone, "plugin_dir", "/usr/local/mysql/lib/plugin/"}, + {ScopeSession, PluginLoad, ""}, + {ScopeSession, PluginDir, "/data/deploy/plugin"}, {ScopeNone, "performance_schema_max_socket_classes", "10"}, {ScopeNone, "performance_schema_max_stage_classes", "150"}, {ScopeGlobal, "innodb_purge_batch_size", "300"}, @@ -672,8 +676,8 @@ var defaultSysVars = []*SysVar{ {ScopeSession, TiDBSlowLogThreshold, strconv.Itoa(logutil.DefaultSlowThreshold)}, {ScopeSession, TiDBQueryLogMaxLen, strconv.Itoa(logutil.DefaultQueryLogMaxLen)}, {ScopeSession, TiDBConfig, ""}, - {ScopeGlobal | ScopeSession, TiDBDDLReorgWorkerCount, strconv.Itoa(DefTiDBDDLReorgWorkerCount)}, - {ScopeGlobal | ScopeSession, TiDBDDLReorgBatchSize, strconv.Itoa(DefTiDBDDLReorgBatchSize)}, + {ScopeGlobal, TiDBDDLReorgWorkerCount, strconv.Itoa(DefTiDBDDLReorgWorkerCount)}, + {ScopeGlobal, TiDBDDLReorgBatchSize, strconv.Itoa(DefTiDBDDLReorgBatchSize)}, {ScopeSession, TiDBDDLReorgPriority, "PRIORITY_LOW"}, {ScopeSession, TiDBForcePriority, mysql.Priority2Str[DefTiDBForcePriority]}, {ScopeSession, TiDBEnableRadixJoin, boolToIntStr(DefTiDBUseRadixJoin)}, @@ -789,6 +793,10 @@ const ( ValidatePasswordNumberCount = "validate_password_number_count" // ValidatePasswordLength is the name of 'validate_password_length' system variable. ValidatePasswordLength = "validate_password_length" + // PluginDir is the name of 'plugin_dir' system variable. + PluginDir = "plugin_dir" + // PluginLoad is the name of 'plugin_load' system variable. + PluginLoad = "plugin_load" ) // GlobalVarAccessor is the interface for accessing global scope system and status variables. diff --git a/sessionctx/variable/varsutil.go b/sessionctx/variable/varsutil.go index 9c20d2e4d261a..89d498c5e1128 100644 --- a/sessionctx/variable/varsutil.go +++ b/sessionctx/variable/varsutil.go @@ -106,6 +106,10 @@ func GetSessionOnlySysVars(s *SessionVars, key string) (string, bool, error) { return strconv.FormatUint(atomic.LoadUint64(&config.GetGlobalConfig().Log.SlowThreshold), 10), true, nil case TiDBQueryLogMaxLen: return strconv.FormatUint(atomic.LoadUint64(&config.GetGlobalConfig().Log.QueryLogMaxLen), 10), true, nil + case PluginDir: + return config.GetGlobalConfig().Plugin.Dir, true, nil + case PluginLoad: + return config.GetGlobalConfig().Plugin.Load, true, nil } sVal, ok := s.systems[key] if ok { diff --git a/sessionctx/variable/varsutil_test.go b/sessionctx/variable/varsutil_test.go index fa7cc100b0956..b54dd8d392b96 100644 --- a/sessionctx/variable/varsutil_test.go +++ b/sessionctx/variable/varsutil_test.go @@ -216,10 +216,6 @@ func (s *testVarsutilSuite) TestVarsutil(c *C) { SetSessionSystemVar(v, TiDBOptimizerSelectivityLevel, types.NewIntDatum(1)) c.Assert(v.OptimizerSelectivityLevel, Equals, 1) - c.Assert(GetDDLReorgWorkerCounter(), Equals, int32(DefTiDBDDLReorgWorkerCount)) - SetSessionSystemVar(v, TiDBDDLReorgWorkerCount, types.NewIntDatum(1)) - c.Assert(GetDDLReorgWorkerCounter(), Equals, int32(1)) - err = SetSessionSystemVar(v, TiDBDDLReorgWorkerCount, types.NewIntDatum(-1)) c.Assert(terror.ErrorEqual(err, ErrWrongValueForVar), IsTrue) diff --git a/statistics/bootstrap.go b/statistics/bootstrap.go index 0cda1a965fbc0..2ef66f577950b 100644 --- a/statistics/bootstrap.go +++ b/statistics/bootstrap.go @@ -67,14 +67,14 @@ func (h *Handle) initStatsMeta(is infoschema.InfoSchema) (statsCache, error) { return nil, errors.Trace(err) } tables := statsCache{} - chk := rc[0].NewChunk() - iter := chunk.NewIterator4Chunk(chk) + req := rc[0].NewRecordBatch() + iter := chunk.NewIterator4Chunk(req.Chunk) for { - err := rc[0].Next(context.TODO(), chk) + err := rc[0].Next(context.TODO(), req) if err != nil { return nil, errors.Trace(err) } - if chk.NumRows() == 0 { + if req.NumRows() == 0 { break } h.initStatsMeta4Chunk(is, tables, iter) @@ -136,14 +136,14 @@ func (h *Handle) initStatsHistograms(is infoschema.InfoSchema, tables statsCache if err != nil { return errors.Trace(err) } - chk := rc[0].NewChunk() - iter := chunk.NewIterator4Chunk(chk) + req := rc[0].NewRecordBatch() + iter := chunk.NewIterator4Chunk(req.Chunk) for { - err := rc[0].Next(context.TODO(), chk) + err := rc[0].Next(context.TODO(), req) if err != nil { return errors.Trace(err) } - if chk.NumRows() == 0 { + if req.NumRows() == 0 { break } h.initStatsHistograms4Chunk(is, tables, iter) @@ -208,14 +208,14 @@ func (h *Handle) initStatsBuckets(tables statsCache) error { if err != nil { return errors.Trace(err) } - chk := rc[0].NewChunk() - iter := chunk.NewIterator4Chunk(chk) + req := rc[0].NewRecordBatch() + iter := chunk.NewIterator4Chunk(req.Chunk) for { - err := rc[0].Next(context.TODO(), chk) + err := rc[0].Next(context.TODO(), req) if err != nil { return errors.Trace(err) } - if chk.NumRows() == 0 { + if req.NumRows() == 0 { break } initStatsBuckets4Chunk(h.mu.ctx, tables, iter) diff --git a/statistics/ddl.go b/statistics/ddl.go index ec4e344217806..fa390ce664605 100644 --- a/statistics/ddl.go +++ b/statistics/ddl.go @@ -137,12 +137,12 @@ func (h *Handle) insertColStats2KV(physicalID int64, colInfo *model.ColumnInfo) if err != nil { return } - chk := rs[0].NewChunk() - err = rs[0].Next(ctx, chk) + req := rs[0].NewRecordBatch() + err = rs[0].Next(ctx, req) if err != nil { return } - count := chk.GetRow(0).GetInt64(0) + count := req.GetRow(0).GetInt64(0) value := types.NewDatum(colInfo.OriginDefaultValue) value, err = value.ConvertTo(h.mu.ctx.GetSessionVars().StmtCtx, &colInfo.FieldType) if err != nil { diff --git a/statistics/feedback.go b/statistics/feedback.go index 24fadd8fa762f..392efd3c9b9a0 100644 --- a/statistics/feedback.go +++ b/statistics/feedback.go @@ -294,8 +294,7 @@ func buildBucketFeedback(h *Histogram, feedback *QueryFeedback) (map[int]*Bucket } total := 0 sc := &stmtctx.StatementContext{TimeZone: time.UTC} - kind := feedback.feedback[0].lower.Kind() - min, max := getMinValue(kind, h.tp), getMaxValue(kind, h.tp) + min, max := getMinValue(h.tp), getMaxValue(h.tp) for _, fb := range feedback.feedback { skip, err := fb.adjustFeedbackBoundaries(sc, &min, &max) if err != nil { @@ -723,11 +722,18 @@ func decodeFeedbackForIndex(q *QueryFeedback, pb *queryFeedback, c *CMSketch) { } } -func decodeFeedbackForPK(q *QueryFeedback, pb *queryFeedback) { +func decodeFeedbackForPK(q *QueryFeedback, pb *queryFeedback, isUnsigned bool) { q.tp = pkType // decode feedback for primary key for i := 0; i < len(pb.IntRanges); i += 2 { - lower, upper := types.NewIntDatum(pb.IntRanges[i]), types.NewIntDatum(pb.IntRanges[i+1]) + var lower, upper types.Datum + if isUnsigned { + lower.SetUint64(uint64(pb.IntRanges[i])) + upper.SetUint64(uint64(pb.IntRanges[i+1])) + } else { + lower.SetInt64(pb.IntRanges[i]) + upper.SetInt64(pb.IntRanges[i+1]) + } q.feedback = append(q.feedback, feedback{&lower, &upper, pb.Counts[i/2], 0}) } } @@ -748,7 +754,7 @@ func decodeFeedbackForColumn(q *QueryFeedback, pb *queryFeedback) error { return nil } -func decodeFeedback(val []byte, q *QueryFeedback, c *CMSketch) error { +func decodeFeedback(val []byte, q *QueryFeedback, c *CMSketch, isUnsigned bool) error { buf := bytes.NewBuffer(val) dec := gob.NewDecoder(buf) pb := &queryFeedback{} @@ -759,7 +765,7 @@ func decodeFeedback(val []byte, q *QueryFeedback, c *CMSketch) error { if len(pb.IndexRanges) > 0 || len(pb.HashValues) > 0 { decodeFeedbackForIndex(q, pb, c) } else if len(pb.IntRanges) > 0 { - decodeFeedbackForPK(q, pb) + decodeFeedbackForPK(q, pb, isUnsigned) } else { err := decodeFeedbackForColumn(q, pb) if err != nil { @@ -1073,15 +1079,14 @@ func (q *QueryFeedback) dumpRangeFeedback(h *Handle, ran *ranger.Range, rangeCou ran.LowVal[0].SetBytes(lower) ran.HighVal[0].SetBytes(upper) } else { - k := q.hist.GetLower(0).Kind() - if !supportColumnType(k) { + if !supportColumnType(q.hist.tp) { return nil } if ran.LowVal[0].Kind() == types.KindMinNotNull { - ran.LowVal[0] = getMinValue(k, q.hist.tp) + ran.LowVal[0] = getMinValue(q.hist.tp) } if ran.HighVal[0].Kind() == types.KindMaxValue { - ran.HighVal[0] = getMaxValue(k, q.hist.tp) + ran.HighVal[0] = getMaxValue(q.hist.tp) } } ranges := q.hist.SplitRange([]*ranger.Range{ran}) @@ -1128,27 +1133,30 @@ func setNextValue(d *types.Datum) { } // supportColumnType checks if the type of the column can be updated by feedback. -func supportColumnType(k byte) bool { - switch k { - case types.KindInt64, types.KindUint64, types.KindFloat32, types.KindFloat64, types.KindString, types.KindBytes, - types.KindMysqlDecimal, types.KindMysqlDuration, types.KindMysqlTime: +func supportColumnType(ft *types.FieldType) bool { + switch ft.Tp { + case mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong, mysql.TypeLonglong, mysql.TypeFloat, + mysql.TypeDouble, mysql.TypeString, mysql.TypeVarString, mysql.TypeVarchar, mysql.TypeBlob, mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob, + mysql.TypeNewDecimal, mysql.TypeDuration, mysql.TypeDate, mysql.TypeDatetime, mysql.TypeTimestamp: return true default: return false } } -func getMaxValue(k byte, ft *types.FieldType) (max types.Datum) { - switch k { - case types.KindInt64: - max.SetInt64(types.SignedUpperBound[ft.Tp]) - case types.KindUint64: - max.SetUint64(types.UnsignedUpperBound[ft.Tp]) - case types.KindFloat32: +func getMaxValue(ft *types.FieldType) (max types.Datum) { + switch ft.Tp { + case mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong, mysql.TypeLonglong: + if mysql.HasUnsignedFlag(ft.Flag) { + max.SetUint64(types.UnsignedUpperBound[ft.Tp]) + } else { + max.SetInt64(types.SignedUpperBound[ft.Tp]) + } + case mysql.TypeFloat: max.SetFloat32(float32(types.GetMaxFloat(ft.Flen, ft.Decimal))) - case types.KindFloat64: + case mysql.TypeDouble: max.SetFloat64(types.GetMaxFloat(ft.Flen, ft.Decimal)) - case types.KindString, types.KindBytes: + case mysql.TypeString, mysql.TypeVarString, mysql.TypeVarchar, mysql.TypeBlob, mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob: val := types.MaxValueDatum() bytes, err := codec.EncodeKey(nil, nil, val) // should not happen @@ -1156,11 +1164,11 @@ func getMaxValue(k byte, ft *types.FieldType) (max types.Datum) { log.Error(err) } max.SetBytes(bytes) - case types.KindMysqlDecimal: + case mysql.TypeNewDecimal: max.SetMysqlDecimal(types.NewMaxOrMinDec(false, ft.Flen, ft.Decimal)) - case types.KindMysqlDuration: + case mysql.TypeDuration: max.SetMysqlDuration(types.Duration{Duration: math.MaxInt64}) - case types.KindMysqlTime: + case mysql.TypeDate, mysql.TypeDatetime, mysql.TypeTimestamp: if ft.Tp == mysql.TypeDate || ft.Tp == mysql.TypeDatetime { max.SetMysqlTime(types.Time{Time: types.MaxDatetime, Type: ft.Tp}) } else { @@ -1170,17 +1178,19 @@ func getMaxValue(k byte, ft *types.FieldType) (max types.Datum) { return } -func getMinValue(k byte, ft *types.FieldType) (min types.Datum) { - switch k { - case types.KindInt64: - min.SetInt64(types.SignedLowerBound[ft.Tp]) - case types.KindUint64: - min.SetUint64(0) - case types.KindFloat32: +func getMinValue(ft *types.FieldType) (min types.Datum) { + switch ft.Tp { + case mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong, mysql.TypeLonglong: + if mysql.HasUnsignedFlag(ft.Flag) { + min.SetUint64(0) + } else { + min.SetInt64(types.SignedLowerBound[ft.Tp]) + } + case mysql.TypeFloat: min.SetFloat32(float32(-types.GetMaxFloat(ft.Flen, ft.Decimal))) - case types.KindFloat64: + case mysql.TypeDouble: min.SetFloat64(-types.GetMaxFloat(ft.Flen, ft.Decimal)) - case types.KindString, types.KindBytes: + case mysql.TypeString, mysql.TypeVarString, mysql.TypeVarchar, mysql.TypeBlob, mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob: val := types.MinNotNullDatum() bytes, err := codec.EncodeKey(nil, nil, val) // should not happen @@ -1188,11 +1198,11 @@ func getMinValue(k byte, ft *types.FieldType) (min types.Datum) { log.Error(err) } min.SetBytes(bytes) - case types.KindMysqlDecimal: + case mysql.TypeNewDecimal: min.SetMysqlDecimal(types.NewMaxOrMinDec(true, ft.Flen, ft.Decimal)) - case types.KindMysqlDuration: + case mysql.TypeDuration: min.SetMysqlDuration(types.Duration{Duration: math.MinInt64}) - case types.KindMysqlTime: + case mysql.TypeDate, mysql.TypeDatetime, mysql.TypeTimestamp: if ft.Tp == mysql.TypeDate || ft.Tp == mysql.TypeDatetime { min.SetMysqlTime(types.Time{Time: types.MinDatetime, Type: ft.Tp}) } else { diff --git a/statistics/feedback_test.go b/statistics/feedback_test.go index 08058c386bef2..da19fc233342d 100644 --- a/statistics/feedback_test.go +++ b/statistics/feedback_test.go @@ -221,7 +221,7 @@ func (s *testFeedbackSuite) TestFeedbackEncoding(c *C) { val, err := encodeFeedback(q) c.Assert(err, IsNil) rq := &QueryFeedback{} - c.Assert(decodeFeedback(val, rq, nil), IsNil) + c.Assert(decodeFeedback(val, rq, nil, false), IsNil) for _, fb := range rq.feedback { fb.lower.SetBytes(codec.EncodeInt(nil, fb.lower.GetInt64())) fb.upper.SetBytes(codec.EncodeInt(nil, fb.upper.GetInt64())) @@ -236,7 +236,7 @@ func (s *testFeedbackSuite) TestFeedbackEncoding(c *C) { c.Assert(err, IsNil) rq = &QueryFeedback{} cms := NewCMSketch(4, 4) - c.Assert(decodeFeedback(val, rq, cms), IsNil) + c.Assert(decodeFeedback(val, rq, cms, false), IsNil) c.Assert(cms.QueryBytes(codec.EncodeInt(nil, 0)), Equals, uint32(1)) q.feedback = q.feedback[:1] c.Assert(q.Equal(rq), IsTrue) diff --git a/statistics/sample.go b/statistics/sample.go index ca63d6d572b2c..f454678478187 100644 --- a/statistics/sample.go +++ b/statistics/sample.go @@ -162,14 +162,14 @@ func (s SampleBuilder) CollectColumnStats() ([]*SampleCollector, *SortedBuilder, } } ctx := context.TODO() - chk := s.RecordSet.NewChunk() - it := chunk.NewIterator4Chunk(chk) + req := s.RecordSet.NewRecordBatch() + it := chunk.NewIterator4Chunk(req.Chunk) for { - err := s.RecordSet.Next(ctx, chk) + err := s.RecordSet.Next(ctx, req) if err != nil { return nil, nil, errors.Trace(err) } - if chk.NumRows() == 0 { + if req.NumRows() == 0 { return collectors, s.PkBuilder, nil } if len(s.RecordSet.Fields()) == 0 { diff --git a/statistics/statistics_test.go b/statistics/statistics_test.go index e25efa2a659a7..cd9761457f14c 100644 --- a/statistics/statistics_test.go +++ b/statistics/statistics_test.go @@ -83,23 +83,23 @@ func (r *recordSet) getNext() []types.Datum { return row } -func (r *recordSet) Next(ctx context.Context, chk *chunk.Chunk) error { - chk.Reset() +func (r *recordSet) Next(ctx context.Context, req *chunk.RecordBatch) error { + req.Reset() row := r.getNext() if row != nil { for i := 0; i < len(row); i++ { - chk.AppendDatum(i, &row[i]) + req.AppendDatum(i, &row[i]) } } return nil } -func (r *recordSet) NewChunk() *chunk.Chunk { +func (r *recordSet) NewRecordBatch() *chunk.RecordBatch { fields := make([]*types.FieldType, 0, len(r.fields)) for _, field := range r.fields { fields = append(fields, &field.Column.FieldType) } - return chunk.NewChunkWithCapacity(fields, 32) + return chunk.NewRecordBatch(chunk.NewChunkWithCapacity(fields, 32)) } func (r *recordSet) Close() error { @@ -174,15 +174,15 @@ func buildPK(sctx sessionctx.Context, numBuckets, id int64, records sqlexec.Reco b := NewSortedBuilder(sctx.GetSessionVars().StmtCtx, numBuckets, id, types.NewFieldType(mysql.TypeLonglong)) ctx := context.Background() for { - chk := records.NewChunk() - err := records.Next(ctx, chk) + req := records.NewRecordBatch() + err := records.Next(ctx, req) if err != nil { return 0, nil, errors.Trace(err) } - if chk.NumRows() == 0 { + if req.NumRows() == 0 { break } - it := chunk.NewIterator4Chunk(chk) + it := chunk.NewIterator4Chunk(req.Chunk) for row := it.Begin(); row != it.End(); row = it.Next() { datums := RowToDatums(row, records.Fields()) err = b.Iterate(datums[0]) @@ -198,14 +198,14 @@ func buildIndex(sctx sessionctx.Context, numBuckets, id int64, records sqlexec.R b := NewSortedBuilder(sctx.GetSessionVars().StmtCtx, numBuckets, id, types.NewFieldType(mysql.TypeBlob)) cms := NewCMSketch(8, 2048) ctx := context.Background() - chk := records.NewChunk() - it := chunk.NewIterator4Chunk(chk) + req := records.NewRecordBatch() + it := chunk.NewIterator4Chunk(req.Chunk) for { - err := records.Next(ctx, chk) + err := records.Next(ctx, req) if err != nil { return 0, nil, nil, errors.Trace(err) } - if chk.NumRows() == 0 { + if req.NumRows() == 0 { break } for row := it.Begin(); row != it.End(); row = it.Next() { diff --git a/statistics/update.go b/statistics/update.go index 143a88a546d77..b65c4cdbb7734 100644 --- a/statistics/update.go +++ b/statistics/update.go @@ -24,6 +24,7 @@ import ( "github.com/pingcap/errors" "github.com/pingcap/parser/model" + "github.com/pingcap/parser/mysql" "github.com/pingcap/tidb/infoschema" "github.com/pingcap/tidb/metrics" "github.com/pingcap/tidb/sessionctx/variable" @@ -561,7 +562,7 @@ func (h *Handle) handleSingleHistogramUpdate(is infoschema.InfoSchema, rows []ch } q := &QueryFeedback{} for _, row := range rows { - err1 := decodeFeedback(row.GetBytes(3), q, cms) + err1 := decodeFeedback(row.GetBytes(3), q, cms, mysql.HasUnsignedFlag(hist.tp.Flag)) if err1 != nil { log.Debugf("decode feedback failed, err: %v", errors.ErrorStack(err)) } diff --git a/statistics/update_test.go b/statistics/update_test.go index 13712ad9f3e09..4e4d18a44faeb 100644 --- a/statistics/update_test.go +++ b/statistics/update_test.go @@ -1362,3 +1362,61 @@ func (s *testStatsSuite) TestFeedbackRanges(c *C) { c.Assert(tbl.Columns[t.colID].ToString(0), Equals, tests[i].hist) } } + +func (s *testStatsSuite) TestUnsignedFeedbackRanges(c *C) { + defer cleanEnv(c, s.store, s.do) + testKit := testkit.NewTestKit(c, s.store) + h := s.do.StatsHandle() + oriProbability := statistics.FeedbackProbability + oriNumber := statistics.MaxNumberOfRanges + defer func() { + statistics.FeedbackProbability = oriProbability + statistics.MaxNumberOfRanges = oriNumber + }() + statistics.FeedbackProbability = 1 + + testKit.MustExec("use test") + testKit.MustExec("create table t (a tinyint unsigned, primary key(a))") + for i := 0; i < 20; i++ { + testKit.MustExec(fmt.Sprintf("insert into t values (%d)", i)) + } + h.HandleDDLEvent(<-h.DDLEventCh()) + c.Assert(h.DumpStatsDeltaToKV(statistics.DumpAll), IsNil) + testKit.MustExec("analyze table t with 3 buckets") + for i := 30; i < 40; i++ { + testKit.MustExec(fmt.Sprintf("insert into t values (%d)", i)) + } + c.Assert(h.DumpStatsDeltaToKV(statistics.DumpAll), IsNil) + tests := []struct { + sql string + hist string + }{ + { + sql: "select * from t where a <= 50", + hist: "column:1 ndv:30 totColSize:0\n" + + "num: 8 lower_bound: 0 upper_bound: 7 repeats: 0\n" + + "num: 8 lower_bound: 8 upper_bound: 15 repeats: 0\n" + + "num: 14 lower_bound: 16 upper_bound: 50 repeats: 0", + }, + { + sql: "select count(*) from t", + hist: "column:1 ndv:30 totColSize:0\n" + + "num: 8 lower_bound: 0 upper_bound: 7 repeats: 0\n" + + "num: 8 lower_bound: 8 upper_bound: 15 repeats: 0\n" + + "num: 14 lower_bound: 16 upper_bound: 255 repeats: 0", + }, + } + is := s.do.InfoSchema() + table, err := is.TableByName(model.NewCIStr("test"), model.NewCIStr("t")) + for i, t := range tests { + testKit.MustQuery(t.sql) + c.Assert(h.DumpStatsDeltaToKV(statistics.DumpAll), IsNil) + c.Assert(h.DumpStatsFeedbackToKV(), IsNil) + c.Assert(h.HandleUpdateStats(s.do.InfoSchema()), IsNil) + c.Assert(err, IsNil) + h.Update(is) + tblInfo := table.Meta() + tbl := h.GetTableStats(tblInfo) + c.Assert(tbl.Columns[1].ToString(0), Equals, tests[i].hist) + } +} diff --git a/store/mockstore/mocktikv/analyze.go b/store/mockstore/mocktikv/analyze.go index ad0b1af5bdd5f..afde4346d9fa8 100644 --- a/store/mockstore/mocktikv/analyze.go +++ b/store/mockstore/mocktikv/analyze.go @@ -212,24 +212,24 @@ func (e *analyzeColumnsExec) getNext(ctx context.Context) ([]types.Datum, error) return datumRow, nil } -func (e *analyzeColumnsExec) Next(ctx context.Context, chk *chunk.Chunk) error { - chk.Reset() +func (e *analyzeColumnsExec) Next(ctx context.Context, req *chunk.RecordBatch) error { + req.Reset() row, err := e.getNext(ctx) if row == nil || err != nil { return errors.Trace(err) } for i := 0; i < len(row); i++ { - chk.AppendDatum(i, &row[i]) + req.AppendDatum(i, &row[i]) } return nil } -func (e *analyzeColumnsExec) NewChunk() *chunk.Chunk { +func (e *analyzeColumnsExec) NewRecordBatch() *chunk.RecordBatch { fields := make([]*types.FieldType, 0, len(e.fields)) for _, field := range e.fields { fields = append(fields, &field.Column.FieldType) } - return chunk.NewChunkWithCapacity(fields, 1) + return chunk.NewRecordBatch(chunk.NewChunkWithCapacity(fields, 1)) } // Close implements the sqlexec.RecordSet Close interface. diff --git a/store/tikv/backoff.go b/store/tikv/backoff.go index a24dc39aabe33..ca99dfa7785e9 100644 --- a/store/tikv/backoff.go +++ b/store/tikv/backoff.go @@ -87,7 +87,7 @@ const ( boTiKVRPC backoffType = iota BoTxnLock boTxnLockFast - boPDRPC + BoPDRPC BoRegionMiss BoUpdateLeader boServerBusy @@ -104,7 +104,7 @@ func (t backoffType) createFn(vars *kv.Variables) func(context.Context) int { return NewBackoffFn(200, 3000, EqualJitter) case boTxnLockFast: return NewBackoffFn(vars.BackoffLockFast, 3000, EqualJitter) - case boPDRPC: + case BoPDRPC: return NewBackoffFn(500, 3000, EqualJitter) case BoRegionMiss: // change base time to 2ms, because it may recover soon. @@ -125,7 +125,7 @@ func (t backoffType) String() string { return "txnLock" case boTxnLockFast: return "txnLockFast" - case boPDRPC: + case BoPDRPC: return "pdRPC" case BoRegionMiss: return "regionMiss" @@ -143,7 +143,7 @@ func (t backoffType) TError() error { return ErrTiKVServerTimeout case BoTxnLock, boTxnLockFast: return ErrResolveLockTimeout - case boPDRPC: + case BoPDRPC: return ErrPDServerTimeout.GenWithStackByArgs(txnRetryableMark) case BoRegionMiss, BoUpdateLeader: return ErrRegionUnavailable diff --git a/store/tikv/client.go b/store/tikv/client.go index 116ee96a9127d..af8e5cfd5d14f 100644 --- a/store/tikv/client.go +++ b/store/tikv/client.go @@ -32,10 +32,13 @@ import ( "github.com/pingcap/tidb/config" "github.com/pingcap/tidb/metrics" "github.com/pingcap/tidb/store/tikv/tikvrpc" + tidbutil "github.com/pingcap/tidb/util" log "github.com/sirupsen/logrus" "google.golang.org/grpc" + gcodes "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials" "google.golang.org/grpc/keepalive" + gstatus "google.golang.org/grpc/status" ) // MaxConnectionCount is the max gRPC connections that will be established with @@ -86,13 +89,116 @@ type connArray struct { v []*grpc.ClientConn // Bind with a background goroutine to process coprocessor streaming timeout. streamTimeout chan *tikvrpc.Lease + + // For batch commands. + batchCommandsCh chan *batchCommandsEntry + batchCommandsClients []*batchCommandsClient + tikvTransportLayerLoad uint64 +} + +type batchCommandsClient struct { + conn *grpc.ClientConn + client tikvpb.Tikv_BatchCommandsClient + batched sync.Map + idAlloc uint64 + tikvTransportLayerLoad *uint64 + + // Indicates the batch client is closed explicitly or not. + closed int32 + // Protect client when re-create the streaming. + clientLock sync.Mutex +} + +func (c *batchCommandsClient) isStopped() bool { + return atomic.LoadInt32(&c.closed) != 0 +} + +func (c *batchCommandsClient) failPendingRequests(err error) { + c.batched.Range(func(key, value interface{}) bool { + id, _ := key.(uint64) + entry, _ := value.(*batchCommandsEntry) + entry.err = err + close(entry.res) + c.batched.Delete(id) + return true + }) +} + +func (c *batchCommandsClient) batchRecvLoop(cfg config.TiKVClient) { + defer func() { + if r := recover(); r != nil { + buf := tidbutil.GetStack() + metrics.PanicCounter.WithLabelValues(metrics.LabelBatchRecvLoop).Inc() + log.Errorf("batchRecvLoop %v %s", r, buf) + log.Infof("Restart batchRecvLoop") + go c.batchRecvLoop(cfg) + } + }() + + for { + // When `conn.Close()` is called, `client.Recv()` will return an error. + resp, err := c.client.Recv() + if err != nil { + if c.isStopped() { + return + } + log.Errorf("batchRecvLoop error when receive: %v", err) + + // Hold the lock to forbid batchSendLoop using the old client. + c.clientLock.Lock() + c.failPendingRequests(err) // fail all pending requests. + for { // try to re-create the streaming in the loop. + // Re-establish a application layer stream. TCP layer is handled by gRPC. + tikvClient := tikvpb.NewTikvClient(c.conn) + streamClient, err := tikvClient.BatchCommands(context.TODO()) + if err == nil { + log.Infof("batchRecvLoop re-create streaming success") + c.client = streamClient + break + } + log.Errorf("batchRecvLoop re-create streaming fail: %v", err) + // TODO: Use a more smart backoff strategy. + time.Sleep(time.Second) + } + c.clientLock.Unlock() + continue + } + + responses := resp.GetResponses() + for i, requestID := range resp.GetRequestIds() { + value, ok := c.batched.Load(requestID) + if !ok { + // There shouldn't be any unknown responses because if the old entries + // are cleaned by `failPendingRequests`, the stream must be re-created + // so that old responses will be never received. + panic("batchRecvLoop receives a unknown response") + } + entry := value.(*batchCommandsEntry) + if atomic.LoadInt32(&entry.canceled) == 0 { + // Put the response only if the request is not canceled. + entry.res <- responses[i] + } + c.batched.Delete(requestID) + } + + tikvTransportLayerLoad := resp.GetTransportLayerLoad() + if tikvTransportLayerLoad > 0.0 && cfg.MaxBatchWaitTime > 0 { + // We need to consider TiKV load only if batch-wait strategy is enabled. + atomic.StoreUint64(c.tikvTransportLayerLoad, tikvTransportLayerLoad) + } + } } func newConnArray(maxSize uint, addr string, security config.Security) (*connArray, error) { + cfg := config.GetGlobalConfig() a := &connArray{ index: 0, v: make([]*grpc.ClientConn, maxSize), streamTimeout: make(chan *tikvrpc.Lease, 1024), + + batchCommandsCh: make(chan *batchCommandsEntry, cfg.TiKVClient.MaxBatchSize), + batchCommandsClients: make([]*batchCommandsClient, 0, maxSize), + tikvTransportLayerLoad: 0, } if err := a.Init(addr, security); err != nil { return nil, err @@ -124,6 +230,7 @@ func (a *connArray) Init(addr string, security config.Security) error { ) } + allowBatch := cfg.TiKVClient.MaxBatchSize > 0 for i := range a.v { ctx, cancel := context.WithTimeout(context.Background(), dialTimeout) conn, err := grpc.DialContext( @@ -150,8 +257,31 @@ func (a *connArray) Init(addr string, security config.Security) error { return errors.Trace(err) } a.v[i] = conn + + if allowBatch { + // Initialize batch streaming clients. + tikvClient := tikvpb.NewTikvClient(conn) + streamClient, err := tikvClient.BatchCommands(context.TODO()) + if err != nil { + a.Close() + return errors.Trace(err) + } + batchClient := &batchCommandsClient{ + conn: conn, + client: streamClient, + batched: sync.Map{}, + idAlloc: 0, + tikvTransportLayerLoad: &a.tikvTransportLayerLoad, + closed: 0, + } + a.batchCommandsClients = append(a.batchCommandsClients, batchClient) + go batchClient.batchRecvLoop(cfg.TiKVClient) + } } go tikvrpc.CheckStreamTimeoutLoop(a.streamTimeout) + if allowBatch { + go a.batchSendLoop(cfg.TiKVClient) + } return nil } @@ -162,6 +292,13 @@ func (a *connArray) Get() *grpc.ClientConn { } func (a *connArray) Close() { + // Close all batchRecvLoop. + for _, c := range a.batchCommandsClients { + // After connections are closed, `batchRecvLoop`s will check the flag. + atomic.StoreInt32(&c.closed, 1) + } + close(a.batchCommandsCh) + for i, c := range a.v { if c != nil { err := c.Close() @@ -172,6 +309,154 @@ func (a *connArray) Close() { close(a.streamTimeout) } +type batchCommandsEntry struct { + req *tikvpb.BatchCommandsRequest_Request + res chan *tikvpb.BatchCommandsResponse_Response + + // Indicated the request is canceled or not. + canceled int32 + err error +} + +// fetchAllPendingRequests fetches all pending requests from the channel. +func fetchAllPendingRequests( + ch chan *batchCommandsEntry, + maxBatchSize int, + entries *[]*batchCommandsEntry, + requests *[]*tikvpb.BatchCommandsRequest_Request, +) { + // Block on the first element. + headEntry := <-ch + if headEntry == nil { + return + } + *entries = append(*entries, headEntry) + *requests = append(*requests, headEntry.req) + + // This loop is for trying best to collect more requests. + for len(*entries) < maxBatchSize { + select { + case entry := <-ch: + if entry == nil { + return + } + *entries = append(*entries, entry) + *requests = append(*requests, entry.req) + default: + return + } + } +} + +// fetchMorePendingRequests fetches more pending requests from the channel. +func fetchMorePendingRequests( + ch chan *batchCommandsEntry, + maxBatchSize int, + batchWaitSize int, + maxWaitTime time.Duration, + entries *[]*batchCommandsEntry, + requests *[]*tikvpb.BatchCommandsRequest_Request, +) { + waitStart := time.Now() + + // Try to collect `batchWaitSize` requests, or wait `maxWaitTime`. + after := time.NewTimer(maxWaitTime) + for len(*entries) < batchWaitSize { + select { + case entry := <-ch: + if entry == nil { + return + } + *entries = append(*entries, entry) + *requests = append(*requests, entry.req) + case waitEnd := <-after.C: + metrics.TiKVBatchWaitDuration.Observe(float64(waitEnd.Sub(waitStart))) + return + } + } + after.Stop() + + // Do an additional non-block try. + for len(*entries) < maxBatchSize { + select { + case entry := <-ch: + if entry == nil { + return + } + *entries = append(*entries, entry) + *requests = append(*requests, entry.req) + default: + metrics.TiKVBatchWaitDuration.Observe(float64(time.Since(waitStart))) + return + } + } +} + +func (a *connArray) batchSendLoop(cfg config.TiKVClient) { + defer func() { + if r := recover(); r != nil { + buf := tidbutil.GetStack() + metrics.PanicCounter.WithLabelValues(metrics.LabelBatchSendLoop).Inc() + log.Errorf("batchSendLoop %v %s", r, buf) + log.Infof("Restart batchSendLoop") + go a.batchSendLoop(cfg) + } + }() + + entries := make([]*batchCommandsEntry, 0, cfg.MaxBatchSize) + requests := make([]*tikvpb.BatchCommandsRequest_Request, 0, cfg.MaxBatchSize) + requestIDs := make([]uint64, 0, cfg.MaxBatchSize) + + for { + // Choose a connection by round-robbin. + next := atomic.AddUint32(&a.index, 1) % uint32(len(a.v)) + batchCommandsClient := a.batchCommandsClients[next] + + entries = entries[:0] + requests = requests[:0] + requestIDs = requestIDs[:0] + + metrics.TiKVPendingBatchRequests.Set(float64(len(a.batchCommandsCh))) + fetchAllPendingRequests(a.batchCommandsCh, int(cfg.MaxBatchSize), &entries, &requests) + + if len(entries) < int(cfg.MaxBatchSize) && cfg.MaxBatchWaitTime > 0 { + tikvTransportLayerLoad := atomic.LoadUint64(batchCommandsClient.tikvTransportLayerLoad) + // If the target TiKV is overload, wait a while to collect more requests. + if uint(tikvTransportLayerLoad) >= cfg.OverloadThreshold { + fetchMorePendingRequests( + a.batchCommandsCh, int(cfg.MaxBatchSize), int(cfg.BatchWaitSize), + cfg.MaxBatchWaitTime, &entries, &requests, + ) + } + } + + length := len(requests) + maxBatchID := atomic.AddUint64(&batchCommandsClient.idAlloc, uint64(length)) + for i := 0; i < length; i++ { + requestID := uint64(i) + maxBatchID - uint64(length) + requestIDs = append(requestIDs, requestID) + } + + request := &tikvpb.BatchCommandsRequest{ + Requests: requests, + RequestIds: requestIDs, + } + + // Use the lock to protect the stream client won't be replaced by RecvLoop, + // and new added request won't be removed by `failPendingRequests`. + batchCommandsClient.clientLock.Lock() + for i, requestID := range request.RequestIds { + batchCommandsClient.batched.Store(requestID, entries[i]) + } + err := batchCommandsClient.client.Send(request) + batchCommandsClient.clientLock.Unlock() + if err != nil { + log.Errorf("batch commands send error: %v", err) + batchCommandsClient.failPendingRequests(err) + } + } +} + // rpcClient is RPC client struct. // TODO: Add flow control between RPC clients in TiDB ond RPC servers in TiKV. // Since we use shared client connection to communicate to the same TiKV, it's possible @@ -237,6 +522,42 @@ func (c *rpcClient) closeConns() { c.Unlock() } +func sendBatchRequest( + ctx context.Context, + addr string, + connArray *connArray, + req *tikvpb.BatchCommandsRequest_Request, + timeout time.Duration, +) (*tikvrpc.Response, error) { + entry := &batchCommandsEntry{ + req: req, + res: make(chan *tikvpb.BatchCommandsResponse_Response, 1), + canceled: 0, + err: nil, + } + ctx1, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + + select { + case connArray.batchCommandsCh <- entry: + case <-ctx1.Done(): + log.Warnf("SendRequest to %s is timeout", addr) + return nil, errors.Trace(gstatus.Error(gcodes.DeadlineExceeded, "Canceled or timeout")) + } + + select { + case res, ok := <-entry.res: + if !ok { + return nil, errors.Trace(entry.err) + } + return tikvrpc.FromBatchCommandsResponse(res), nil + case <-ctx1.Done(): + atomic.StoreInt32(&entry.canceled, 1) + log.Warnf("SendRequest to %s is canceled", addr) + return nil, errors.Trace(gstatus.Error(gcodes.DeadlineExceeded, "Canceled or timeout")) + } +} + // SendRequest sends a Request to server and receives Response. func (c *rpcClient) SendRequest(ctx context.Context, addr string, req *tikvrpc.Request, timeout time.Duration) (*tikvrpc.Response, error) { start := time.Now() @@ -250,6 +571,13 @@ func (c *rpcClient) SendRequest(ctx context.Context, addr string, req *tikvrpc.R if err != nil { return nil, errors.Trace(err) } + + if config.GetGlobalConfig().TiKVClient.MaxBatchSize > 0 { + if batchReq := req.ToBatchCommandsRequest(); batchReq != nil { + return sendBatchRequest(ctx, addr, connArray, batchReq, timeout) + } + } + client := tikvpb.NewTikvClient(connArray.Get()) if req.Type != tikvrpc.CmdCopStream { diff --git a/store/tikv/client_test.go b/store/tikv/client_test.go index d3b53f30b87a2..5063acb77eeee 100644 --- a/store/tikv/client_test.go +++ b/store/tikv/client_test.go @@ -32,6 +32,9 @@ type testClientSuite struct { var _ = Suite(&testClientSuite{}) func (s *testClientSuite) TestConn(c *C) { + globalConfig := config.GetGlobalConfig() + globalConfig.TiKVClient.MaxBatchSize = 0 // Disable batch. + client := newRPCClient(config.Security{}) addr := "127.0.0.1:6379" diff --git a/store/tikv/gcworker/gc_worker.go b/store/tikv/gcworker/gc_worker.go index 4561b586f176f..5dcfe97f1fbd2 100644 --- a/store/tikv/gcworker/gc_worker.go +++ b/store/tikv/gcworker/gc_worker.go @@ -122,6 +122,11 @@ const ( gcEnableValue = "true" gcDisableValue = "false" gcDefaultEnableValue = true + + gcModeKey = "tikv_gc_mode" + gcModeCentral = "central" + gcModeDistributed = "distributed" + gcModeDefault = gcModeDistributed ) var gcSafePointCacheInterval = tikv.GcSafePointCacheInterval @@ -136,6 +141,7 @@ var gcVariableComments = map[string]string{ gcSafePointKey: "All versions after safe point can be accessed. (DO NOT EDIT)", gcConcurrencyKey: "How many go routines used to do GC parallel, [1, 128], default 2", gcEnableKey: "Current GC enable status", + gcModeKey: "Mode of GC, \"central\" or \"distributed\"", } func (w *GCWorker) start(ctx context.Context, wg *sync.WaitGroup) { @@ -400,14 +406,34 @@ func (w *GCWorker) runGCJob(ctx context.Context, safePoint uint64) { w.done <- errors.Trace(err) return } - err = w.doGC(ctx, safePoint) + + useDistributedGC, err := w.checkUseDistributedGC() if err != nil { - log.Errorf("[gc worker] %s do GC returns an error %v", w.uuid, errors.ErrorStack(err)) - w.gcIsRunning = false - metrics.GCJobFailureCounter.WithLabelValues("gc").Inc() - w.done <- errors.Trace(err) - return + log.Errorf("[gc worker] %s failed to load gc mode, fall back to central mode. err: %v", w.uuid, errors.ErrorStack(err)) + metrics.GCJobFailureCounter.WithLabelValues("check_gc_mode").Inc() + useDistributedGC = false + } + + if useDistributedGC { + err = w.uploadSafePointToPD(ctx, safePoint) + if err != nil { + log.Errorf("[gc worker] %s failed to upload safe point to PD: %v", w.uuid, errors.ErrorStack(err)) + w.gcIsRunning = false + metrics.GCJobFailureCounter.WithLabelValues("upload_safe_point").Inc() + w.done <- errors.Trace(err) + return + } + } else { + err = w.doGC(ctx, safePoint) + if err != nil { + log.Errorf("[gc worker] %s do GC returns an error %v", w.uuid, errors.ErrorStack(err)) + w.gcIsRunning = false + metrics.GCJobFailureCounter.WithLabelValues("gc").Inc() + w.done <- errors.Trace(err) + return + } } + w.done <- nil } @@ -484,7 +510,7 @@ func (w *GCWorker) sendUnsafeDestroyRangeRequest(ctx context.Context, startKey [ // Get all stores every time deleting a region. So the store list is less probably to be stale. stores, err := w.pdClient.GetAllStores(ctx) if err != nil { - log.Errorf("[gc worker] %s delete ranges: got an error while trying to get store list from pd: %v", w.uuid, errors.ErrorStack(err)) + log.Errorf("[gc worker] %s delete ranges: got an error while trying to get store list from PD: %v", w.uuid, errors.ErrorStack(err)) return errors.Trace(err) } @@ -550,6 +576,28 @@ func (w *GCWorker) loadGCConcurrencyWithDefault() (int, error) { return jobConcurrency, nil } +func (w *GCWorker) checkUseDistributedGC() (bool, error) { + str, err := w.loadValueFromSysTable(gcModeKey) + if err != nil { + return false, errors.Trace(err) + } + if str == "" { + err = w.saveValueToSysTable(gcModeKey, gcModeDefault) + if err != nil { + return false, errors.Trace(err) + } + str = gcModeDefault + } + if strings.EqualFold(str, gcModeDistributed) { + return true, nil + } + if strings.EqualFold(str, gcModeCentral) { + return false, nil + } + log.Warnf("[gc worker] \"%v\" is not a valid gc mode. distributed mode will be used.", str) + return true, nil +} + func (w *GCWorker) resolveLocks(ctx context.Context, safePoint uint64) error { metrics.GCWorkerCounter.WithLabelValues("resolve_locks").Inc() @@ -642,6 +690,34 @@ func (w *GCWorker) resolveLocks(ctx context.Context, safePoint uint64) error { return nil } +func (w *GCWorker) uploadSafePointToPD(ctx context.Context, safePoint uint64) error { + var newSafePoint uint64 + var err error + + bo := tikv.NewBackoffer(ctx, tikv.GcOneRegionMaxBackoff) + for { + newSafePoint, err = w.pdClient.UpdateGCSafePoint(ctx, safePoint) + if err != nil { + if errors.Cause(err) == context.Canceled { + return errors.Trace(err) + } + err = bo.Backoff(tikv.BoPDRPC, errors.Errorf("failed to upload safe point to PD, err: %v", err)) + if err != nil { + return errors.Trace(err) + } + continue + } + break + } + + if newSafePoint != safePoint { + log.Warnf("[gc worker] %s, PD rejected our safe point %v but is using another safe point %v", w.uuid, safePoint, newSafePoint) + return errors.Errorf("PD rejected our safe point %v but is using another safe point %v", safePoint, newSafePoint) + } + log.Infof("[gc worker] %s sent safe point %v to PD", w.uuid, safePoint) + return nil +} + type gcTask struct { startKey []byte endKey []byte @@ -995,16 +1071,16 @@ func (w *GCWorker) loadValueFromSysTable(key string) (string, error) { if err != nil { return "", errors.Trace(err) } - chk := rs[0].NewChunk() - err = rs[0].Next(ctx, chk) + req := rs[0].NewRecordBatch() + err = rs[0].Next(ctx, req) if err != nil { return "", errors.Trace(err) } - if chk.NumRows() == 0 { + if req.NumRows() == 0 { log.Debugf("[gc worker] load kv, %s:nil", key) return "", nil } - value := chk.GetRow(0).GetString(0) + value := req.GetRow(0).GetString(0) log.Debugf("[gc worker] load kv, %s:%s", key, value) return value, nil } diff --git a/store/tikv/gcworker/gc_worker_test.go b/store/tikv/gcworker/gc_worker_test.go index 1bf14628ba45b..579a8b9b456f5 100644 --- a/store/tikv/gcworker/gc_worker_test.go +++ b/store/tikv/gcworker/gc_worker_test.go @@ -214,3 +214,28 @@ func (s *testGCWorkerSuite) TestDoGC(c *C) { err = s.gcWorker.doGC(ctx, 20) c.Assert(err, IsNil) } + +func (s *testGCWorkerSuite) TestCheckGCMode(c *C) { + useDistributedGC, err := s.gcWorker.checkUseDistributedGC() + c.Assert(err, IsNil) + c.Assert(useDistributedGC, Equals, true) + // Now the row must be set to the default value. + str, err := s.gcWorker.loadValueFromSysTable(gcModeKey) + c.Assert(err, IsNil) + c.Assert(str, Equals, gcModeDistributed) + + s.gcWorker.saveValueToSysTable(gcModeKey, gcModeCentral) + useDistributedGC, err = s.gcWorker.checkUseDistributedGC() + c.Assert(err, IsNil) + c.Assert(useDistributedGC, Equals, false) + + s.gcWorker.saveValueToSysTable(gcModeKey, gcModeDistributed) + useDistributedGC, err = s.gcWorker.checkUseDistributedGC() + c.Assert(err, IsNil) + c.Assert(useDistributedGC, Equals, true) + + s.gcWorker.saveValueToSysTable(gcModeKey, "invalid_mode") + useDistributedGC, err = s.gcWorker.checkUseDistributedGC() + c.Assert(err, IsNil) + c.Assert(useDistributedGC, Equals, true) +} diff --git a/store/tikv/kv.go b/store/tikv/kv.go index beb7451718448..a816826f696d4 100644 --- a/store/tikv/kv.go +++ b/store/tikv/kv.go @@ -317,7 +317,7 @@ func (s *tikvStore) getTimestampWithRetry(bo *Backoffer) (uint64, error) { if err == nil { return startTS, nil } - err = bo.Backoff(boPDRPC, errors.Errorf("get timestamp failed: %v", err)) + err = bo.Backoff(BoPDRPC, errors.Errorf("get timestamp failed: %v", err)) if err != nil { return 0, errors.Trace(err) } diff --git a/store/tikv/region_cache.go b/store/tikv/region_cache.go index 919f3aa6d3472..18dfcf3589d80 100644 --- a/store/tikv/region_cache.go +++ b/store/tikv/region_cache.go @@ -330,7 +330,7 @@ func (c *RegionCache) loadRegion(bo *Backoffer, key []byte) (*Region, error) { var backoffErr error for { if backoffErr != nil { - err := bo.Backoff(boPDRPC, backoffErr) + err := bo.Backoff(BoPDRPC, backoffErr) if err != nil { return nil, errors.Trace(err) } @@ -364,7 +364,7 @@ func (c *RegionCache) loadRegionByID(bo *Backoffer, regionID uint64) (*Region, e var backoffErr error for { if backoffErr != nil { - err := bo.Backoff(boPDRPC, backoffErr) + err := bo.Backoff(BoPDRPC, backoffErr) if err != nil { return nil, errors.Trace(err) } @@ -437,7 +437,7 @@ func (c *RegionCache) loadStoreAddr(bo *Backoffer, id uint64) (string, error) { return "", errors.Trace(err) } err = errors.Errorf("loadStore from PD failed, id: %d, err: %v", id, err) - if err = bo.Backoff(boPDRPC, err); err != nil { + if err = bo.Backoff(BoPDRPC, err); err != nil { return "", errors.Trace(err) } continue diff --git a/store/tikv/region_request_test.go b/store/tikv/region_request_test.go index 44dd26124548c..a9a79e37e85dd 100644 --- a/store/tikv/region_request_test.go +++ b/store/tikv/region_request_test.go @@ -232,6 +232,9 @@ func (s *mockTikvGrpcServer) Coprocessor(context.Context, *coprocessor.Request) func (s *mockTikvGrpcServer) Raft(tikvpb.Tikv_RaftServer) error { return errors.New("unreachable") } +func (s *mockTikvGrpcServer) BatchRaft(tikvpb.Tikv_BatchRaftServer) error { + return errors.New("unreachable") +} func (s *mockTikvGrpcServer) Snapshot(tikvpb.Tikv_SnapshotServer) error { return errors.New("unreachable") } @@ -249,6 +252,10 @@ func (s *mockTikvGrpcServer) CoprocessorStream(*coprocessor.Request, tikvpb.Tikv return errors.New("unreachable") } +func (s *mockTikvGrpcServer) BatchCommands(tikvpb.Tikv_BatchCommandsServer) error { + return errors.New("unreachable") +} + func (s *testRegionRequestSuite) TestNoReloadRegionForGrpcWhenCtxCanceled(c *C) { // prepare a mock tikv grpc server addr := "localhost:56341" diff --git a/store/tikv/sql_fail_test.go b/store/tikv/sql_fail_test.go index 6cfc4824b3554..2ac7ef5411408 100644 --- a/store/tikv/sql_fail_test.go +++ b/store/tikv/sql_fail_test.go @@ -72,11 +72,11 @@ func (s *testSQLSuite) TestFailBusyServerCop(c *C) { defer terror.Call(rs[0].Close) } c.Assert(err, IsNil) - chk := rs[0].NewChunk() - err = rs[0].Next(context.Background(), chk) + req := rs[0].NewRecordBatch() + err = rs[0].Next(context.Background(), req) c.Assert(err, IsNil) - c.Assert(chk.NumRows() == 0, IsFalse) - c.Assert(chk.GetRow(0).GetString(0), Equals, "True") + c.Assert(req.NumRows() == 0, IsFalse) + c.Assert(req.GetRow(0).GetString(0), Equals, "True") }() wg.Wait() @@ -107,13 +107,13 @@ func (s *testSQLSuite) TestCoprocessorStreamRecvTimeout(c *C) { res, err := tk.Se.Execute(ctx, "select * from t") c.Assert(err, IsNil) - chk := res[0].NewChunk() + req := res[0].NewRecordBatch() for { - err := res[0].Next(ctx, chk) + err := res[0].Next(ctx, req) c.Assert(err, IsNil) - if chk.NumRows() == 0 { + if req.NumRows() == 0 { break } - chk.Reset() + req.Reset() } } diff --git a/store/tikv/tikvrpc/tikvrpc.go b/store/tikv/tikvrpc/tikvrpc.go index 8709803c462b1..0666aa5026555 100644 --- a/store/tikv/tikvrpc/tikvrpc.go +++ b/store/tikv/tikvrpc/tikvrpc.go @@ -149,6 +149,53 @@ type Request struct { SplitRegion *kvrpcpb.SplitRegionRequest } +// ToBatchCommandsRequest converts the request to an entry in BatchCommands request. +func (req *Request) ToBatchCommandsRequest() *tikvpb.BatchCommandsRequest_Request { + switch req.Type { + case CmdGet: + return &tikvpb.BatchCommandsRequest_Request{Cmd: &tikvpb.BatchCommandsRequest_Request_Get{Get: req.Get}} + case CmdScan: + return &tikvpb.BatchCommandsRequest_Request{Cmd: &tikvpb.BatchCommandsRequest_Request_Scan{Scan: req.Scan}} + case CmdPrewrite: + return &tikvpb.BatchCommandsRequest_Request{Cmd: &tikvpb.BatchCommandsRequest_Request_Prewrite{Prewrite: req.Prewrite}} + case CmdCommit: + return &tikvpb.BatchCommandsRequest_Request{Cmd: &tikvpb.BatchCommandsRequest_Request_Commit{Commit: req.Commit}} + case CmdCleanup: + return &tikvpb.BatchCommandsRequest_Request{Cmd: &tikvpb.BatchCommandsRequest_Request_Cleanup{Cleanup: req.Cleanup}} + case CmdBatchGet: + return &tikvpb.BatchCommandsRequest_Request{Cmd: &tikvpb.BatchCommandsRequest_Request_BatchGet{BatchGet: req.BatchGet}} + case CmdBatchRollback: + return &tikvpb.BatchCommandsRequest_Request{Cmd: &tikvpb.BatchCommandsRequest_Request_BatchRollback{BatchRollback: req.BatchRollback}} + case CmdScanLock: + return &tikvpb.BatchCommandsRequest_Request{Cmd: &tikvpb.BatchCommandsRequest_Request_ScanLock{ScanLock: req.ScanLock}} + case CmdResolveLock: + return &tikvpb.BatchCommandsRequest_Request{Cmd: &tikvpb.BatchCommandsRequest_Request_ResolveLock{ResolveLock: req.ResolveLock}} + case CmdGC: + return &tikvpb.BatchCommandsRequest_Request{Cmd: &tikvpb.BatchCommandsRequest_Request_GC{GC: req.GC}} + case CmdDeleteRange: + return &tikvpb.BatchCommandsRequest_Request{Cmd: &tikvpb.BatchCommandsRequest_Request_DeleteRange{DeleteRange: req.DeleteRange}} + case CmdRawGet: + return &tikvpb.BatchCommandsRequest_Request{Cmd: &tikvpb.BatchCommandsRequest_Request_RawGet{RawGet: req.RawGet}} + case CmdRawBatchGet: + return &tikvpb.BatchCommandsRequest_Request{Cmd: &tikvpb.BatchCommandsRequest_Request_RawBatchGet{RawBatchGet: req.RawBatchGet}} + case CmdRawPut: + return &tikvpb.BatchCommandsRequest_Request{Cmd: &tikvpb.BatchCommandsRequest_Request_RawPut{RawPut: req.RawPut}} + case CmdRawBatchPut: + return &tikvpb.BatchCommandsRequest_Request{Cmd: &tikvpb.BatchCommandsRequest_Request_RawBatchPut{RawBatchPut: req.RawBatchPut}} + case CmdRawDelete: + return &tikvpb.BatchCommandsRequest_Request{Cmd: &tikvpb.BatchCommandsRequest_Request_RawDelete{RawDelete: req.RawDelete}} + case CmdRawBatchDelete: + return &tikvpb.BatchCommandsRequest_Request{Cmd: &tikvpb.BatchCommandsRequest_Request_RawBatchDelete{RawBatchDelete: req.RawBatchDelete}} + case CmdRawDeleteRange: + return &tikvpb.BatchCommandsRequest_Request{Cmd: &tikvpb.BatchCommandsRequest_Request_RawDeleteRange{RawDeleteRange: req.RawDeleteRange}} + case CmdRawScan: + return &tikvpb.BatchCommandsRequest_Request{Cmd: &tikvpb.BatchCommandsRequest_Request_RawScan{RawScan: req.RawScan}} + case CmdCop: + return &tikvpb.BatchCommandsRequest_Request{Cmd: &tikvpb.BatchCommandsRequest_Request_Coprocessor{Coprocessor: req.Cop}} + } + return nil +} + // Response wraps all kv/coprocessor responses. type Response struct { Type CmdType @@ -179,6 +226,53 @@ type Response struct { SplitRegion *kvrpcpb.SplitRegionResponse } +// FromBatchCommandsResponse converts a BatchCommands response to Response. +func FromBatchCommandsResponse(res *tikvpb.BatchCommandsResponse_Response) *Response { + switch res := res.GetCmd().(type) { + case *tikvpb.BatchCommandsResponse_Response_Get: + return &Response{Type: CmdGet, Get: res.Get} + case *tikvpb.BatchCommandsResponse_Response_Scan: + return &Response{Type: CmdScan, Scan: res.Scan} + case *tikvpb.BatchCommandsResponse_Response_Prewrite: + return &Response{Type: CmdPrewrite, Prewrite: res.Prewrite} + case *tikvpb.BatchCommandsResponse_Response_Commit: + return &Response{Type: CmdCommit, Commit: res.Commit} + case *tikvpb.BatchCommandsResponse_Response_Cleanup: + return &Response{Type: CmdCleanup, Cleanup: res.Cleanup} + case *tikvpb.BatchCommandsResponse_Response_BatchGet: + return &Response{Type: CmdBatchGet, BatchGet: res.BatchGet} + case *tikvpb.BatchCommandsResponse_Response_BatchRollback: + return &Response{Type: CmdBatchRollback, BatchRollback: res.BatchRollback} + case *tikvpb.BatchCommandsResponse_Response_ScanLock: + return &Response{Type: CmdScanLock, ScanLock: res.ScanLock} + case *tikvpb.BatchCommandsResponse_Response_ResolveLock: + return &Response{Type: CmdResolveLock, ResolveLock: res.ResolveLock} + case *tikvpb.BatchCommandsResponse_Response_GC: + return &Response{Type: CmdGC, GC: res.GC} + case *tikvpb.BatchCommandsResponse_Response_DeleteRange: + return &Response{Type: CmdDeleteRange, DeleteRange: res.DeleteRange} + case *tikvpb.BatchCommandsResponse_Response_RawGet: + return &Response{Type: CmdRawGet, RawGet: res.RawGet} + case *tikvpb.BatchCommandsResponse_Response_RawBatchGet: + return &Response{Type: CmdRawBatchGet, RawBatchGet: res.RawBatchGet} + case *tikvpb.BatchCommandsResponse_Response_RawPut: + return &Response{Type: CmdRawPut, RawPut: res.RawPut} + case *tikvpb.BatchCommandsResponse_Response_RawBatchPut: + return &Response{Type: CmdRawBatchPut, RawBatchPut: res.RawBatchPut} + case *tikvpb.BatchCommandsResponse_Response_RawDelete: + return &Response{Type: CmdRawDelete, RawDelete: res.RawDelete} + case *tikvpb.BatchCommandsResponse_Response_RawBatchDelete: + return &Response{Type: CmdRawBatchDelete, RawBatchDelete: res.RawBatchDelete} + case *tikvpb.BatchCommandsResponse_Response_RawDeleteRange: + return &Response{Type: CmdRawDeleteRange, RawDeleteRange: res.RawDeleteRange} + case *tikvpb.BatchCommandsResponse_Response_RawScan: + return &Response{Type: CmdRawScan, RawScan: res.RawScan} + case *tikvpb.BatchCommandsResponse_Response_Coprocessor: + return &Response{Type: CmdCop, Cop: res.Coprocessor} + } + return nil +} + // CopStreamResponse combinates tikvpb.Tikv_CoprocessorStreamClient and the first Recv() result together. // In streaming API, get grpc stream client may not involve any network packet, then region error have // to be handled in Recv() function. This struct facilitates the error handling. diff --git a/table/tables/tables.go b/table/tables/tables.go index 50ef94a4500d0..82d7e4aad2949 100644 --- a/table/tables/tables.go +++ b/table/tables/tables.go @@ -919,6 +919,16 @@ func (t *tableCommon) AllocAutoID(ctx sessionctx.Context) (int64, error) { return 0, errors.Trace(err) } if t.meta.ShardRowIDBits > 0 { + if t.overflowShardBits(rowID) { + // If overflow, the rowID may be duplicated. For examples, + // t.meta.ShardRowIDBits = 4 + // rowID = 0010111111111111111111111111111111111111111111111111111111111111 + // shard = 01000000000000000000000000000000000000000000000000000000000000000 + // will be duplicated with: + // rowID = 0100111111111111111111111111111111111111111111111111111111111111 + // shard = 0010000000000000000000000000000000000000000000000000000000000000 + return 0, autoid.ErrAutoincReadFailed + } txnCtx := ctx.GetSessionVars().TxnCtx if txnCtx.Shard == nil { shard := t.calcShard(txnCtx.StartTS) @@ -929,6 +939,12 @@ func (t *tableCommon) AllocAutoID(ctx sessionctx.Context) (int64, error) { return rowID, nil } +// overflowShardBits check whether the rowID overflow `1<<(64-t.meta.ShardRowIDBits-1) -1`. +func (t *tableCommon) overflowShardBits(rowID int64) bool { + mask := (1< 0 +} + func (t *tableCommon) calcShard(startTS uint64) int64 { var buf [8]byte binary.LittleEndian.PutUint64(buf[:], startTS) diff --git a/table/tables/tables_test.go b/table/tables/tables_test.go index d9307bb307882..17a9c49c4f70b 100644 --- a/table/tables/tables_test.go +++ b/table/tables/tables_test.go @@ -169,10 +169,10 @@ func (ts *testSuite) TestTypes(c *C) { c.Assert(err, IsNil) rs, err := ts.se.Execute(ctx, "select * from test.t where c1 = 1") c.Assert(err, IsNil) - chk := rs[0].NewChunk() - err = rs[0].Next(ctx, chk) + req := rs[0].NewRecordBatch() + err = rs[0].Next(ctx, req) c.Assert(err, IsNil) - c.Assert(chk.NumRows() == 0, IsFalse) + c.Assert(req.NumRows() == 0, IsFalse) c.Assert(rs[0].Close(), IsNil) _, err = ts.se.Execute(ctx, "drop table test.t") c.Assert(err, IsNil) @@ -183,11 +183,11 @@ func (ts *testSuite) TestTypes(c *C) { c.Assert(err, IsNil) rs, err = ts.se.Execute(ctx, "select * from test.t where c1 = 1") c.Assert(err, IsNil) - chk = rs[0].NewChunk() - err = rs[0].Next(ctx, chk) + req = rs[0].NewRecordBatch() + err = rs[0].Next(ctx, req) c.Assert(err, IsNil) - c.Assert(chk.NumRows() == 0, IsFalse) - row := chk.GetRow(0) + c.Assert(req.NumRows() == 0, IsFalse) + row := req.GetRow(0) c.Assert(types.BinaryLiteral(row.GetBytes(5)), DeepEquals, types.NewBinaryLiteralFromUint(6, -1)) c.Assert(rs[0].Close(), IsNil) _, err = ts.se.Execute(ctx, "drop table test.t") @@ -199,11 +199,11 @@ func (ts *testSuite) TestTypes(c *C) { c.Assert(err, IsNil) rs, err = ts.se.Execute(ctx, "select c1 + 1 from test.t where c1 = 1") c.Assert(err, IsNil) - chk = rs[0].NewChunk() - err = rs[0].Next(ctx, chk) + req = rs[0].NewRecordBatch() + err = rs[0].Next(ctx, req) c.Assert(err, IsNil) - c.Assert(chk.NumRows() == 0, IsFalse) - c.Assert(chk.GetRow(0).GetFloat64(0), DeepEquals, float64(2)) + c.Assert(req.NumRows() == 0, IsFalse) + c.Assert(req.GetRow(0).GetFloat64(0), DeepEquals, float64(2)) c.Assert(rs[0].Close(), IsNil) _, err = ts.se.Execute(ctx, "drop table test.t") c.Assert(err, IsNil) diff --git a/tidb-server/main.go b/tidb-server/main.go index 34b98aa59e7cf..c074779832baa 100644 --- a/tidb-server/main.go +++ b/tidb-server/main.go @@ -14,6 +14,7 @@ package main import ( + "context" "flag" "fmt" "os" @@ -79,6 +80,8 @@ const ( nmMetricsInterval = "metrics-interval" nmDdlLease = "lease" nmTokenLimit = "token-limit" + nmPluginDir = "plugin-dir" + nmPluginLoad = "plugin-load" nmProxyProtocolNetworks = "proxy-protocol-networks" nmProxyProtocolHeaderTimeout = "proxy-protocol-header-timeout" @@ -100,6 +103,8 @@ var ( runDDL = flagBoolean(nmRunDDL, true, "run ddl worker on this tidb-server") ddlLease = flag.String(nmDdlLease, "45s", "schema lease duration, very dangerous to change only if you know what you do") tokenLimit = flag.Int(nmTokenLimit, 1000, "the limit of concurrent executed sessions") + pluginDir = flag.String(nmPluginDir, "/data/deploy/plugin", "the folder that hold plugin") + pluginLoad = flag.String(nmPluginLoad, "", "wait load plugin name(seperated by comma)") // Log logLevel = flag.String(nmLogLevel, "info", "log level: info, debug, warn, error, fatal") @@ -322,6 +327,12 @@ func overrideConfig() { if actualFlags[nmTokenLimit] { cfg.TokenLimit = uint(*tokenLimit) } + if actualFlags[nmPluginLoad] { + cfg.Plugin.Load = *pluginLoad + } + if actualFlags[nmPluginDir] { + cfg.Plugin.Dir = *pluginDir + } // Log if actualFlags[nmLogLevel] { @@ -543,9 +554,9 @@ func closeDomainAndStorage() { func cleanup() { if graceful { - svr.GracefulDown() + svr.GracefulDown(context.Background(), nil) } else { - svr.KillAllConnections() + svr.TryGracefulDown() } closeDomainAndStorage() } diff --git a/types/convert.go b/types/convert.go index 27df51b771997..87e3cd82e24a8 100644 --- a/types/convert.go +++ b/types/convert.go @@ -106,7 +106,11 @@ func ConvertUintToInt(val uint64, upperBound int64, tp byte) (int64, error) { } // ConvertIntToUint converts an int value to an uint value. -func ConvertIntToUint(val int64, upperBound uint64, tp byte) (uint64, error) { +func ConvertIntToUint(sc *stmtctx.StatementContext, val int64, upperBound uint64, tp byte) (uint64, error) { + if sc.ShouldClipToZero() && val < 0 { + return 0, overflow(val, tp) + } + if uint64(val) > upperBound { return upperBound, overflow(val, tp) } @@ -124,9 +128,12 @@ func ConvertUintToUint(val uint64, upperBound uint64, tp byte) (uint64, error) { } // ConvertFloatToUint converts a float value to an uint value. -func ConvertFloatToUint(fval float64, upperBound uint64, tp byte) (uint64, error) { +func ConvertFloatToUint(sc *stmtctx.StatementContext, fval float64, upperBound uint64, tp byte) (uint64, error) { val := RoundFloat(fval) if val < 0 { + if sc.ShouldClipToZero() { + return 0, overflow(val, tp) + } return uint64(int64(val)), overflow(val, tp) } @@ -400,7 +407,7 @@ func ConvertJSONToInt(sc *stmtctx.StatementContext, j json.BinaryJSON, unsigned return ConvertFloatToInt(f, lBound, uBound, mysql.TypeDouble) } bound := UnsignedUpperBound[mysql.TypeLonglong] - u, err := ConvertFloatToUint(f, bound, mysql.TypeDouble) + u, err := ConvertFloatToUint(sc, f, bound, mysql.TypeDouble) return int64(u), errors.Trace(err) case json.TypeCodeString: return StrToInt(sc, hack.String(j.GetString())) @@ -423,7 +430,7 @@ func ConvertJSONToFloat(sc *stmtctx.StatementContext, j json.BinaryJSON) (float6 case json.TypeCodeInt64: return float64(j.GetInt64()), nil case json.TypeCodeUint64: - u, err := ConvertIntToUint(j.GetInt64(), UnsignedUpperBound[mysql.TypeLonglong], mysql.TypeLonglong) + u, err := ConvertIntToUint(sc, j.GetInt64(), UnsignedUpperBound[mysql.TypeLonglong], mysql.TypeLonglong) return float64(u), errors.Trace(err) case json.TypeCodeFloat64: return j.GetFloat64(), nil diff --git a/types/datum.go b/types/datum.go index a22a1e50634cc..ed340e00aaf33 100644 --- a/types/datum.go +++ b/types/datum.go @@ -866,21 +866,21 @@ func (d *Datum) convertToUint(sc *stmtctx.StatementContext, target *FieldType) ( ) switch d.k { case KindInt64: - val, err = ConvertIntToUint(d.GetInt64(), upperBound, tp) + val, err = ConvertIntToUint(sc, d.GetInt64(), upperBound, tp) case KindUint64: val, err = ConvertUintToUint(d.GetUint64(), upperBound, tp) case KindFloat32, KindFloat64: - val, err = ConvertFloatToUint(d.GetFloat64(), upperBound, tp) + val, err = ConvertFloatToUint(sc, d.GetFloat64(), upperBound, tp) case KindString, KindBytes: - val, err = StrToUint(sc, d.GetString()) - if err != nil { - return ret, errors.Trace(err) + uval, err1 := StrToUint(sc, d.GetString()) + if err1 != nil && ErrOverflow.Equal(err1) && !sc.ShouldIgnoreOverflowError() { + return ret, errors.Trace(err1) } - val, err = ConvertUintToUint(val, upperBound, tp) + val, err = ConvertUintToUint(uval, upperBound, tp) if err != nil { return ret, errors.Trace(err) } - ret.SetUint64(val) + err = err1 case KindMysqlTime: dec := d.GetMysqlTime().ToNumber() err = dec.Round(dec, 0, ModeHalfEven) @@ -888,7 +888,7 @@ func (d *Datum) convertToUint(sc *stmtctx.StatementContext, target *FieldType) ( if err == nil { err = err1 } - val, err1 = ConvertIntToUint(ival, upperBound, tp) + val, err1 = ConvertIntToUint(sc, ival, upperBound, tp) if err == nil { err = err1 } @@ -897,18 +897,18 @@ func (d *Datum) convertToUint(sc *stmtctx.StatementContext, target *FieldType) ( err = dec.Round(dec, 0, ModeHalfEven) ival, err1 := dec.ToInt() if err1 == nil { - val, err = ConvertIntToUint(ival, upperBound, tp) + val, err = ConvertIntToUint(sc, ival, upperBound, tp) } case KindMysqlDecimal: fval, err1 := d.GetMysqlDecimal().ToFloat64() - val, err = ConvertFloatToUint(fval, upperBound, tp) + val, err = ConvertFloatToUint(sc, fval, upperBound, tp) if err == nil { err = err1 } case KindMysqlEnum: - val, err = ConvertFloatToUint(d.GetMysqlEnum().ToNumber(), upperBound, tp) + val, err = ConvertFloatToUint(sc, d.GetMysqlEnum().ToNumber(), upperBound, tp) case KindMysqlSet: - val, err = ConvertFloatToUint(d.GetMysqlSet().ToNumber(), upperBound, tp) + val, err = ConvertFloatToUint(sc, d.GetMysqlSet().ToNumber(), upperBound, tp) case KindBinaryLiteral, KindMysqlBit: val, err = d.GetBinaryLiteral().ToInt(sc) case KindMysqlJSON: @@ -1138,7 +1138,7 @@ func ProduceDecWithSpecifiedTp(dec *MyDecimal, tp *FieldType, sc *stmtctx.Statem return nil, errors.Trace(err) } if !dec.IsZero() && frac > decimal && dec.Compare(&old) != 0 { - if sc.InInsertStmt || sc.InUpdateOrDeleteStmt { + if sc.InInsertStmt || sc.InUpdateStmt || sc.InDeleteStmt { // fix https://github.com/pingcap/tidb/issues/3895 // fix https://github.com/pingcap/tidb/issues/5532 sc.AppendWarning(ErrTruncated) diff --git a/types/format_test.go b/types/format_test.go index 16fbf6a963971..493ac31e75ccf 100644 --- a/types/format_test.go +++ b/types/format_test.go @@ -103,6 +103,17 @@ func (s *testTimeSuite) TestStrToDate(c *C) { {`10:13 PM`, `%l:%i %p`, types.FromDate(0, 0, 0, 22, 13, 0, 0)}, {`12:00:00 AM`, `%h:%i:%s %p`, types.FromDate(0, 0, 0, 0, 0, 0, 0)}, {`12:00:00 PM`, `%h:%i:%s %p`, types.FromDate(0, 0, 0, 12, 0, 0, 0)}, + {`18/10/22`, `%y/%m/%d`, types.FromDate(2018, 10, 22, 0, 0, 0, 0)}, + {`8/10/22`, `%y/%m/%d`, types.FromDate(2008, 10, 22, 0, 0, 0, 0)}, + {`69/10/22`, `%y/%m/%d`, types.FromDate(2069, 10, 22, 0, 0, 0, 0)}, + {`70/10/22`, `%y/%m/%d`, types.FromDate(1970, 10, 22, 0, 0, 0, 0)}, + {`18/10/22`, `%Y/%m/%d`, types.FromDate(2018, 10, 22, 0, 0, 0, 0)}, + {`2018/10/22`, `%Y/%m/%d`, types.FromDate(2018, 10, 22, 0, 0, 0, 0)}, + {`8/10/22`, `%Y/%m/%d`, types.FromDate(2008, 10, 22, 0, 0, 0, 0)}, + {`69/10/22`, `%Y/%m/%d`, types.FromDate(2069, 10, 22, 0, 0, 0, 0)}, + {`70/10/22`, `%Y/%m/%d`, types.FromDate(1970, 10, 22, 0, 0, 0, 0)}, + {`18/10/22`, `%Y/%m/%d`, types.FromDate(2018, 10, 22, 0, 0, 0, 0)}, + {`100/10/22`, `%Y/%m/%d`, types.FromDate(100, 10, 22, 0, 0, 0, 0)}, } for i, tt := range tests { var t types.Time @@ -121,6 +132,7 @@ func (s *testTimeSuite) TestStrToDate(c *C) { {`23:60:12`, `%T`}, // invalid minute {`18`, `%l`}, {`00:21:22 AM`, `%h:%i:%s %p`}, + {`100/10/22`, `%y/%m/%d`}, } for _, tt := range errTests { var t types.Time diff --git a/types/mydecimal.go b/types/mydecimal.go index afb34eca3523f..30c0d736bfcd8 100644 --- a/types/mydecimal.go +++ b/types/mydecimal.go @@ -234,6 +234,23 @@ func (d *MyDecimal) removeLeadingZeros() (wordIdx int, digitsInt int) { return } +func (d *MyDecimal) removeTrailingZeros() (lastWordIdx int, digitsFrac int) { + digitsFrac = int(d.digitsFrac) + i := ((digitsFrac - 1) % digitsPerWord) + 1 + lastWordIdx = digitsToWords(int(d.digitsInt)) + digitsToWords(int(d.digitsFrac)) + for digitsFrac > 0 && d.wordBuf[lastWordIdx-1] == 0 { + digitsFrac -= i + i = digitsPerWord + lastWordIdx-- + } + if digitsFrac > 0 { + digitsFrac -= countTrailingZeroes(9-((digitsFrac-1)%digitsPerWord), d.wordBuf[lastWordIdx-1]) + } else { + digitsFrac = 0 + } + return +} + // ToString converts decimal to its printable string representation without rounding. // // RETURN VALUE @@ -1212,6 +1229,26 @@ func (d *MyDecimal) ToBin(precision, frac int) ([]byte, error) { return bin, err } +// ToHashKey removes the leading and trailing zeros and generates a hash key. +// Two Decimals dec0 and dec1 with different fraction will generate the same hash keys if dec0.Compare(dec1) == 0. +func (d *MyDecimal) ToHashKey() ([]byte, error) { + _, digitsInt := d.removeLeadingZeros() + _, digitsFrac := d.removeTrailingZeros() + prec := digitsInt + digitsFrac + if prec == 0 { // zeroDecimal + prec = 1 + } + buf, err := d.ToBin(prec, digitsFrac) + if err == ErrTruncated { + // This err is caused by shorter digitsFrac; + // After removing the trailing zeros from a Decimal, + // so digitsFrac may be less than the real digitsFrac of the Decimal, + // thus ErrTruncated may be raised, we can ignore it here. + err = nil + } + return buf, err +} + // PrecisionAndFrac returns the internal precision and frac number. func (d *MyDecimal) PrecisionAndFrac() (precision, frac int) { frac = int(d.digitsFrac) diff --git a/types/mydecimal_test.go b/types/mydecimal_test.go index 1770554629b8c..e799692231c6a 100644 --- a/types/mydecimal_test.go +++ b/types/mydecimal_test.go @@ -15,6 +15,7 @@ package types import ( "strings" + "testing" . "github.com/pingcap/check" ) @@ -145,6 +146,114 @@ func (s *testMyDecimalSuite) TestToFloat(c *C) { } } +func (s *testMyDecimalSuite) TestToHashKey(c *C) { + tests := []struct { + numbers []string + }{ + {[]string{"1.1", "1.1000", "1.1000000", "1.10000000000", "01.1", "0001.1", "001.1000000"}}, + {[]string{"-1.1", "-1.1000", "-1.1000000", "-1.10000000000", "-01.1", "-0001.1", "-001.1000000"}}, + {[]string{".1", "0.1", "000000.1", ".10000", "0000.10000", "000000000000000000.1"}}, + {[]string{"0", "0000", ".0", ".00000", "00000.00000", "-0", "-0000", "-.0", "-.00000", "-00000.00000"}}, + {[]string{".123456789123456789", ".1234567891234567890", ".12345678912345678900", ".123456789123456789000", ".1234567891234567890000", "0.123456789123456789", + ".1234567891234567890000000000", "0000000.123456789123456789000"}}, + {[]string{"12345", "012345", "0012345", "0000012345", "0000000012345", "00000000000012345", "12345.", "12345.00", "12345.000000000", "000012345.0000"}}, + {[]string{"123E5", "12300000", "00123E5", "000000123E5", "12300000.00000000"}}, + {[]string{"123E-2", "1.23", "00000001.23", "1.2300000000000000", "000000001.23000000000000"}}, + } + for _, ca := range tests { + keys := make([]string, 0, len(ca.numbers)) + for _, num := range ca.numbers { + var dec MyDecimal + c.Check(dec.FromString([]byte(num)), IsNil) + key, err := dec.ToHashKey() + c.Check(err, IsNil) + keys = append(keys, string(key)) + } + + for i := 1; i < len(keys); i++ { + c.Check(keys[0], Equals, keys[i]) + } + } + + binTests := []struct { + hashNumbers []string + binNumbers []string + }{ + {[]string{"1.1", "1.1000", "1.1000000", "1.10000000000", "01.1", "0001.1", "001.1000000"}, + []string{"1.1", "0001.1", "01.1"}}, + {[]string{"-1.1", "-1.1000", "-1.1000000", "-1.10000000000", "-01.1", "-0001.1", "-001.1000000"}, + []string{"-1.1", "-0001.1", "-01.1"}}, + {[]string{".1", "0.1", "000000.1", ".10000", "0000.10000", "000000000000000000.1"}, + []string{".1", "0.1", "000000.1", "00.1"}}, + {[]string{"0", "0000", ".0", ".00000", "00000.00000", "-0", "-0000", "-.0", "-.00000", "-00000.00000"}, + []string{"0", "0000", "00", "-0", "-00", "-000000"}}, + {[]string{".123456789123456789", ".1234567891234567890", ".12345678912345678900", ".123456789123456789000", ".1234567891234567890000", "0.123456789123456789", + ".1234567891234567890000000000", "0000000.123456789123456789000"}, + []string{".123456789123456789", "0.123456789123456789", "0000.123456789123456789", "0000000.123456789123456789"}}, + {[]string{"12345", "012345", "0012345", "0000012345", "0000000012345", "00000000000012345", "12345.", "12345.00", "12345.000000000", "000012345.0000"}, + []string{"12345", "012345", "000012345", "000000000000012345"}}, + {[]string{"123E5", "12300000", "00123E5", "000000123E5", "12300000.00000000"}, + []string{"12300000", "123E5", "00123E5", "0000000000123E5"}}, + {[]string{"123E-2", "1.23", "00000001.23", "1.2300000000000000", "000000001.23000000000000"}, + []string{"123E-2", "1.23", "000001.23", "0000000000001.23"}}, + } + for _, ca := range binTests { + keys := make([]string, 0, len(ca.hashNumbers)+len(ca.binNumbers)) + for _, num := range ca.hashNumbers { + var dec MyDecimal + c.Check(dec.FromString([]byte(num)), IsNil) + key, err := dec.ToHashKey() + c.Check(err, IsNil) + keys = append(keys, string(key)) + } + for _, num := range ca.binNumbers { + var dec MyDecimal + c.Check(dec.FromString([]byte(num)), IsNil) + prec, frac := dec.PrecisionAndFrac() // remove leading zeros but trailing zeros remain + key, err := dec.ToBin(prec, frac) + c.Check(err, IsNil) + keys = append(keys, string(key)) + } + + for i := 1; i < len(keys); i++ { + c.Check(keys[0], Equals, keys[i]) + } + } +} + +func (s *testMyDecimalSuite) TestRemoveTrailingZeros(c *C) { + tests := []string{ + "0", "0.0", ".0", ".00000000", "0.0000", "0000", "0000.0", "0000.000", + "-0", "-0.0", "-.0", "-.00000000", "-0.0000", "-0000", "-0000.0", "-0000.000", + "123123123", "213123.", "21312.000", "21321.123", "213.1230000", "213123.000123000", + "-123123123", "-213123.", "-21312.000", "-21321.123", "-213.1230000", "-213123.000123000", + "123E5", "12300E-5", "0.00100E1", "0.001230E-3", + "123987654321.123456789000", "000000000123", "123456789.987654321", "999.999000", + } + for _, ca := range tests { + var dec MyDecimal + c.Check(dec.FromString([]byte(ca)), IsNil) + + // calculate the number of digits after point but trailing zero + digitsFracExp := 0 + str := string(dec.ToString()) + point := strings.Index(str, ".") + if point != -1 { + pos := len(str) - 1 + for pos > point { + if str[pos] != '0' { + break + } + pos-- + } + digitsFracExp = pos - point + } + + _, digitsFrac := dec.removeTrailingZeros() + c.Check(digitsFrac, Equals, digitsFracExp) + } +} + func (s *testMyDecimalSuite) TestShift(c *C) { type tcase struct { input string @@ -778,3 +887,57 @@ func (s *testMyDecimalSuite) TestMaxOrMin(c *C) { c.Assert(dec.String(), Equals, tt.result) } } + +func benchmarkMyDecimalToBinOrHashCases() []string { + return []string{ + "1.000000000000", "3", "12.000000000", "120", + "120000", "100000000000.00000", "0.000000001200000000", + "98765.4321", "-123.456000000000000000", + "0", "0000000000", "0.00000000000", + } +} + +func BenchmarkMyDecimalToBin(b *testing.B) { + cases := benchmarkMyDecimalToBinOrHashCases() + decs := make([]*MyDecimal, 0, len(cases)) + for _, ca := range cases { + var dec MyDecimal + if err := dec.FromString([]byte(ca)); err != nil { + b.Fatal(err) + } + decs = append(decs, &dec) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + for _, dec := range decs { + prec, frac := dec.PrecisionAndFrac() + _, err := dec.ToBin(prec, frac) + if err != nil { + b.Fatal(err) + } + } + } +} + +func BenchmarkMyDecimalToHashKey(b *testing.B) { + cases := benchmarkMyDecimalToBinOrHashCases() + decs := make([]*MyDecimal, 0, len(cases)) + for _, ca := range cases { + var dec MyDecimal + if err := dec.FromString([]byte(ca)); err != nil { + b.Fatal(err) + } + decs = append(decs, &dec) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + for _, dec := range decs { + _, err := dec.ToHashKey() + if err != nil { + b.Fatal(err) + } + } + } +} diff --git a/types/parser_driver/value_expr.go b/types/parser_driver/value_expr.go index 08b6872ce2f21..9ab8e66e8f650 100644 --- a/types/parser_driver/value_expr.go +++ b/types/parser_driver/value_expr.go @@ -20,6 +20,7 @@ import ( "github.com/pingcap/errors" "github.com/pingcap/parser/ast" + "github.com/pingcap/parser/format" "github.com/pingcap/parser/mysql" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/hack" @@ -69,7 +70,7 @@ type ValueExpr struct { } // Restore implements Node interface. -func (n *ValueExpr) Restore(ctx *ast.RestoreCtx) error { +func (n *ValueExpr) Restore(ctx *format.RestoreCtx) error { switch n.Kind() { case types.KindNull: ctx.WriteKeyWord("NULL") @@ -195,7 +196,7 @@ type ParamMarkerExpr struct { } // Restore implements Node interface. -func (n *ParamMarkerExpr) Restore(ctx *ast.RestoreCtx) error { +func (n *ParamMarkerExpr) Restore(ctx *format.RestoreCtx) error { ctx.WritePlain("?") return nil } diff --git a/types/parser_driver/value_exprl_test.go b/types/parser_driver/value_expr_test.go similarity index 93% rename from types/parser_driver/value_exprl_test.go rename to types/parser_driver/value_expr_test.go index 07746f77dc42d..7cbc6643246a6 100644 --- a/types/parser_driver/value_exprl_test.go +++ b/types/parser_driver/value_expr_test.go @@ -18,7 +18,7 @@ import ( "testing" . "github.com/pingcap/check" - "github.com/pingcap/parser/ast" + "github.com/pingcap/parser/format" "github.com/pingcap/tidb/types" ) @@ -52,7 +52,7 @@ func (s *testValueExprRestoreSuite) TestValueExprRestore(c *C) { for _, testCase := range testCases { sb.Reset() expr := &ValueExpr{Datum: testCase.datum} - err := expr.Restore(ast.NewRestoreCtx(ast.DefaultRestoreFlags, &sb)) + err := expr.Restore(format.NewRestoreCtx(format.DefaultRestoreFlags, &sb)) c.Assert(err, IsNil) c.Assert(sb.String(), Equals, testCase.expect, Commentf("Datum: %#v", testCase.datum)) } diff --git a/types/time.go b/types/time.go index fec734025aea1..41d32a7422245 100644 --- a/types/time.go +++ b/types/time.go @@ -2165,6 +2165,8 @@ var dateFormatParserTable = map[string]dateFormatParser{ "%S": secondsNumeric, // Seconds (00..59) "%T": time24Hour, // Time, 24-hour (hh:mm:ss) "%Y": yearNumericFourDigits, // Year, numeric, four digits + // Deprecated since MySQL 5.7.5 + "%y": yearNumericTwoDigits, // Year, numeric (two digits) // TODO: Add the following... // "%a": abbreviatedWeekday, // Abbreviated weekday name (Sun..Sat) // "%D": dayOfMonthWithSuffix, // Day of the month with English suffix (0th, 1st, 2nd, 3rd) @@ -2176,8 +2178,6 @@ var dateFormatParserTable = map[string]dateFormatParser{ // "%w": dayOfWeek, // Day of the week (0=Sunday..6=Saturday) // "%X": yearOfWeek, // Year for the week where Sunday is the first day of the week, numeric, four digits; used with %V // "%x": yearOfWeek, // Year for the week, where Monday is the first day of the week, numeric, four digits; used with %v - // Deprecated since MySQL 5.7.5 - // "%y": yearTwoDigits, // Year, numeric (two digits) } // GetFormatType checks the type(Duration, Date or Datetime) of a format string. @@ -2235,7 +2235,7 @@ func matchDateWithToken(t *MysqlTime, date string, token string, ctx map[string] } func parseDigits(input string, count int) (int, bool) { - if len(input) < count { + if count <= 0 || len(input) < count { return 0, false } @@ -2432,12 +2432,31 @@ func microSeconds(t *MysqlTime, input string, ctx map[string]int) (string, bool) } func yearNumericFourDigits(t *MysqlTime, input string, ctx map[string]int) (string, bool) { - v, succ := parseDigits(input, 4) - if !succ { + return yearNumericNDigits(t, input, ctx, 4) +} + +func yearNumericTwoDigits(t *MysqlTime, input string, ctx map[string]int) (string, bool) { + return yearNumericNDigits(t, input, ctx, 2) +} + +func yearNumericNDigits(t *MysqlTime, input string, ctx map[string]int, n int) (string, bool) { + effectiveCount, effectiveValue := 0, 0 + for effectiveCount+1 <= n { + value, succeed := parseDigits(input, effectiveCount+1) + if !succeed { + break + } + effectiveCount++ + effectiveValue = value + } + if effectiveCount == 0 { return input, false } - t.year = uint16(v) - return input[4:], true + if effectiveCount <= 2 { + effectiveValue = adjustYear(effectiveValue) + } + t.year = uint16(effectiveValue) + return input[effectiveCount:], true } func dayOfYearThreeDigits(t *MysqlTime, input string, ctx map[string]int) (string, bool) { diff --git a/util/chunk/recordbatch.go b/util/chunk/recordbatch.go new file mode 100644 index 0000000000000..7eb79f54f4333 --- /dev/null +++ b/util/chunk/recordbatch.go @@ -0,0 +1,24 @@ +// Copyright 2019 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package chunk + +// RecordBatch is input parameter of Executor.Next` method. +type RecordBatch struct { + *Chunk +} + +// NewRecordBatch is used to construct a RecordBatch. +func NewRecordBatch(chk *Chunk) *RecordBatch { + return &RecordBatch{chk} +} diff --git a/util/codec/codec.go b/util/codec/codec.go index fae2935538246..0af81e36ae5f7 100644 --- a/util/codec/codec.go +++ b/util/codec/codec.go @@ -91,9 +91,8 @@ func encode(sc *stmtctx.StatementContext, b []byte, vals []types.Datum, comparab if hash { // If hash is true, we only consider the original value of this decimal and ignore it's precision. dec := vals[i].GetMysqlDecimal() - precision, frac := dec.PrecisionAndFrac() var bin []byte - bin, err = dec.ToBin(precision, frac) + bin, err = dec.ToHashKey() if err != nil { return nil, errors.Trace(err) } @@ -238,9 +237,7 @@ func encodeChunkRow(sc *stmtctx.StatementContext, b []byte, row chunk.Row, allTy if hash { // If hash is true, we only consider the original value of this decimal and ignore it's precision. dec := row.GetMyDecimal(i) - precision, frac := dec.PrecisionAndFrac() - var bin []byte - bin, err = dec.ToBin(precision, frac) + bin, err := dec.ToHashKey() if err != nil { return nil, errors.Trace(err) } diff --git a/util/encrypt/aes.go b/util/encrypt/aes.go index 435f3bdc7c6ee..7b1d39644b059 100644 --- a/util/encrypt/aes.go +++ b/util/encrypt/aes.go @@ -173,6 +173,30 @@ func AESDecryptWithCBC(cryptStr, key []byte, iv []byte) ([]byte, error) { return aesDecrypt(cryptStr, mode) } +// AESEncryptWithOFB encrypts data using AES with OFB mode. +func AESEncryptWithOFB(plainStr []byte, key []byte, iv []byte) ([]byte, error) { + cb, err := aes.NewCipher(key) + if err != nil { + return nil, errors.Trace(err) + } + mode := cipher.NewOFB(cb, iv) + crypted := make([]byte, len(plainStr)) + mode.XORKeyStream(crypted, plainStr) + return crypted, nil +} + +// AESDecryptWithOFB decrypts data using AES with OFB mode. +func AESDecryptWithOFB(cipherStr []byte, key []byte, iv []byte) ([]byte, error) { + cb, err := aes.NewCipher(key) + if err != nil { + return nil, errors.Trace(err) + } + mode := cipher.NewOFB(cb, iv) + plainStr := make([]byte, len(cipherStr)) + mode.XORKeyStream(plainStr, cipherStr) + return plainStr, nil +} + // AESEncryptWithCFB decrypts data using AES with CFB mode. func AESEncryptWithCFB(cryptStr, key []byte, iv []byte) ([]byte, error) { cb, err := aes.NewCipher(key) diff --git a/util/encrypt/aes_test.go b/util/encrypt/aes_test.go index 14cd1e4b77fa2..0ea417c82090e 100644 --- a/util/encrypt/aes_test.go +++ b/util/encrypt/aes_test.go @@ -304,6 +304,75 @@ func (s *testEncryptSuite) TestAESEncryptWithCBC(c *C) { } } +func (s *testEncryptSuite) TestAESEncryptWithOFB(c *C) { + defer testleak.AfterTest(c)() + tests := []struct { + str string + key string + iv string + expect string + isError bool + }{ + // 128 bits key + {"pingcap", "1234567890123456", "1234567890123456", "0515A36BBF3DE0", false}, + {"pingcap123", "1234567890123456", "1234567890123456", "0515A36BBF3DE0DBE9DD", false}, + // 192 bits key + {"pingcap", "123456789012345678901234", "1234567890123456", "45A57592449893", false}, // 192 bit + // negtive cases: invalid key length + {"pingcap", "12345678901234567", "1234567890123456", "", true}, + {"pingcap", "123456789012345", "1234567890123456", "", true}, + } + + for _, t := range tests { + str := []byte(t.str) + key := []byte(t.key) + iv := []byte(t.iv) + + crypted, err := AESEncryptWithOFB(str, key, iv) + if t.isError { + c.Assert(err, NotNil, Commentf("%v", t)) + continue + } + c.Assert(err, IsNil, Commentf("%v", t)) + result := toHex(crypted) + c.Assert(result, Equals, t.expect, Commentf("%v", t)) + } +} + +func (s *testEncryptSuite) TestAESDecryptWithOFB(c *C) { + defer testleak.AfterTest(c)() + tests := []struct { + str string + key string + iv string + expect string + isError bool + }{ + // 128 bits key + {"0515A36BBF3DE0", "1234567890123456", "1234567890123456", "pingcap", false}, + {"0515A36BBF3DE0DBE9DD", "1234567890123456", "1234567890123456", "pingcap123", false}, + // 192 bits key + {"45A57592449893", "123456789012345678901234", "1234567890123456", "pingcap", false}, // 192 bit + // negtive cases: invalid key length + {"pingcap", "12345678901234567", "1234567890123456", "", true}, + {"pingcap", "123456789012345", "1234567890123456", "", true}, + } + + for _, t := range tests { + str, _ := hex.DecodeString(t.str) + key := []byte(t.key) + iv := []byte(t.iv) + + plainText, err := AESDecryptWithOFB(str, key, iv) + if t.isError { + c.Assert(err, NotNil, Commentf("%v", t)) + continue + } + c.Assert(err, IsNil, Commentf("%v", t)) + c.Assert(string(plainText), Equals, t.expect, Commentf("%v", t)) + } +} + func (s *testEncryptSuite) TestAESDecryptWithCBC(c *C) { defer testleak.AfterTest(c)() tests := []struct { @@ -342,7 +411,7 @@ func (s *testEncryptSuite) TestAESDecryptWithCBC(c *C) { } } -func (s *testEncryptSuite) TestAESEncryptWithOFB(c *C) { +func (s *testEncryptSuite) TestAESEncryptWithCFB(c *C) { defer testleak.AfterTest(c)() tests := []struct { str string diff --git a/util/execdetails/execdetails.go b/util/execdetails/execdetails.go index 184f5d4fde6e4..aaebbda1a3cc8 100644 --- a/util/execdetails/execdetails.go +++ b/util/execdetails/execdetails.go @@ -53,13 +53,13 @@ type CommitDetails struct { func (d ExecDetails) String() string { parts := make([]string, 0, 6) if d.ProcessTime > 0 { - parts = append(parts, fmt.Sprintf("process_time:%v", d.ProcessTime)) + parts = append(parts, fmt.Sprintf("process_time:%vs", d.ProcessTime.Seconds())) } if d.WaitTime > 0 { - parts = append(parts, fmt.Sprintf("wait_time:%v", d.WaitTime)) + parts = append(parts, fmt.Sprintf("wait_time:%vs", d.WaitTime.Seconds())) } if d.BackoffTime > 0 { - parts = append(parts, fmt.Sprintf("backoff_time:%v", d.BackoffTime)) + parts = append(parts, fmt.Sprintf("backoff_time:%vs", d.BackoffTime.Seconds())) } if d.RequestCount > 0 { parts = append(parts, fmt.Sprintf("request_count:%d", d.RequestCount)) @@ -73,23 +73,23 @@ func (d ExecDetails) String() string { commitDetails := d.CommitDetail if commitDetails != nil { if commitDetails.PrewriteTime > 0 { - parts = append(parts, fmt.Sprintf("prewrite_time:%v", commitDetails.PrewriteTime)) + parts = append(parts, fmt.Sprintf("prewrite_time:%vs", commitDetails.PrewriteTime.Seconds())) } if commitDetails.CommitTime > 0 { - parts = append(parts, fmt.Sprintf("commit_time:%v", commitDetails.CommitTime)) + parts = append(parts, fmt.Sprintf("commit_time:%vs", commitDetails.CommitTime.Seconds())) } if commitDetails.GetCommitTsTime > 0 { - parts = append(parts, fmt.Sprintf("get_commit_ts_time:%v", commitDetails.GetCommitTsTime)) + parts = append(parts, fmt.Sprintf("get_commit_ts_time:%vs", commitDetails.GetCommitTsTime.Seconds())) } if commitDetails.TotalBackoffTime > 0 { - parts = append(parts, fmt.Sprintf("total_backoff_time:%v", commitDetails.TotalBackoffTime)) + parts = append(parts, fmt.Sprintf("total_backoff_time:%vs", commitDetails.TotalBackoffTime.Seconds())) } resolveLockTime := atomic.LoadInt64(&commitDetails.ResolveLockTime) if resolveLockTime > 0 { - parts = append(parts, fmt.Sprintf("resolve_lock_time:%d", time.Duration(resolveLockTime))) + parts = append(parts, fmt.Sprintf("resolve_lock_time:%vs", time.Duration(resolveLockTime).Seconds())) } if commitDetails.LocalLatchTime > 0 { - parts = append(parts, fmt.Sprintf("local_latch_wait_time:%v", commitDetails.LocalLatchTime)) + parts = append(parts, fmt.Sprintf("local_latch_wait_time:%vs", commitDetails.LocalLatchTime.Seconds())) } if commitDetails.WriteKeys > 0 { parts = append(parts, fmt.Sprintf("write_keys:%d", commitDetails.WriteKeys)) diff --git a/util/execdetails/execdetails_test.go b/util/execdetails/execdetails_test.go index b69f2229c7668..cd69856070bf8 100644 --- a/util/execdetails/execdetails_test.go +++ b/util/execdetails/execdetails_test.go @@ -20,14 +20,26 @@ import ( func TestString(t *testing.T) { detail := &ExecDetails{ - ProcessTime: time.Second, + ProcessTime: 2*time.Second + 5*time.Millisecond, WaitTime: time.Second, BackoffTime: time.Second, RequestCount: 1, TotalKeys: 100, ProcessedKeys: 10, + CommitDetail: &CommitDetails{ + GetCommitTsTime: time.Second, + PrewriteTime: time.Second, + CommitTime: time.Second, + LocalLatchTime: time.Second, + TotalBackoffTime: time.Second, + ResolveLockTime: 1000000000, // 10^9 ns = 1s + WriteKeys: 1, + WriteSize: 1, + PrewriteRegionNum: 1, + TxnRetry: 1, + }, } - expected := "process_time:1s wait_time:1s backoff_time:1s request_count:1 total_keys:100 processed_keys:10" + expected := "process_time:2.005s wait_time:1s backoff_time:1s request_count:1 total_keys:100 processed_keys:10 prewrite_time:1s commit_time:1s get_commit_ts_time:1s total_backoff_time:1s resolve_lock_time:1s local_latch_wait_time:1s write_keys:1 write_size:1 prewrite_region:1 txn_retry:1" if str := detail.String(); str != expected { t.Errorf("got:\n%s\nexpected:\n%s", str, expected) } diff --git a/util/sqlexec/restricted_sql_executor.go b/util/sqlexec/restricted_sql_executor.go index 0fac60de09a33..02ef84a95d89f 100644 --- a/util/sqlexec/restricted_sql_executor.go +++ b/util/sqlexec/restricted_sql_executor.go @@ -86,10 +86,10 @@ type RecordSet interface { Fields() []*ast.ResultField // Next reads records into chunk. - Next(ctx context.Context, chk *chunk.Chunk) error + Next(ctx context.Context, req *chunk.RecordBatch) error - // NewChunk creates a new chunk with initial capacity. - NewChunk() *chunk.Chunk + //NewRecordBatch create a recordBatch. + NewRecordBatch() *chunk.RecordBatch // Close closes the underlying iterator, call Next after Close will // restart the iteration.