Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

*: fix tidb_decode_key with partition table #39312

Merged
merged 9 commits into from
Nov 22, 2022
24 changes: 24 additions & 0 deletions expression/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3050,6 +3050,30 @@ func TestTiDBDecodeKeyFunc(t *testing.T) {
sql = fmt.Sprintf("select tidb_decode_key( '%s' )", hexKey)
rs = fmt.Sprintf(`{"%s":%d,"table_id":"%d"}`, tbl.Meta().GetPkName().String(), rowID, tbl.Meta().ID)
tk.MustQuery(sql).Check(testkit.Rows(rs))

// Test partition table.
tk.MustExec("drop table if exists t;")
tk.MustExec("create table t (a int primary key clustered, b int, key bk (b)) PARTITION BY RANGE (a) (PARTITION p0 VALUES LESS THAN (1), PARTITION p1 VALUES LESS THAN (2));")
dom = domain.GetDomain(tk.Session())
is = dom.InfoSchema()
tbl, err = is.TableByName(model.NewCIStr("test"), model.NewCIStr("t"))
require.NoError(t, err)
require.NotNil(t, tbl.Meta().Partition)
hexKey = buildTableRowKey(tbl.Meta().Partition.Definitions[0].ID, rowID)
sql = fmt.Sprintf("select tidb_decode_key( '%s' )", hexKey)
rs = fmt.Sprintf(`{"%s":%d,"partition_id":%d,"table_id":"%d"}`, tbl.Meta().GetPkName().String(), rowID, tbl.Meta().Partition.Definitions[0].ID, tbl.Meta().ID)
tk.MustQuery(sql).Check(testkit.Rows(rs))

hexKey = tablecodec.EncodeTablePrefix(tbl.Meta().Partition.Definitions[0].ID).String()
sql = fmt.Sprintf("select tidb_decode_key( '%s' )", hexKey)
rs = fmt.Sprintf(`{"partition_id":%d,"table_id":%d}`, tbl.Meta().Partition.Definitions[0].ID, tbl.Meta().ID)
tk.MustQuery(sql).Check(testkit.Rows(rs))

data = []types.Datum{types.NewIntDatum(100)}
hexKey = buildIndexKeyFromData(tbl.Meta().Partition.Definitions[0].ID, tbl.Indices()[0].Meta().ID, data)
sql = fmt.Sprintf("select tidb_decode_key( '%s' )", hexKey)
rs = fmt.Sprintf(`{"index_id":1,"index_vals":{"b":"100"},"partition_id":%d,"table_id":%d}`, tbl.Meta().Partition.Definitions[0].ID, tbl.Meta().ID)
tk.MustQuery(sql).Check(testkit.Rows(rs))
}

func TestTwoDecimalTruncate(t *testing.T) {
Expand Down
26 changes: 23 additions & 3 deletions planner/core/expression_rewriter.go
Original file line number Diff line number Diff line change
Expand Up @@ -2169,6 +2169,9 @@ func decodeKeyFromString(ctx sessionctx.Context, s string) string {
return s
}
tbl, _ := is.TableByID(tableID)
if tbl == nil {
tbl, _, _ = is.FindTableByPartitionID(tableID)
}
loc := ctx.GetSessionVars().Location()
if tablecodec.IsRecordKey(key) {
ret, err := decodeRecordKey(key, tableID, tbl, loc)
Expand All @@ -2185,7 +2188,7 @@ func decodeKeyFromString(ctx sessionctx.Context, s string) string {
}
return ret
} else if tablecodec.IsTableKey(key) {
ret, err := decodeTableKey(key, tableID)
ret, err := decodeTableKey(key, tableID, tbl)
if err != nil {
sc.AppendWarning(err)
return s
Expand All @@ -2203,6 +2206,10 @@ func decodeRecordKey(key []byte, tableID int64, tbl table.Table, loc *time.Locat
}
if handle.IsInt() {
ret := make(map[string]interface{})
if tbl != nil && tbl.Meta().Partition != nil {
ret["partition_id"] = tableID
tableID = tbl.Meta().ID
}
ret["table_id"] = strconv.FormatInt(tableID, 10)
// When the clustered index is enabled, we should show the PK name.
if tbl != nil && tbl.Meta().HasClusteredIndex() {
Expand Down Expand Up @@ -2239,6 +2246,10 @@ func decodeRecordKey(key []byte, tableID int64, tbl table.Table, loc *time.Locat
return "", errors.Trace(err)
}
ret := make(map[string]interface{})
if tbl.Meta().Partition != nil {
ret["partition_id"] = tableID
tableID = tbl.Meta().ID
}
ret["table_id"] = tableID
handleRet := make(map[string]interface{})
for colID := range datumMap {
Expand Down Expand Up @@ -2308,6 +2319,10 @@ func decodeIndexKey(key []byte, tableID int64, tbl table.Table, loc *time.Locati
ds = append(ds, d)
}
ret := make(map[string]interface{})
if tbl.Meta().Partition != nil {
ret["partition_id"] = tableID
tableID = tbl.Meta().ID
}
ret["table_id"] = tableID
ret["index_id"] = indexID
idxValMap := make(map[string]interface{}, len(targetIndex.Columns))
Expand Down Expand Up @@ -2340,8 +2355,13 @@ func decodeIndexKey(key []byte, tableID int64, tbl table.Table, loc *time.Locati
return string(retStr), nil
}

func decodeTableKey(_ []byte, tableID int64) (string, error) {
ret := map[string]int64{"table_id": tableID}
func decodeTableKey(_ []byte, tableID int64, tbl table.Table) (string, error) {
ret := map[string]int64{}
if tbl != nil && tbl.Meta().GetPartitionInfo() != nil {
ret["partition_id"] = tableID
tableID = tbl.Meta().ID
}
ret["table_id"] = tableID
retStr, err := json.Marshal(ret)
if err != nil {
return "", errors.Trace(err)
Expand Down