Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

*: Fix issue 39999, used wrong column id list for checking partitions #40003

Merged
merged 9 commits into from
Dec 20, 2022
23 changes: 8 additions & 15 deletions executor/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -3511,13 +3511,10 @@ func buildIndexRangeForEachPartition(ctx sessionctx.Context, usedPartitions []ta
return nextRange, nil
}

func keyColumnsIncludeAllPartitionColumns(keyColumns []int, pe *tables.PartitionExpr) bool {
tmp := make(map[int]struct{}, len(keyColumns))
for _, offset := range keyColumns {
tmp[offset] = struct{}{}
}
for _, offset := range pe.ColumnOffset {
if _, ok := tmp[offset]; !ok {
func keyColumnsIncludeAllPartitionColumns(keyColumnIDs []int64, pt table.PartitionedTable) bool {
partColIDs := pt.GetPartitionColumnIDs()
for _, id := range keyColumnIDs {
if _, ok := partColIDs[id]; !ok {
return false
}
}
Expand Down Expand Up @@ -4149,12 +4146,6 @@ func (builder *dataReaderBuilder) buildTableReaderForIndexJoin(ctx context.Conte
}
tbl, _ := builder.is.TableByID(tbInfo.ID)
pt := tbl.(table.PartitionedTable)
pe, err := tbl.(interface {
bb7133 marked this conversation as resolved.
Show resolved Hide resolved
PartitionExpr() (*tables.PartitionExpr, error)
}).PartitionExpr()
if err != nil {
return nil, err
}
partitionInfo := &v.PartitionInfo
usedPartitionList, err := builder.partitionPruning(pt, partitionInfo.PruningConds, partitionInfo.PartitionNames, partitionInfo.Columns, partitionInfo.ColumnNames)
if err != nil {
Expand All @@ -4166,13 +4157,14 @@ func (builder *dataReaderBuilder) buildTableReaderForIndexJoin(ctx context.Conte
}
var kvRanges []kv.KeyRange
if v.IsCommonHandle {
if len(lookUpContents) > 0 && keyColumnsIncludeAllPartitionColumns(lookUpContents[0].keyCols, pe) {
if len(lookUpContents) > 0 && keyColumnsIncludeAllPartitionColumns(lookUpContents[0].keyColIDs, pt) {
locateKey := make([]types.Datum, e.Schema().Len())
kvRanges = make([]kv.KeyRange, 0, len(lookUpContents))
// lookUpContentsByPID groups lookUpContents by pid(partition) so that kv ranges for same partition can be merged.
lookUpContentsByPID := make(map[int64][]*indexJoinLookUpContent)
for _, content := range lookUpContents {
for i, date := range content.keys {
// TODO: Add a test where partition column is not a prefix or out of order with source joined table
locateKey[content.keyCols[i]] = date
}
p, err := pt.GetPartitionByRow(e.ctx, locateKey)
Expand Down Expand Up @@ -4212,11 +4204,12 @@ func (builder *dataReaderBuilder) buildTableReaderForIndexJoin(ctx context.Conte

handles, lookUpContents := dedupHandles(lookUpContents)

if len(lookUpContents) > 0 && keyColumnsIncludeAllPartitionColumns(lookUpContents[0].keyCols, pe) {
if len(lookUpContents) > 0 && keyColumnsIncludeAllPartitionColumns(lookUpContents[0].keyColIDs, pt) {
locateKey := make([]types.Datum, e.Schema().Len())
kvRanges = make([]kv.KeyRange, 0, len(lookUpContents))
for _, content := range lookUpContents {
for i, date := range content.keys {
// TODO: Add test to see if keyColIDs should be used instead?
locateKey[content.keyCols[i]] = date
}
p, err := pt.GetPartitionByRow(e.ctx, locateKey)
Expand Down
56 changes: 56 additions & 0 deletions executor/partition_table_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3831,3 +3831,59 @@ func TestIssue21732(t *testing.T) {
})
}
}

func TestIssue39999(t *testing.T) {
store := testkit.CreateMockStore(t)

tk := testkit.NewTestKit(t, store)

tk.MustExec(`create schema test39999`)
tk.MustExec(`use test39999`)
tk.MustExec(`drop table if exists c, t`)
tk.MustExec("CREATE TABLE `c` (" +
"`serial_id` varchar(24)," +
"`txt_account_id` varchar(24)," +
"`capital_sub_class` varchar(10)," +
"`occur_trade_date` date," +
"`occur_amount` decimal(16,2)," +
"`broker` varchar(10)," +
"PRIMARY KEY (`txt_account_id`,`occur_trade_date`,`serial_id`) /*T![clustered_index] CLUSTERED */," +
"KEY `idx_serial_id` (`serial_id`)" +
") ENGINE=InnoDB DEFAULT CHARSET=utf8 COLLATE=utf8_unicode_ci " +
"PARTITION BY RANGE COLUMNS(`serial_id`) (" +
"PARTITION `p202209` VALUES LESS THAN ('20221001')," +
"PARTITION `p202210` VALUES LESS THAN ('20221101')," +
"PARTITION `p202211` VALUES LESS THAN ('20221201')" +
")")

tk.MustExec("CREATE TABLE `t` ( " +
"`txn_account_id` varchar(24), " +
"`account_id` varchar(32), " +
"`broker` varchar(10), " +
"PRIMARY KEY (`txn_account_id`) /*T![clustered_index] CLUSTERED */ " +
") ENGINE=InnoDB DEFAULT CHARSET=utf8 COLLATE=utf8_unicode_ci")

tk.MustExec("INSERT INTO `c` VALUES ('2022111700196920','04482786','CUST','2022-11-17',-2.01,'0009')")
tk.MustExec("INSERT INTO `t` VALUES ('04482786','1142927','0009')")

tk.MustExec(`set tidb_partition_prune_mode='dynamic'`)
tk.MustExec(`analyze table c`)
tk.MustExec(`analyze table t`)
query := `select
/*+ inl_join(c) */
c.occur_amount
from
c
join t on c.txt_account_id = t.txn_account_id
and t.broker = '0009'
and c.occur_trade_date = '2022-11-17'`
tk.MustQuery("explain " + query).Check(testkit.Rows(""+
"IndexJoin_22 1.00 root inner join, inner:TableReader_21, outer key:test39999.t.txn_account_id, inner key:test39999.c.txt_account_id, equal cond:eq(test39999.t.txn_account_id, test39999.c.txt_account_id)",
"├─TableReader_27(Build) 1.00 root data:Selection_26",
"│ └─Selection_26 1.00 cop[tikv] eq(test39999.t.broker, \"0009\")",
"│ └─TableFullScan_25 1.00 cop[tikv] table:t keep order:false",
"└─TableReader_21(Probe) 1.00 root partition:all data:Selection_20",
" └─Selection_20 1.00 cop[tikv] eq(test39999.c.occur_trade_date, 2022-11-17 00:00:00.000000)",
" └─TableRangeScan_19 1.00 cop[tikv] table:c range: decided by [eq(test39999.c.txt_account_id, test39999.t.txn_account_id) eq(test39999.c.occur_trade_date, 2022-11-17 00:00:00.000000)], keep order:false"))
tk.MustQuery(query).Check(testkit.Rows("-2.01"))
}
1 change: 1 addition & 0 deletions table/table.go
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,7 @@ type PartitionedTable interface {
GetPartitionByRow(sessionctx.Context, []types.Datum) (PhysicalTable, error)
GetAllPartitionIDs() []int64
GetPartitionColumnNames() []model.CIStr
GetPartitionColumnIDs() map[int64]struct{}
CheckForExchangePartition(ctx sessionctx.Context, pi *model.PartitionInfo, r []types.Datum, pid int64) error
}

Expand Down
23 changes: 23 additions & 0 deletions table/tables/partition.go
Original file line number Diff line number Diff line change
Expand Up @@ -962,6 +962,29 @@ func (t *partitionedTable) GetPartitionColumnNames() []model.CIStr {
return colNames
}

// GetPartitionColumnIDs returns the column IDs from the partition expression
// TODO: refactor and have the column ids or a hash on PartitionInfo instead
func (t *partitionedTable) GetPartitionColumnIDs() map[int64]struct{} {
var ids map[int64]struct{}
meta := t.Meta()
pi := meta.Partition
if len(pi.Columns) > 0 {
ids = make(map[int64]struct{}, len(pi.Columns))
for _, name := range pi.Columns {
col := table.FindColLowerCase(t.Cols(), name.L)
ids[col.ID] = struct{}{}
}
return ids
}

partitionCols := expression.ExtractColumns(t.partitionExpr.Expr)
ids = make(map[int64]struct{}, len(partitionCols))
for _, col := range partitionCols {
ids[col.ID] = struct{}{}
}
return ids
}

// PartitionRecordKey is exported for test.
func PartitionRecordKey(pid int64, handle int64) kv.Key {
recordPrefix := tablecodec.GenTableRecordPrefix(pid)
Expand Down