Skip to content

Commit

Permalink
load data: sync rewrite set expr for concurrent load (#43492)
Browse files Browse the repository at this point in the history
ref #42930
  • Loading branch information
D3Hunter authored May 9, 2023
1 parent 09bf23f commit 5f11392
Show file tree
Hide file tree
Showing 5 changed files with 87 additions and 34 deletions.
27 changes: 27 additions & 0 deletions executor/importer/import.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"path/filepath"
"runtime"
"strings"
"sync"
"unicode/utf8"

"github.com/pingcap/errors"
Expand All @@ -32,6 +33,7 @@ import (
"github.com/pingcap/tidb/br/pkg/lightning/mydump"
"github.com/pingcap/tidb/br/pkg/storage"
"github.com/pingcap/tidb/executor/asyncloaddata"
"github.com/pingcap/tidb/expression"
tidbkv "github.com/pingcap/tidb/kv"
"github.com/pingcap/tidb/parser"
"github.com/pingcap/tidb/parser/ast"
Expand Down Expand Up @@ -192,6 +194,9 @@ type LoadDataController struct {
*Plan
*ASTArgs

// used for sync column assignment expression generation.
colAssignMu sync.Mutex

Table table.Table

// how input field(or input column) from data file is mapped, either to a column or variable.
Expand Down Expand Up @@ -940,6 +945,28 @@ func (e *LoadDataController) toMyDumpFiles() []mydump.FileInfo {
return res
}

// CreateColAssignExprs creates the column assignment expressions using session context.
// RewriteAstExpr will write ast node in place(due to xxNode.Accept), but it doesn't change node content,
// so we sync it.
func (e *LoadDataController) CreateColAssignExprs(sctx sessionctx.Context) ([]expression.Expression, []stmtctx.SQLWarn, error) {
e.colAssignMu.Lock()
defer e.colAssignMu.Unlock()
res := make([]expression.Expression, 0, len(e.ColumnAssignments))
allWarnings := []stmtctx.SQLWarn{}
for _, assign := range e.ColumnAssignments {
newExpr, err := expression.RewriteAstExpr(sctx, assign.Expr, nil, nil, false)
// col assign expr warnings is static, we should generate it for each row processed.
// so we save it and clear it here.
allWarnings = append(allWarnings, sctx.GetSessionVars().StmtCtx.GetWarnings()...)
sctx.GetSessionVars().StmtCtx.SetWarnings(nil)
if err != nil {
return nil, nil, err
}
res = append(res, newExpr)
}
return res, allWarnings, nil
}

// JobImportParam is the param of the job import.
type JobImportParam struct {
Job *asyncloaddata.Job
Expand Down
22 changes: 12 additions & 10 deletions executor/importer/kv_encode.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import (
"github.com/pingcap/tidb/sessionctx/variable"
"github.com/pingcap/tidb/table"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/util/chunk"
)

type kvEncoder interface {
Expand All @@ -44,7 +45,7 @@ type kvEncoder interface {
type tableKVEncoder struct {
*kv.BaseKVEncoder
// see import.go
columnAssignments []*ast.Assignment
columnAssignments []expression.Expression
columnsAndUserVars []*ast.ColumnNameOrUserVar
fieldMappings []*FieldMapping
insertColumns []*table.Column
Expand All @@ -54,24 +55,25 @@ var _ kvEncoder = &tableKVEncoder{}

func newTableKVEncoder(
config *encode.EncodingConfig,
columnAssignments []*ast.Assignment,
columnsAndUserVars []*ast.ColumnNameOrUserVar,
fieldMappings []*FieldMapping,
insertColumns []*table.Column,
ti *TableImporter,
) (*tableKVEncoder, error) {
baseKVEncoder, err := kv.NewBaseKVEncoder(config)
if err != nil {
return nil, err
}
// we need a non-nil TxnCtx to avoid panic when evaluating set clause
baseKVEncoder.SessionCtx.Vars.TxnCtx = new(variable.TransactionContext)
colAssignExprs, _, err := ti.CreateColAssignExprs(baseKVEncoder.SessionCtx)
if err != nil {
return nil, err
}

return &tableKVEncoder{
BaseKVEncoder: baseKVEncoder,
columnAssignments: columnAssignments,
columnsAndUserVars: columnsAndUserVars,
fieldMappings: fieldMappings,
insertColumns: insertColumns,
columnAssignments: colAssignExprs,
columnsAndUserVars: ti.ColumnsAndUserVars,
fieldMappings: ti.FieldMappings,
insertColumns: ti.InsertColumns,
}, nil
}

Expand Down Expand Up @@ -131,7 +133,7 @@ func (en *tableKVEncoder) parserData2TableData(parserData []types.Datum, rowID i
}
for i := 0; i < len(en.columnAssignments); i++ {
// eval expression of `SET` clause
d, err := expression.EvalAstExpr(en.SessionCtx, en.columnAssignments[i].Expr)
d, err := en.columnAssignments[i].Eval(chunk.Row{})
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion executor/importer/table_import.go
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ func (ti *TableImporter) getKVEncoder(chunk *checkpoints.ChunkCheckpoint) (kvEnc
Table: ti.encTable,
Logger: log.Logger{Logger: ti.logger.With(zap.String("path", chunk.FileMeta.Path))},
}
return newTableKVEncoder(cfg, ti.ColumnAssignments, ti.ColumnsAndUserVars, ti.FieldMappings, ti.InsertColumns)
return newTableKVEncoder(cfg, ti)
}

func (ti *TableImporter) importTable(ctx context.Context) (err error) {
Expand Down
34 changes: 26 additions & 8 deletions executor/load_data.go
Original file line number Diff line number Diff line change
Expand Up @@ -422,10 +422,16 @@ func (ji *logicalJobImporter) initEncodeCommitWorkers(e *LoadDataWorker) (err er
return err2
}
createdSessions = append(createdSessions, commitCore.ctx)
colAssignExprs, exprWarnings, err2 := e.controller.CreateColAssignExprs(encodeCore.ctx)
if err2 != nil {
return err2
}
encode := &encodeWorker{
InsertValues: encodeCore,
controller: e.controller,
killed: &e.UserSctx.GetSessionVars().Killed,
InsertValues: encodeCore,
controller: e.controller,
colAssignExprs: colAssignExprs,
exprWarnings: exprWarnings,
killed: &e.UserSctx.GetSessionVars().Killed,
}
encode.resetBatch()
encodeWorkers = append(encodeWorkers, encode)
Expand Down Expand Up @@ -583,6 +589,11 @@ func (ji *logicalJobImporter) Result() importer.JobImportResult {
numSkipped += commitStmtCtx.RecordRows() - commitStmtCtx.CopiedRows()
}

// col assign expr warnings is generated during init, it's static
// we need to generate it for each row processed.
colAssignExprWarnings := ji.encodeWorkers[0].exprWarnings
numWarnings += numRecords * uint64(len(colAssignExprWarnings))

if numWarnings > math.MaxUint16 {
numWarnings = math.MaxUint16
}
Expand All @@ -596,6 +607,9 @@ func (ji *logicalJobImporter) Result() importer.JobImportResult {
for _, w := range ji.commitWorkers {
n += copy(warns[n:], w.ctx.GetSessionVars().StmtCtx.GetWarnings())
}
for i := 0; i < int(numRecords) && n < len(warns); i++ {
n += copy(warns[n:], colAssignExprWarnings)
}
return importer.JobImportResult{
Msg: msg,
LastInsertID: ji.getLastInsertID(),
Expand Down Expand Up @@ -640,9 +654,13 @@ func (ji *logicalJobImporter) Close() error {
// encodeWorker is a sub-worker of LoadDataWorker that dedicated to encode data.
type encodeWorker struct {
*InsertValues
controller *importer.LoadDataController
killed *uint32
rows [][]types.Datum
controller *importer.LoadDataController
colAssignExprs []expression.Expression
// sessionCtx generate warnings when rewrite AST node into expression.
// we should generate such warnings for each row encoded.
exprWarnings []stmtctx.SQLWarn
killed *uint32
rows [][]types.Datum
}

// processStream always trys to build a parser from channel and process it. When
Expand Down Expand Up @@ -837,9 +855,9 @@ func (w *encodeWorker) parserData2TableData(

row = append(row, parserData[i])
}
for i := 0; i < len(w.controller.ColumnAssignments); i++ {
for i := 0; i < len(w.colAssignExprs); i++ {
// eval expression of `SET` clause
d, err := expression.EvalAstExpr(w.ctx, w.controller.ColumnAssignments[i].Expr)
d, err := w.colAssignExprs[i].Eval(chunk.Row{})
if err != nil {
if w.controller.Restrictive {
return nil, err
Expand Down
36 changes: 21 additions & 15 deletions tests/realtikvtest/loaddatatest/load_data_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -829,28 +829,34 @@ func (s *mockGCSSuite) TestColumnsAndUserVars() {
}

func (s *mockGCSSuite) testColumnsAndUserVars(importMode string, distributed bool) {
withOptions := fmt.Sprintf("WITH thread=1, import_mode='%s'", importMode)
withOptions := fmt.Sprintf("WITH thread=2, import_mode='%s'", importMode)
withOptions = adjustOptions(withOptions, distributed)
s.prepareVariables(distributed)
s.tk.MustExec("DROP DATABASE IF EXISTS load_data;")
s.tk.MustExec("CREATE DATABASE load_data;")
s.tk.MustExec(`CREATE TABLE load_data.cols_and_vars (a INT, b INT);`)
s.tk.MustExec(`CREATE TABLE load_data.cols_and_vars (a INT, b INT, c int);`)

s.server.CreateObject(fakestorage.Object{
ObjectAttrs: fakestorage.ObjectAttrs{
BucketName: "test-load",
Name: "cols_and_vars.tsv",
},
Content: []byte("1\n2\n3\n4\n5\n"),
ObjectAttrs: fakestorage.ObjectAttrs{BucketName: "test-load", Name: "cols_and_vars-1.tsv"},
Content: []byte("1,11,111\n2,22,222\n3,33,333\n4,44,444\n5,55,555\n"),
})
s.server.CreateObject(fakestorage.Object{
ObjectAttrs: fakestorage.ObjectAttrs{BucketName: "test-load", Name: "cols_and_vars-2.tsv"},
Content: []byte("6,66,666\n7,77,777\n8,88,888\n9,99,999\n"),
})
sql := fmt.Sprintf(`LOAD DATA INFILE 'gs://test-load/cols_and_vars.tsv?endpoint=%s'
INTO TABLE load_data.cols_and_vars(@V1) set a=@V1, b=@V1*100 %s`, gcsEndpoint, withOptions)
sql := fmt.Sprintf(`LOAD DATA INFILE 'gs://test-load/cols_and_vars-*.tsv?endpoint=%s'
INTO TABLE load_data.cols_and_vars fields terminated by ','
(@V1, @v2, @v3) set a=@V1, b=@V2*10, c=123 %s`, gcsEndpoint, withOptions)
s.tk.MustExec(sql)
s.tk.MustQuery("SELECT * FROM load_data.cols_and_vars;").Check(testkit.Rows(
"1 100",
"2 200",
"3 300",
"4 400",
"5 500",
s.tk.MustQuery("SELECT * FROM load_data.cols_and_vars;").Sort().Check(testkit.Rows(
"1 110 123",
"2 220 123",
"3 330 123",
"4 440 123",
"5 550 123",
"6 660 123",
"7 770 123",
"8 880 123",
"9 990 123",
))
}

0 comments on commit 5f11392

Please sign in to comment.